KNN
模型接口类别 |
函数接口 |
---|---|
ML API |
def fit(dataset: Dataset[_]): KNNModel def transform(dataset: Dataset[_]): DataFrame |
- 输入输出
- 包名:org.apache.spark.ml.neighbors
- 类名:KNN
- 方法名:fit/transform
- 输入:训练样本Dataset[_],测试样本Dataset[_]。
Param name
Type(s)
Description
dataset
Dataset[_]
包含样本特征的DF
k
Int
最近邻数
- 算法参数
- fit参数
Param name
Type(s)
Default
Description
setFeaturesCol(value:String
String
features
训练集特征列名
setAuxiliaryCols(value:Array[String])
Array[String]
Array.empty[String]
训练集附加列列名
- transform参数
Param name
Type(s)
Default
Description
setFeaturesCol(value: String)
String
features
测试集特征列名
setNeighborsCol(value:String)
String
neighbors
邻居附加列列名
setDistanceCol(value: String)
String
distances
邻居距离列名
setK(value: Int)
Int
1
近邻数
setTestBatchSize(value: Int)
Int
1024
搜索Batch大小
参数及fit代码接口示例:
1 2 3 4
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF)
- fit参数
- 输出:距离测试样本最近的k个近邻,包括距离和训练样本的附加列。
Param name
Type(s)
Description
dataset
Dataset[_]
包含k近邻距离和附加列的DF
- 使用样例
1 2 3 4 5 6 7 8 9 10 11
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF) val testResults = model .setFeaturesCol(featuresCol) .setNeighborsCol(neighborsCol) .setDistanceCol(distanceCol) .setK(k) .setTestBatchSize(testBatchSize) .transform(testDataDF)
父主题: 算法API