ALS
ALS为ML API类模型接口。
模型接口类别 |
函数接口 |
---|---|
ML API |
def fit(dataset: Dataset[_]): ALSModel def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[ALSModel] def fit(dataset: Dataset[_], paramMap: ParamMap): ALSModel def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): ALSModel |
ML API
- 功能描述
- 输入输出
- 包名:package org.apache.spark.ml.recommendation
- 类名:ALS
- 方法名:fit
- 输入:dataSet[_],训练样本数据
- 算法参数
算法参数
def setRank(value: Int): ALS.this.type
def setNumUserBlocks(value: Int): ALS.this.type
def setNumItemBlocks(value: Int): ALS.this.type
def setImplicitPrefs(value: Boolean): ALS.this.type
def setAlpha(value: Double): ALS.this.type
def setUserCol(value: String): ALS.this.type
def setItemCol(value: String): ALS.this.type
def setRatingCol(value: String): ALS.this.type
def setPredictionCol(value: String): ALS.this.type
def setMaxIter(value: Int): ALS.this.type
def setRegParam(value: Double): ALS.this.type
def setNonnegative(value: Boolean): ALS.this.type
def setCheckpointInterval(value: Int): ALS.this.type
def setSeed(value: Long): ALS.this.type
def setIntermediateStorageLevel(value: String): ALS.this.type
def setFinalStorageLevel(value: String): ALS.this.type
def setColdStartStrategy(value: String): ALS.this.type
def setNumBlocks(value: Int): ALS.type
- 新增算法参数。
参数名称
参数含义
取值类型
spark.sophon.ALS.bloc kMaxRow
计算格莱姆矩阵行分块大小(不建议修改)
正整数
spark.sophon.ALS.unp ersistCycle
srcFactorRDD反持久化周期(不建议修改)
正整数
参数及fit代码接口示例:
val als = new ALS() .setMaxIter(numIterations) .setUserCol("user") .setItemCol("product") .setRatingCol("rating") .setNonnegative(nonnegative) .setImplicitPrefs(implicitPrefs) .setNumItemBlocks(numItemBlocks) .setNumUserBlocks(numUserBlocks) .setRegParam(regParam) .setAlpha(alpha) val model = als.fit(ratings)
- 输出:ALSModel,ALS推荐模型
- 使用样例
val model = als.fit(ratings) val predictions = model.transform(ratings) val p = predictions.select("rating", "prediction").rdd .map { case Row(prediction: Float, label: Float) => (prediction, label) } .map{t => val err = (t._1 - t._2) err * err }.mean() println("Mean Squared Error = " + p)
- 结果样例
Mean Squared Error = 0.9962046583156793