LinearRegression为ML API。
模型接口类别 |
函数接口 |
---|---|
ML API |
def fit(dataset: Dataset[_]):LinearRegressionModel |
def fit(dataset: Dataset[_], paramMap: ParamMap): LinearRegressionModel |
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*):LinearRegressionModel |
|
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[LinearRegressionModel] |
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
labelCol |
Double |
"label" |
Label |
featuresCol |
Vector |
"features" |
特征标签 |
算法参数 |
---|
def setRegParam(value: Double): LinearRegression.this.type def setFitIntercept(value: Boolean): LinearRegression.this.type def setStandardization(value: Boolean): LinearRegression.this.type def setElasticNetParam(value: Double): LinearRegression.this.type def setMaxIter(value: Int): LinearRegression.this.type def setTol(value: Double): LinearRegression.this.type def setWeightCol(value: String): LinearRegression.this.type def setSolver(value: String): LinearRegression.this.type def setAggregationDepth(value: Int): LinearRegression.this.type def setLoss(value: String): LinearRegression.this.type def setEpsilon(value: Double): LinearRegression.this.type |
参数及fit代码接口示例:
import org.apache.spark.ml.param.{ParamMap, ParamPair} val linR = new LinearRegression() //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数 val paramMap = ParamMap(linR.maxIter -> maxIter) .put(linR.regParam, regParam) // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数 val paramMaps: Array[ParamMap] = new Array[ParamMap](2) for (i <- 0 to 2) { paramMaps(i) = ParamMap(linR.maxIter -> maxIter) .put(linR.regParam, regParam) }//对paramMaps进行赋值 // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数 val regParamPair = ParamPair(linR.regParam, regParam) val maxIterParamPair = ParamPair(linR.maxIter, maxIter) val tolParamPair = ParamPair(linR.tol, tol) // 调用各个fit接口 model = linR.fit(trainingData) model = linR.fit(trainingData, paramMap) models = linR.fit(trainingData, paramMaps) model = linR.fit(trainingData, regParamPair, maxIterParamPair, tolParamPair)
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
predictionCol |
Int |
"prediction" |
predictionCol |
import org.apache.spark.ml.regression.LinearRegression // Load training data val training = spark.read.format("libsvm") .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) .setRegParam(0.3) .setElasticNetParam(0.8) // Fit the model val lrModel = lr.fit(training) // Summarize the model over the training set and print out some metrics val trainingSummary = lrModel.summary