KNN
Model API Type |
Function API |
|---|---|
ML API |
def fit(dataset: Dataset[_]): KNNModel def transform(dataset: Dataset[_]): DataFrame |
ML API
- Input and output
- Package name: org.apache.spark.ml.neighbors
- Class name: KNN
- Method name: fit/transform
- Input: training sample Dataset[_] and test sample Dataset[_]
Parameter
Value Type
Description
dataset
Dataset[_]
DF that contains sample features
k
Int
Number of nearest neighbors
- Parameters optimized based on native algorithms
- fit parameters
Parameter
Value Type
Default Value
Description
setFeaturesCol(value:String
String
features
Feature column name of the training dataset
setAuxiliaryCols(value:Array[String])
Array[String]
Array.empty[String]
Additional column name of the training dataset
- transform parameters
Parameter
Value Type
Default Value
Description
setFeaturesCol(value: String)
String
features
Feature column name of the test dataset
setNeighborsCol(value:String)
String
neighbors
Additional column name of a neighbor
setDistanceCol(value: String)
String
distances
Neighbor distance column name
setK(value: Int)
Int
1
Number of nearest neighbors
setTestBatchSize(value: Int)
Int
1024
Search batch size
An example is provided as follows:
1 2 3 4
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF)
- fit parameters
- Output: k nearest neighbors to the test sample, including the distances and the additional columns of the training sample
Parameter
Value Type
Description
dataset
Dataset[_]
DF that contains distances and additional columns
- Example
1 2 3 4 5 6 7 8 9 10 11
val model = new KNN() .setFeaturesCol(featuresCol) .setAuxiliaryCols(Array("id")) .fit(trainDataDF) val testResults = model .setFeaturesCol(featuresCol) .setNeighborsCol(neighborsCol) .setDistanceCol(distanceCol) .setK(k) .setTestBatchSize(testBatchSize) .transform(testDataDF)