鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

Linear Regression

LinearRegression为ML API。

模型接口类别

函数接口

ML API

def fit(dataset: Dataset[_]):LinearRegressionModel

def fit(dataset: Dataset[_], paramMap: ParamMap): LinearRegressionModel

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*):LinearRegressionModel

def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[LinearRegressionModel]

ML API

  • 功能描述

    传入Dataset格式的样本数据,调用fit接口,输出Linear Regression模型。

  • 输入输出
    1. 包名package org.apache.spark.ml.regression
    2. 类名:LinearRegression
    3. 方法名:fit
    4. 输入:Dataset[_],训练样本数据,必须字段如下

      参数名称

      取值类型

      缺省值

      描述

      labelCol

      Double

      label

      Label

      featuresCol

      Vector

      features

      特征标签

    5. 基于原生算法优化的参数
      def setRegParam(value: Double): LinearRegression.this.type
      def setFitIntercept(value: Boolean): LinearRegression.this.type
      def setStandardization(value: Boolean): LinearRegression.this.type
      def setElasticNetParam(value: Double): LinearRegression.this.type
      def setMaxIter(value: Int): LinearRegression.this.type
      def setTol(value: Double): LinearRegression.this.type
      def setWeightCol(value: String): LinearRegression.this.type
      def setSolver(value: String): LinearRegression.this.type
      def setAggregationDepth(value: Int): LinearRegression.this.type
      def setLoss(value: String): LinearRegression.this.type
      def setEpsilon(value: Double): LinearRegression.this.type

      参数及fit代码接口示例:

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      import org.apache.spark.ml.param.{ParamMap, ParamPair}
      
      val linR = new LinearRegression()
      //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数
      val paramMap = ParamMap(linR.maxIter -> maxIter)
      .put(linR.regParam, regParam)
      
      // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数
      val paramMaps: Array[ParamMap] = new Array[ParamMap](2)
      for (i <- 0 to  2) {
      paramMaps(i) = ParamMap(linR.maxIter -> maxIter)
      .put(linR.regParam, regParam)
      }//对paramMaps进行赋值
      
      // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数
      val regParamPair = ParamPair(linR.regParam, regParam)
      val maxIterParamPair = ParamPair(linR.maxIter, maxIter)
      val tolParamPair = ParamPair(linR.tol, tol)
      
      // 调用各个fit接口
      model = linR.fit(trainingData)
      model = linR.fit(trainingData, paramMap)
      models = linR.fit(trainingData, paramMaps)
      model = linR.fit(trainingData, regParamPair, maxIterParamPair, tolParamPair)
      
    6. 输出:LinearRegressionModel,模型预测时的输出字段如下

      参数名称

      取值类型

      缺省值

      描述

      predictionCol

      Int

      prediction

      predictionCol

  • 使用样例
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    import org.apache.spark.ml.regression.LinearRegression
    
    // Load training data
    val training = spark.read.format("libsvm")
      .load("data/mllib/sample_linear_regression_data.txt")
    
    val lr = new LinearRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
    
    // Fit the model
    val lrModel = lr.fit(training)
    
    // Summarize the model over the training set and print out some metrics
    val trainingSummary = lrModel.summary