GBDT分为ML Classification API和ML Regression API两大类模型接口。
模型接口类别 |
函数接口 |
---|---|
ML Classification API |
def fit(dataset: Dataset[_]): GBTClassificationModel |
def fit(dataset: Dataset[_], paramMap: ParamMap): GBTClassificationModel |
|
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[GBTClassificationModel] |
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): GBTClassificationModel |
|
ML Regression API |
def fit(dataset: Dataset[_]): GBTRegressionModel |
def fit(dataset: Dataset[_], paramMap: ParamMap): GBTRegressionModel |
|
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[GBTRegressionModel] |
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): GBTRegressionModel |
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
labelCol |
Double |
"label" |
预测标签 |
featuresCol |
Vector |
"features" |
特征标签 |
Param name |
Type(s) |
Example |
Description |
---|---|---|---|
paramMap |
ParamMap |
ParamMap(A.c -> b) |
将b的值赋给模型A的参数c |
paramMaps |
Array[ParamMa p] |
Array[ParamMa p](n) |
形成n个 ParamMap模型参数列表 |
firstParamPair |
ParamPair |
ParamPair(A.c, b) |
将b的值赋给模型A的参数c |
otherParamPair s |
ParamPair |
ParamPair(A.e, f) |
将f的值赋给模型 A的参数e |
算法参数 |
---|
def setCheckpointInterval(value: Int): GBTClassifier.this.type def setFeatureSubsetStrategy(value: String): GBTClassifier.this.type def setFeaturesCol(value: String): GBTClassifier def setImpurity(value: String): GBTClassifier.this.type def setLabelCol(value: String): GBTClassifier def setLossType(value: String): GBTClassifier.this.type def setMaxBins(value: Int): GBTClassifier.this.type def setMaxDepth(value: Int): GBTClassifier.this.type def setMaxIter(value: Int): GBTClassifier.this.type def setMinInfoGain(value: Double): GBTClassifier.this.type def setMinInstancesPerNode(value: Int): GBTClassifier.this.type def setPredictionCol(value: String): GBTClassifier def setProbabilityCol(value: String): GBTClassifierdoUseAcc def setRawPredictionCol(value: String): GBTClassifier def setSeed(value: Long): GBTClassifier.this.type def setStepSize(value: Double): GBTClassifier.this.type def setSubsamplingRate(value: Double): GBTClassifier.this.type def setThresholds(value: Array[Double]): GBTClassifier |
参数名称 |
参数含义 |
取值类型 |
---|---|---|
doUseAcc |
特征并行训练模式开关 |
True/False[Boolean] |
参数及fit代码接口示例:
import org.apache.spark.ml.param.{ParamMap, ParamPair} val gbdt = new GBTClassifier() //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数 val paramMap = ParamMap(gbdt.maxDepth -> maxDepth) .put(gbdt.maxIter, maxIter) // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数 val paramMaps: Array[ParamMap] = new Array[ParamMap](2) for (i <- 0 to 2) { paramMaps(i) = ParamMap(gbdt.maxDepth -> maxDepth) .put(gbdt.maxIter, maxIter) }//对paramMaps进行赋值 // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数 val maxDepthParamPair = ParamPair(gbdt.maxDepth, maxDepth) val maxIterParamPair = ParamPair(gbdt.maxIter, maxIter) val maxBinsParamPair = ParamPair(gbdt.maxBins, maxBins) // 调用各个fit接口 model = gbdt.fit(trainingData) model = gbdt.fit(trainingData, paramMap) models = gbdt.fit(trainingData, paramMaps) model = gbdt.fit(trainingData, maxDepthParamPair, maxIterParamPair, maxBinsParamPair)
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
predictionCol |
Double |
"prediction" |
Predicted label |
fit(dataset: Dataset[_]): GBTClassificationModel样例:
import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // Load and parse the data file, converting it to a DataFrame. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data) // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) .fit(data) // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a GBT model. val gbt = new GBTClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setMaxIter(10) // Convert indexed labels back to original labels. val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) // Chain indexers and GBT in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] println("Learned classification GBT model:\n" + gbtModel.toDebugString)
Test Error = 0.0714285714285714 Learned classification GBT model: GBTClassificationModel (uid=gbtc_72086dba9af5) with 10 trees Tree 0 (weight 1.0): If (feature 406 <= 9.5) Predict: 1.0 Else (feature 406 > 9.5) Predict: -1.0 Tree 1 (weight 0.1): If (feature 406 <= 9.5) If (feature 209 <= 241.5) If (feature 154 <= 55.0) Predict: 0.4768116880884702 Else (feature 154 > 55.0) Predict: 0.4768116880884703 Else (feature 209 > 241.5) Predict: 0.47681168808847035 Else (feature 406 > 9.5) If (feature 461 <= 143.5) Predict: -0.47681168808847024 Else (feature 461 > 143.5) Predict: -0.47681168808847035 Tree 2 (weight 0.1): If (feature 406 <= 9.5) If (feature 657 <= 116.5) If (feature 154 <= 9.5) Predict: 0.4381935810427206 Else (feature 154 > 9.5) Predict: 0.43819358104272066 Else (feature 657 > 116.5) Predict: 0.43819358104272066 Else (feature 406 > 9.5) If (feature 322 <= 16.0) Predict: -0.4381935810427206 Else (feature 322 > 16.0) Predict: -0.4381935810427206 Tree 3 (weight 0.1): If (feature 406 <= 9.5) If (feature 598 <= 166.5) If (feature 180 <= 3.0) Predict: 0.4051496802845983 Else (feature 180 > 3.0) Predict: 0.4051496802845984 Else (feature 598 > 166.5) Predict: 0.4051496802845983 Else (feature 406 > 9.5) Predict: -0.4051496802845983 Tree 4 (weight 0.1): If (feature 406 <= 9.5) If (feature 537 <= 47.5) If (feature 606 <= 7.0) Predict: 0.3765841318352991 Else (feature 606 > 7.0) Predict: 0.37658413183529926 Else (feature 537 > 47.5) Predict: 0.3765841318352994 Else (feature 406 > 9.5) If (feature 124 <= 35.5) If (feature 376 <= 1.0) If (feature 516 <= 26.5) If (feature 266 <= 50.5) Predict: -0.3765841318352991 Else (feature 266 > 50.5) Predict: -0.37658413183529915 Else (feature 516 > 26.5) Predict: -0.3765841318352992 Else (feature 376 > 1.0) Predict: -0.3765841318352994 Else (feature 124 > 35.5) Predict: -0.3765841318352994 Tree 5 (weight 0.1): If (feature 406 <= 9.5) If (feature 570 <= 3.5) Predict: 0.35166478958101005 Else (feature 570 > 3.5) Predict: 0.35166478958101 Else (feature 406 > 9.5) If (feature 266 <= 14.0) If (feature 267 <= 12.5) Predict: -0.35166478958101005 Else (feature 267 > 12.5) If (feature 267 <= 36.0) Predict: -0.35166478958101005 Else (feature 267 > 36.0) Predict: -0.3516647895810101 Else (feature 266 > 14.0) Predict: -0.35166478958101005 Tree 6 (weight 0.1): If (feature 406 <= 9.5) If (feature 207 <= 7.5) Predict: 0.32974984655529926 Else (feature 207 > 7.5) Predict: 0.3297498465552993 Else (feature 406 > 9.5) If (feature 490 <= 185.0) Predict: -0.32974984655529926 Else (feature 490 > 185.0) Predict: -0.3297498465552993 Tree 7 (weight 0.1): If (feature 406 <= 9.5) If (feature 568 <= 22.0) Predict: 0.3103372455197956 Else (feature 568 > 22.0) Predict: 0.31033724551979563 Else (feature 406 > 9.5) If (feature 379 <= 133.5) If (feature 237 <= 250.5) Predict: -0.3103372455197956 Else (feature 237 > 250.5) Predict: -0.3103372455197957 Else (feature 379 > 133.5) If (feature 433 <= 183.5) If (feature 516 <= 9.0) Predict: -0.3103372455197956 Else (feature 516 > 9.0) Predict: -0.3103372455197957 Else (feature 433 > 183.5) Predict: -0.3103372455197957 Tree 8 (weight 0.1): If (feature 406 <= 9.5) If (feature 184 <= 19.0) Predict: 0.2930291649125433 Else (feature 184 > 19.0) If (feature 155 <= 147.0) If (feature 180 <= 3.0) Predict: 0.2930291649125433 Else (feature 180 > 3.0) Predict: 0.2930291649125433 Else (feature 155 > 147.0) Predict: 0.2930291649125434 Else (feature 406 > 9.5) If (feature 379 <= 133.5) Predict: -0.2930291649125433 Else (feature 379 > 133.5) If (feature 433 <= 52.5) Predict: -0.2930291649125433 Else (feature 433 > 52.5) If (feature 462 <= 143.5) Predict: -0.2930291649125433 Else (feature 462 > 143.5) Predict: -0.2930291649125434 Tree 9 (weight 0.1): If (feature 406 <= 9.5) If (feature 183 <= 3.0) Predict: 0.27750666438358246 Else (feature 183 > 3.0) If (feature 183 <= 19.5) Predict: 0.27750666438358246 Else (feature 183 > 19.5) Predict: 0.2775066643835825 Else (feature 406 > 9.5) If (feature 239 <= 50.5) If (feature 435 <= 102.0) Predict: -0.27750666438358246 Else (feature 435 > 102.0) Predict: -0.2775066643835825 Else (feature 239 > 50.5) Predict: -0.27750666438358257
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
labelCol |
Double |
"label" |
预测标签 |
featuresCol |
Vector |
"features" |
特征标签 |
Param name |
Type(s) |
Example |
Description |
---|---|---|---|
paramMap |
ParamMap |
ParamMap(A.c -> b) |
将b的值赋给模型A的参数c |
paramMaps |
Array[ParamMa p] |
Array[ParamMa p](n) |
形成n个ParamMap模型参数列表 |
firstParamPair |
ParamPair |
ParamPair(A.c, b) |
将b的值赋给模型A的参数c |
otherParamPair s |
ParamPair |
ParamPair(A.e, f) |
将f的值赋给模型A的参数e |
算法参数 |
---|
def setCheckpointInterval(value: Int): GBTRegressor.this.type def setFeatureSubsetStrategy(value: String): GBTRegressor.this.type def setFeaturesCol(value: String): GBTRegressor def setImpurity(value: String): GBTRegressor.this.type def setLabelCol(value: String): GBTRegressor def setLossType(value: String): GBTRegressor.this.type def setMaxBins(value: Int): GBTRegressor.this.type def setMaxDepth(value: Int): GBTRegressor.this.type def setMaxIter(value: Int): GBTRegressor.this.type def setMinInfoGain(value: Double): GBTRegressor.this.type def setMinInstancesPerNode(value: Int): GBTRegressor.this.type def setPredictionCol(value: String): GBTRegressor def setSeed(value: Long): GBTRegressor.this.type def setStepSize(value: Double): GBTRegressor.this.type def setSubsamplingRate(value: Double): GBTRegressor.this.type |
参数及fit代码接口示例:
import org.apache.spark.ml.param.{ParamMap, ParamPair} val gbdt = new GBTRegressor() //定义回归模型 //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数 val paramMap = ParamMap(gbdt.maxDepth -> maxDepth) .put(gbdt.maxIter, maxIter) // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数 val paramMaps: Array[ParamMap] = new Array[ParamMap](2) for (i <- 0 to 2) { paramMaps(i) = ParamMap(gbdt.maxDepth -> maxDepth) .put(gbdt.maxIter, maxIter) } //对paramMaps进行赋值 // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数 val maxDepthParamPair = ParamPair(gbdt.maxDepth, maxDepth) val maxIterParamPair = ParamPair(gbdt.maxIter, maxIter) val maxBinsParamPair = ParamPair(gbdt.maxBins, maxBins) // 调用各个fit接口 model = gbdt.fit(trainingData) //返回GBTRegressionModel model = gbdt.fit(trainingData, paramMap) //返回GBTRegressionModel models = gbdt.fit(trainingData, paramMaps) //返回Seq[GBTRegressionModel] model = gbdt.fit(trainingData, maxDepthParamPair, maxIterParamPair, maxBinsParamPair) //返回GBTRegressionModel
Param name |
Type(s) |
Default |
Description |
---|---|---|---|
predictionCol |
Double |
"prediction" |
Predicted label |
fit(dataset: Dataset[_]): GBTRegressionModel样例:
import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} // Load and parse the data file, converting it to a DataFrame. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) .fit(data) // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a GBT model. val gbt = new GBTRegressor() .setLabelCol("label") .setFeaturesCol("indexedFeatures") .setMaxIter(10) // Chain indexer and GBT in a Pipeline. val pipeline = new Pipeline() .setStages(Array(featureIndexer, gbt)) // Train model. This also runs the indexer. val model = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select example rows to display. predictions.select("prediction", "label", "features").show(5) // Select (prediction, true label) and compute test error. val evaluator = new RegressionE val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] println("Learned regression GBT model:\n" + gbtModel.toDebugString)
Root Mean Squared Error (RMSE) on test data = 0.0 Learned regression GBT model: GBTRegressionModel (uid=gbtr_842c8acff963) with 10 trees Tree 0 (weight 1.0): If (feature 434 <= 70.5) If (feature 99 in {0.0,3.0}) Predict: 0.0 Else (feature 99 not in {0.0,3.0}) Predict: 1.0 Else (feature 434 > 70.5) Predict: 1.0 Tree 1 (weight 0.1): Predict: 0.0 Tree 2 (weight 0.1): Predict: 0.0 Tree 3 (weight 0.1): Predict: 0.0 Tree 4 (weight 0.1): Predict: 0.0 Tree 5 (weight 0.1): Predict: 0.0 Tree 6 (weight 0.1): Predict: 0.0 Tree 7 (weight 0.1): Predict: 0.0 Tree 8 (weight 0.1): Predict: 0.0 Tree 9 (weight 0.1): Predict: 0.0
接口适用性说明: