KNN
Model API Type |
Function API |
|---|---|
ML API |
def fit(dataset: Dataset[_]): KNNModel def transform(dataset: Dataset[_]): DataFrame |
- Input and output
- Package name: org.apache.spark.ml.neighbors
- Class name: KNN
- Method name: fit/transform
- Input: training sample Dataset[_] and test sample Dataset[_]
Parameter
Type
Description
dataset
Dataset[_]
DF that contains sample features
k
Int
Number of nearest neighbors
- Algorithm parameters
- fit parameters
Parameter
Type
Default Value
Description
setFeaturesCol(value:String
String
features
Feature column name of the training set
setAuxiliaryCols(value:Array[String])
Array[String]
Array.empty[String]
Additional column name of the training set
- transform parameters
Parameter
Type
Default Value
Description
setFeaturesCol(value: String)
String
features
Feature column name of the test set
setNeighborsCol(value:String)
String
neighbors
Additional column name of a neighbor
setDistanceCol(value: String)
String
distances
Neighbor distance column name
setK(value: Int)
Int
1
Number of nearest neighbors
setTestBatchSize(value: Int)
Int
1024
Search batch size
An example is provided as follows:
1 2 3 4
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF)
- fit parameters
- Output: k nearest neighbors to the test sample, including the distances and the additional columns of the training sample
Parameter
Type
Description
dataset
Dataset[_]
DF with k nearest neighbor distance and additional column
- Example usage
1 2 3 4 5 6 7 8 9 10 11
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF) val testResults = model .setFeaturesCol(featuresCol) .setNeighborsCol(neighborsCol) .setDistanceCol(distanceCol) .setK(k) .setTestBatchSize(testBatchSize) .transform(testDataDF)