EN
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
鲲鹏小智

LogisticRegression

LogisticRegression为ML API。

模型接口类别

函数接口

ML API

def fit(dataset: Dataset[_]):LogisticRegressionModel

def fit(dataset: Dataset[_], paramMap: ParamMap): LogisticRegressionModel

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs:

ParamPair[_]*):LogisticRegressionModel

def fit(dataset: Dataset[_], paramMaps:

Array[ParamMap]): Seq[LogisticRegressionModel]

ML classification API

  • 功能描述

    传入Dataset格式的样本数据,调用fit接口,输出LogisticRegression模型。

  • 输入输出
    1. 包名:package org.apache.spark.ml.classification
    2. 类名:LogisticRegression
    3. 方法名:fit
    4. 输入:Dataset[_],训练样本数据,必须字段如下。

      Param name

      Type(s)

      Default

      Description

      labelCol

      Double

      "label"

      Label, require: 1) label == label.toInt 2) label >= 0

      featuresCol

      Vector

      "features"

      特征标签

    5. 算法参数

      算法参数

      def setRegParam(value: Double): LogisticRegression.this.type

      def setElasticNetParam(value: Double): LogisticRegression.this.type

      def setMaxIter(value: Int): LogisticRegression.this.type

      def setTol(value: Double): LogisticRegression.this.type

      def setFitIntercept(value: Boolean): LogisticRegression.this.type

      def setFamily(value: String): LogisticRegression.this.type

      def setStandardization(value: Boolean): LogisticRegression.this.type

      override def setThreshold(value: Double): LogisticRegression.this.type

      def setWeightCol(value: String): LogisticRegression.this.type

      override def setThresholds(value: Array[Double]): LogisticRegression.this.type

      def setAggregationDepth(value: Int): LogisticRegression.this.type

      def setLowerBoundsOnCoefficients(value: Matrix): LogisticRegression.this.type

      def setUpperBoundsOnCoefficients(value: Matrix): LogisticRegression.this.type

      def setLowerBoundsOnIntercepts(value: Vector): LogisticRegression.this.type

      def setUpperBoundsOnIntercepts(value: Vector):LogisticRegression.this.type

      参数及fit代码接口示例:

      import org.apache.spark.ml.param.{ParamMap, ParamPair}
      
      val logR = new LogisticRegression()
      //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数
      val paramMap = ParamMap(logR.maxIter -> maxIter)
      .put(logR.regParam, regParam)
      
      // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数
      val paramMaps: Array[ParamMap] = new Array[ParamMap](2)
      for (i <- 0 to  2) {
      paramMaps(i) = ParamMap(logR.maxIter -> maxIter)
      .put(logR.regParam, regParam)
      }//对paramMaps进行赋值
      
      // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数
      val regParamPair = ParamPair(logR.regParam, regParam)
      val maxIterParamPair = ParamPair(logR.maxIter, maxIter)
      val tolParamPair = ParamPair(logR.tol, tol)
      
      // 调用各个fit接口
      model = logR.fit(trainingData)
      model = logR.fit(trainingData, paramMap)
      models = logR.fit(trainingData, paramMaps)
      model = logR.fit(trainingData, regParamPair, maxIterParamPair, tolParamPair)
    6. 输出:LogisticRegressionModel,模型预测时的输出字段如下。

      Param name

      Type(s)

      Default

      Description

      predictionCol

      Double

      "prediction"

      Predicted Label

  • 使用样例
    import org.apache.spark.ml.classification.LogisticRegression
    
    // Load training data
    val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
    
    val lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.3)
    .setElasticNetParam(0.8)
    
    // Fit the model
    val lrModel = lr.fit(training)
    
    // Print the coefficients and intercept for logistic regression
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
    
    // We can also use the multinomial family for binary classification
    val mlr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.3)
    .setElasticNetParam(0.8)
    .setFamily("multinomial")
    
    val mlrModel = mlr.fit(training)