我要评分
获取效率
正确性
完整性
易理解

K-means

The K-means algorithm uses ML APIs.

Model API Type

Function API

ML API

def fit(dataset: Dataset[_]): KMeansModel

def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[KMeansModel]

def fit(dataset: Dataset[_], paramMap: ParamMap): KMeansModel

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): KMeansModel

ML API

  • Function

    Import sample data in dataset format, call the fit API, and output the k-means clustering model.

  • Input and output
    1. Package name: package org.apache.spark.ml.clustering
    2. Class name: k-means
    3. Method name: fit
    4. Input: training sample data (Dataset[_]). Mandatory fields are as follows:

      Parameter

      Type

      Default Value

      Description

      featuresCol

      Vector

      features

      Feature label

    5. Algorithm parameters

      Algorithm Parameter

      def setFeaturesCol(value: String): KMeans.this.type

      def setPredictionCol(value: String): KMeans.this.type

      def setK(value: Int): KMeans.this.type

      def setInitMode(value: String): KMeans.this.type

      def setInitSteps(value: Int): KMeans.this.type

      def setMaxIter(value: Int): KMeans.this.type

      def setThreshold(value: Double): KMeans.this.type

      def setTol(value: Double): KMeans.this.type

      def setSeed(value: Long): KMeans.this.type

    6. Added algorithm parameters

      Parameter

      Description

      Type

      sampleRate

      Ratio of the data used in each iteration to the full data set

      0~1[Double]

      optMethod

      Whether to trigger sampling

      default/allData[String]

      An example is provided as follows:

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      import org.apache.spark.ml.param.{ParamMap, ParamPair}
      
      val kmeans = new MlKMeans()
      // Define the def fit(dataset: Dataset[_], paramMap: ParamMap) API parameter.
      val paramMap = ParamMap(kmeans.initSteps -> initSteps)
      .put(kmeans.maxIter, maxIter)
      
      // Define the def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): API parameter.
      val paramMaps: Array[ParamMap] = new Array[ParamMap](2)
      for (i <- 0 to  2) {
      paramMaps(i) = ParamMap(kmeans.initSteps -> initSteps)
      .put(kmeans.maxIter, maxIter)
      }// Assign a value to paramMaps.
      
      // Define the def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) API parameter.
      val initStepsParamPair = ParamPair(kmeans.initSteps, initSteps)
      val maxIterParamPair = ParamPair(kmeans.maxIter, maxIter)
      val tolParamPair = ParamPair(kmeans.tol, tol)
      
      // Call the fit APIs.
      model = kmeans.fit(trainingData)
      model = kmeans.fit(trainingData, paramMap)
      models = kemans.fit(trainingData, paramMaps)
      model = kemans.fit(trainingData, initStepsParamPair, maxIterParamPair, tolParamPair)
      
    7. Output: k-means clustering model (KMeansModel). The output in model prediction is as follows.

      Parameter

      Type

      Default Value

      Description

      predictionCol

      Int

      prediction

      predictionCol

  • Sample usage
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import org.apache.spark.ml.clustering.KMeans
    import org.apache.spark.ml.evaluation.ClusteringEvaluator
    
    // Loads data.
    val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
    
    // Trains a k-means model.
    val kmeans = new KMeans().setK(2).setSeed(1L)
    val model = kmeans.fit(dataset)
    
    // Make predictions
    val predictions = model.transform(dataset)
    
    // Evaluate clustering by computing Silhouette score
    val evaluator = new ClusteringEvaluator()
    
    val silhouette = evaluator.evaluate(predictions)
    println(s"Silhouette with squared euclidean distance = $silhouette")
    
    // Shows the result.
    println("Cluster Centers: ")
    model.clusterCenters.foreach(println)