鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

ALS

ALS为ML API类模型接口。

模型接口类别

函数接口

ML API

def fit(dataset: Dataset[_]): ALSModel

def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[ALSModel]

def fit(dataset: Dataset[_], paramMap: ParamMap): ALSModel

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): ALSModel

ML API

  • 功能描述

    传入dataSet格式的样本数据,调用fit接口,输出ALS推荐模型。

  • 输入输出
    1. 包名package org.apache.spark.ml.recommendation
    2. 类名:ALS
    3. 方法名:fit
    4. 输入:dataSet[_],训练样本数据
    5. 算法参数

      算法参数

      def setRank(value: Int): ALS.this.type

      def setNumUserBlocks(value: Int): ALS.this.type

      def setNumItemBlocks(value: Int): ALS.this.type

      def setImplicitPrefs(value: Boolean): ALS.this.type

      def setAlpha(value: Double): ALS.this.type

      def setUserCol(value: String): ALS.this.type

      def setItemCol(value: String): ALS.this.type

      def setRatingCol(value: String): ALS.this.type

      def setPredictionCol(value: String): ALS.this.type

      def setMaxIter(value: Int): ALS.this.type

      def setRegParam(value: Double): ALS.this.type

      def setNonnegative(value: Boolean): ALS.this.type

      def setCheckpointInterval(value: Int): ALS.this.type

      def setSeed(value: Long): ALS.this.type

      def setIntermediateStorageLevel(value: String): ALS.this.type

      def setFinalStorageLevel(value: String): ALS.this.type

      def setColdStartStrategy(value: String): ALS.this.type

      def setNumBlocks(value: Int): ALS.type

    6. 新增算法参数。

      参数名称

      参数含义

      取值类型

      spark.boostkit.ALS.blockMaxRow

      计算格莱姆矩阵行分块大小(不建议修改)

      正整数

      spark.boostkit.ALS.unpersistCycle

      srcFactorRDD反持久化周期(不建议修改)

      正整数

      参数及fit代码接口示例:

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      val als = new ALS()
            .setMaxIter(numIterations)
            .setUserCol("user")
            .setItemCol("product")
            .setRatingCol("rating")
            .setNonnegative(nonnegative)
            .setImplicitPrefs(implicitPrefs)
            .setNumItemBlocks(numItemBlocks)
            .setNumUserBlocks(numUserBlocks)
            .setRegParam(regParam)
            .setAlpha(alpha)
      val model = als.fit(ratings)
      
    7. 输出:ALSModel,ALS推荐模型
  • 使用样例
    1
    2
    3
    val model = als.fit(ratings)
    val predictions = model.transform(ratings)
    val p = predictions.select("rating", "prediction").rdd   .map { case Row(prediction: Float, label: Float) => (prediction, label) }   .map{t =>     val err = (t._1 - t._2)     err * err   }.mean() println("Mean Squared Error = " + p)
    
  • 结果样例
    1
    Mean Squared Error = 0.9962046583156793