XGBoost算法分为ML Classification API和ML Regression API两大类模型接口。
模型接口类别 |
函数接口 |
---|---|
ML Classification API |
def fit(dataset: Dataset[_]): XGBoostClassificationModel |
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[XGBoostClassificationModel] |
|
def fit(dataset: Dataset[_], paramMap: ParamMap): XGBoostClassificationModel |
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): XGBoostClassificationModel |
|
ML Regression API |
def fit(dataset: Dataset[_]): XGBoostRegressionModel |
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[XGBoostRegressionModel] |
|
def fit(dataset: Dataset[_], paramMap: ParamMap): XGBoostRegressionModel |
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): XGBoostRegressionModel |
参数名称 |
取值类型 |
缺省值 |
描述 |
---|---|---|---|
labelCol |
Double |
label |
预测标签 |
featuresCol |
Vector |
features |
特征标签 |
参数名称 |
取值类型 |
示例 |
描述 |
---|---|---|---|
predictionCol |
Double |
prediction |
预测的标签值 |
def setAllowNonZeroForMissing(value: Boolean): XGBoostClassifier.this.type def setAlpha(value: Double): XGBoostClassifier.this.type def setBaseMarginCol(value: String): XGBoostClassifier.this.type def setBaseScore(value: Double): XGBoostClassifier.this.type def setCheckpointInterval(value: Int): XGBoostClassifier.this.type def setCheckpointPath(value: String): XGBoostClassifier.this.type def setColsampleBylevel(value: Double): XGBoostClassifier.this.type def setColsampleBytree(value: Double): XGBoostClassifier.this.type def setCustomEval(value: EvalTrait): XGBoostClassifier.this.type def setCustomObj(value: ObjectiveTrait): XGBoostClassifier.this.type def setEta(value: Double): XGBoostClassifier.this.type def setEvalMetric(value: String): XGBoostClassifier.this.type def setEvalSets(evalSets: Map[String, DataFrame]): XGBoostClassifier.this.type def setFeaturesCol(value: String): XGBoostClassifier def setGamma(value: Double): XGBoostClassifier.this.type def setGrowPolicy(value: String): XGBoostClassifier.this.type def setLabelCol(value: String): XGBoostClassifier.this.type def setLambda(value: Double): XGBoostClassifier.this.type def setLambdaBias(value: Double): XGBoostClassifier.this.type def setMaxBins(value: Int): XGBoostClassifier.this.type def setMaxDeltaStep(value: Double): XGBoostClassifier.this.type def setMaxDepth(value: Int): XGBoostClassifier.this.type def setMaxLeaves(value: Int): XGBoostClassifier.this.type def setMaximizeEvaluationMetrics(value: Boolean): XGBoostClassifier.this.type def setMinChildWeight(value: Double): XGBoostClassifier.this.type def setMissing(value: Float): XGBoostClassifier.this.type def setNormalizeType(value: String): XGBoostClassifier.this.type def setNthread(value: Int): XGBoostClassifier.this.type def setNumClass(value: Int): XGBoostClassifier.this.type def setNumEarlyStoppingRounds(value: Int): XGBoostClassifier.this.type def setNumRound(value: Int): XGBoostClassifier.this.type def setNumWorkers(value: Int): XGBoostClassifier.this.type def setObjective(value: String): XGBoostClassifier.this.type def setObjectiveType(value: String): XGBoostClassifier.this.type def setPredictionCol(value: String): XGBoostClassifier def setProbabilityCol(value: String): XGBoostClassifier def setRateDrop(value: Double): XGBoostClassifier.this.type def setRawPredictionCol(value: String): XGBoostClassifier.this.type def setSampleType(value: String): XGBoostClassifier.this.type def setScalePosWeight(value: Double): XGBoostClassifier.this.type def setSeed(value: Long): XGBoostClassifier.this.type def setSilent(value: Int): XGBoostClassifier.this.type def setSinglePrecisionHistogram(value: Boolean): XGBoostClassifier.this.type def setSketchEps(value: Double): XGBoostClassifier.this.type def setSkipDrop(value: Double): XGBoostClassifier.this.type def setSubsample(value: Double): XGBoostClassifier.this.type def setThresholds(value: Array[Double]): XGBoostClassifier def setTimeoutRequestWorkers(value: Long): XGBoostClassifier.this.type def setTrainTestRatio(value: Double): XGBoostClassifier.this.type def setTreeMethod(value: String): XGBoostClassifier.this.type def setUseExternalMemory(value: Boolean): XGBoostClassifier.this.type def setWeightCol(value: String): XGBoostClassifier.this.type
参数名称 |
参数含义 |
取值类型 |
---|---|---|
grow_policy |
修改参数,新增depthwiselossltd;控制新树节点加入树的方法;只有在tree_method被设置为hist时生效。 |
String,缺省值为“depthwise”,可选值为:“depthwise”、“lossguide”、“depthwiselossltd”。 |
min_loss_ratio |
控制训练过程中树节点的剪枝程度;只有在grow_policy为depthwiselossltd时生效。 |
Double,缺省值为0,范围[0,1)。 |
sampling_strategy |
控制训练过程中的采样策略。 |
String,缺省值为“eachTree”,可选值为:“eachTree”、“eachIteration”、“alliteration”、“multiIteration”、“gossStyle”。 |
enable_bbgen |
控制是否使用批伯努利位生成算法。 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
sampling_step |
控制采样的间隔轮次,只有sampling_strategy设置为multiIteration时生效。 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
auto_subsample |
控制是否采用自动减少采样率策略。 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
auto_k |
控制自动减少采样率策略中的轮次,只有auto_subsample设置为true时生效。 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
auto_subsample_ratio |
设置自动减少采样率的比例。 |
Array[Double],缺省值为:Array(0.05,0.1,0.2,0.4,0.8,1.0),可选值范围:(0,1]。 |
auto_r |
控制允许的自动减少采样率带来的错误率上升。 |
Double,缺省值为:0.95,可选值范围:(0,1]。 |
rabit_enable_tcp_no_delay |
控制Rabit引擎中的通信策略。 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
random_split_denom |
控制候选分割点的使用比例。 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
default_direction |
控制缺失值的默认方向。 |
String,缺省值:“learn”,可选值为:“left”、“right”、“learn”。 |
代码接口示例:
1 2 3 4 5 | val xgbClassifier = new XGBoostClassifier(param).setLabelCol("label").setFeaturesCol("features") val model = xgbClassifier.fit(train_data) val predictions = model.transform(test_data) val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) |
package com.bigdata.ml import java.io.File import java.lang.System.nanoTime import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.SparkConf import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.storage.StorageLevel import scala.util.Random import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier import com.typesafe.config.{Config, ConfigFactory} object Xgboost_test { def profile[R](code: => R, t: Long = nanoTime) = (code, nanoTime - t) def getSparkSession(): SparkSession = { val conf = new SparkConf() .setAppName("XGBOOST-SPARK") val spark = SparkSession .builder() .config( conf ) .getOrCreate() spark.sparkContext.setLogLevel("ERROR") println("SparkSession created successfully!") spark; } def main(args: Array[String]): Unit = { val config = ConfigFactory.parseFile(new File(args(0))) // set seed Random.setSeed(System.currentTimeMillis()) // set session val spark = this.getSparkSession() println("created spark session") println(spark.sparkContext.getConf.toDebugString) val (result, time) = profile(test(spark, config)) val time_sec = time.asInstanceOf[Long].toDouble * 1.0e-9 println(s"Profiling complete in $time_sec seconds. ") } def test(spark: SparkSession, config: Config): Unit = { var param = Map[String, Any]() val it = config.entrySet.iterator while (it.hasNext) { val entry = it.next param += (entry.getKey -> entry.getValue.unwrapped) } if (!config.hasPath("allow_non_zero_for_missing")) { param += ("allow_non_zero_for_missing" -> true) } println(param.mkString(";\n")) val xgbClassifier = new XGBoostClassifier(param) .setLabelCol("label") .setFeaturesCol("features") val time_point1 = System.currentTimeMillis() val train_data = getTrainData(spark, config).persist(StorageLevel.MEMORY_AND_DISK_SER) val time_point2 = System.currentTimeMillis() val model = xgbClassifier.fit(train_data) val time_point3 = System.currentTimeMillis() val test_data = getTestData(spark, config).persist(StorageLevel.MEMORY_AND_DISK_SER) val predictions = model.transform(test_data) val time_point4 = System.currentTimeMillis() val load_time = (time_point2 - time_point1) / 1000.0 println(s"Loading complete in $load_time seconds.") val training_time = (time_point3 - time_point2) / 1000.0 println(s"Training complete in $training_time seconds.") val testing_time = (time_point4 - time_point3) / 1000.0 println(s"Testing complete in $testing_time seconds.") // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println(s"Test Error = ${(1.0 - accuracy)}") // Select example rows to display. predictions.select("prediction", "label", "features").show(5) } def getTrainData(spark: SparkSession, config: Config): Dataset[Row] = { val tr_fname = config.getString("tr_fname") println("tr_fname", tr_fname) var reader = spark .read .format("libsvm") .option("vectorType", if (config.hasPath("vectorType")) config.getString("vectorType") else "sparse") if(config.hasPath("numFeatures")) { val numFeatures = config.getInt("numFeatures") println("numFeatures", numFeatures) reader = reader.option("numFeatures", numFeatures) } val tr_data = reader .load(tr_fname) tr_data } def getTestData(spark: SparkSession, config: Config): Dataset[Row] = { val ts_fname = config.getString("ts_fname") println("ts_fname", ts_fname) var reader = spark .read .format("libsvm") .option("vectorType", if (config.hasPath("vectorType")) config.getString("vectorType") else "sparse") if(config.hasPath("numFeatures")) { val numFeatures = config.getInt("numFeatures") println("numFeatures", numFeatures) reader = reader.option("numFeatures", numFeatures) } val ts_data = reader .load(ts_fname) ts_data } }
Test Error = 0.253418287207109 +----------+-----+--------------------+ |prediction|label| features| +----------+-----+--------------------+ | 0.0| 0.0|(28,[0,1,2,3,4,5,...| | 0.0| 0.0|(28,[0,1,2,3,4,5,...| | 0.0| 0.0|(28,[0,1,2,3,4,5,...| | 0.0| 0.0|(28,[0,1,2,3,4,5,...| | 0.0| 0.0|(28,[0,1,2,3,4,5,...| +----------+-----+--------------------+ only showing top 5 rows
参数名称 |
取值类型 |
缺省值 |
描述 |
---|---|---|---|
labelCol |
Double |
label |
预测标签 |
featuresCol |
Vector |
features |
特征标签 |
参数名称 |
取值类型 |
示例 |
描述 |
---|---|---|---|
predictionCol |
Double |
prediction |
预测的标签值 |
def setAllowNonZeroForMissing(value: Boolean): XGBoostRegressor.this.type def setAlpha(value: Double): XGBoostRegressor.this.type def setBaseMarginCol(value: String): XGBoostRegressor.this.type def setBaseScore(value: Double): XGBoostRegressor.this.type def setCheckpointInterval(value: Int): XGBoostRegressor.this.type def setCheckpointPath(value: String): XGBoostRegressor.this.type def setColsampleBylevel(value: Double): XGBoostRegressor.this.type def setColsampleBytree(value: Double): XGBoostRegressor.this.type def setCustomEval(value: EvalTrait): XGBoostRegressor.this.type def setCustomObj(value: ObjectiveTrait): XGBoostRegressor.this.type def setEta(value: Double): XGBoostRegressor.this.type def setEvalMetric(value: String): XGBoostRegressor.this.type def setEvalSets(evalSets: Map[String, DataFrame]): XGBoostRegressor.this.type def setFeaturesCol(value: String): XGBoostRegressor def setGamma(value: Double): XGBoostRegressor.this.type def setGroupCol(value: String): XGBoostRegressor.this.type def setGrowPolicy(value: String): XGBoostRegressor.this.type def setLabelCol(value: String): XGBoostRegressor.this.type def setLambda(value: Double): XGBoostRegressor.this.type def setLambdaBias(value: Double): XGBoostRegressor.this.type def setMaxBins(value: Int): XGBoostRegressor.this.type def setMaxDeltaStep(value: Double): XGBoostRegressor.this.type def setMaxDepth(value: Int): XGBoostRegressor.this.type def setMaxLeaves(value: Int): XGBoostRegressor.this.type def setMaximizeEvaluationMetrics(value: Boolean): XGBoostRegressor.this.type def setMinChildWeight(value: Double): XGBoostRegressor.this.type def setMissing(value: Float): XGBoostRegressor.this.type def setNormalizeType(value: String): XGBoostRegressor.this.type def setNthread(value: Int): XGBoostRegressor.this.type def setNumClass(value: Int): XGBoostRegressor.this.type def setNumEarlyStoppingRounds(value: Int): XGBoostRegressor.this.type def setNumRound(value: Int): XGBoostRegressor.this.type def setNumWorkers(value: Int): XGBoostRegressor.this.type def setObjective(value: String): XGBoostRegressor.this.type def setObjectiveType(value: String): XGBoostRegressor.this.type def setPredictionCol(value: String): XGBoostRegressor.this.type def setRateDrop(value: Double): XGBoostRegressor.this.type def setRawPredictionCol(value: String): XGBoostRegressor def setSampleType(value: String): XGBoostRegressor.this.type def setScalePosWeight(value: Double): XGBoostRegressor.this.type def setSeed(value: Long): XGBoostRegressor.this.type def setSilent(value: Int): XGBoostRegressor.this.type def setSinglePrecisionHistogram(value: Boolean): XGBoostRegressor.this.type def setSketchEps(value: Double): XGBoostRegressor.this.type def setSkipDrop(value: Double): XGBoostRegressor.this.type def setSubsample(value: Double): XGBoostRegressor.this.type def setThresholds(value: Array[Double]): XGBoostRegressor def setTimeoutRequestWorkers(value: Long): XGBoostRegressor.this.type def setTrainTestRatio(value: Double): XGBoostRegressor.this.type def setTreeMethod(value: String): XGBoostRegressor.this.type def setUseExternalMemory(value: Boolean): XGBoostRegressor.this.type def setWeightCol(value: String): XGBoostRegressor.this.type
参数名称 |
参数含义 |
取值类型 |
---|---|---|
grow_policy |
修改参数,新增depthwiselossltd;控制新树节点加入树的方法;只有在tree_method被设置为hist时生效 |
String,缺省值为“depthwise”,可选值为:“depthwise”、“lossguide”、“depthwiselossltd”。 |
min_loss_ratio |
控制训练过程中树节点的剪枝程度;只有在grow_policy为depthwiselossltd时生效 |
Double,缺省值为0,范围[0,1)。 |
sampling_strategy |
控制训练过程中的采样策略 |
String,缺省值为“eachTree”,可选值为:“eachTree”、“eachIteration”、“alliteration”、“multiIteration”、“gossStyle”。 |
enable_bbgen |
控制是否使用批伯努利位生成算法 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
sampling_step |
控制采样的间隔轮次,只有sampling_strategy设置为multiIteration时生效 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
auto_subsample |
控制是否采用自动减少采样率策略 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
auto_k |
控制自动减少采样率策略中的轮次,只有auto_subsample设置为true时生效 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
auto_subsample_ratio |
设置自动减少采样率的比例 |
Array[Double],缺省值为:Array(0.05,0.1,0.2,0.4,0.8,1.0),可选值范围:(0,1]。 |
auto_r |
控制允许的自动减少采样率带来的错误率上升 |
Double,缺省值为:0.95,可选值范围:(0,1]。 |
rabit_enable_tcp_no_delay |
控制Rabit引擎中的通信策略 |
Boolean,缺省值为:“false”,可选值为:“true”、“false”。 |
random_split_denom |
控制候选分割点的使用比例 |
Int,缺省值为:1,可选值范围:[1,+∞)。 |
default_direction |
控制缺失值的默认方向 |
String,缺省值:“learn”,可选值为:“left”、“right”、“learn”。 |
代码接口示例: val xgbRegression = new XGBoostRegressor(param).setLabelCol("label").setFeaturesCol("features") val model = xgbRegression.fit(train_data) val predictions = model.transform(test_data) val evaluator = new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions)
package com.bigdata.ml import java.io.File import java.lang.System.nanoTime import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.SparkConf import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.storage.StorageLevel import scala.util.Random import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier import com.typesafe.config.{Config, ConfigFactory} object Xgboost_test { def profile[R](code: => R, t: Long = nanoTime) = (code, nanoTime - t) def getSparkSession(): SparkSession = { val conf = new SparkConf() .setAppName("XGBOOST-SPARK") val spark = SparkSession .builder() .config( conf ) .getOrCreate() spark.sparkContext.setLogLevel("ERROR") println("SparkSession created successfully!") spark; } def main(args: Array[String]): Unit = { val config = ConfigFactory.parseFile(new File(args(0))) // set seed Random.setSeed(System.currentTimeMillis()) // set session val spark = this.getSparkSession() println("created spark session") println(spark.sparkContext.getConf.toDebugString) val (result, time) = profile(test(spark, config)) val time_sec = time.asInstanceOf[Long].toDouble * 1.0e-9 println(s"Profiling complete in $time_sec seconds. ") } def test(spark: SparkSession, config: Config): Unit = { var param = Map[String, Any]() val it = config.entrySet.iterator while (it.hasNext) { val entry = it.next param += (entry.getKey -> entry.getValue.unwrapped) } if (!config.hasPath("allow_non_zero_for_missing")) { param += ("allow_non_zero_for_missing" -> true) } println(param.mkString(";\n")) val xgbRegression = new XGBoostRegressor(param) .setLabelCol("label") .setFeaturesCol("features") val time_point1 = System.currentTimeMillis() val train_data = getTrainData(spark, config).persist(StorageLevel.MEMORY_AND_DISK_SER) val time_point2 = System.currentTimeMillis() val model = xgbRegression.fit(train_data) val time_point3 = System.currentTimeMillis() val test_data = getTestData(spark, config).persist(StorageLevel.MEMORY_AND_DISK_SER) val predictions = model.transform(test_data) val time_point4 = System.currentTimeMillis() val load_time = (time_point2 - time_point1) / 1000.0 println(s"Loading complete in $load_time seconds.") val training_time = (time_point3 - time_point2) / 1000.0 println(s"Training complete in $training_time seconds.") val testing_time = (time_point4 - time_point3) / 1000.0 println(s"Testing complete in $testing_time seconds.") // Select (prediction, true label) and compute test error. val evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println(s"Test Error = ${(1.0 - accuracy)}") // Select example rows to display. predictions.select("prediction", "label", "features").show(5) } def getTrainData(spark: SparkSession, config: Config): Dataset[Row] = { val tr_fname = config.getString("tr_fname") println("tr_fname", tr_fname) var reader = spark .read .format("libsvm") .option("vectorType", if (config.hasPath("vectorType")) config.getString("vectorType") else "sparse") if(config.hasPath("numFeatures")) { val numFeatures = config.getInt("numFeatures") println("numFeatures", numFeatures) reader = reader.option("numFeatures", numFeatures) } val tr_data = reader .load(tr_fname) tr_data } def getTestData(spark: SparkSession, config: Config): Dataset[Row] = { val ts_fname = config.getString("ts_fname") println("ts_fname", ts_fname) var reader = spark .read .format("libsvm") .option("vectorType", if (config.hasPath("vectorType")) config.getString("vectorType") else "sparse") if(config.hasPath("numFeatures")) { val numFeatures = config.getInt("numFeatures") println("numFeatures", numFeatures) reader = reader.option("numFeatures", numFeatures) } val ts_data = reader .load(ts_fname) ts_data } }
Test Error = 0.5872398843658918 +--------------------+-----+--------------------+ | prediction|label| features| +--------------------+-----+--------------------+ | 0.2738455533981323| 0.0|(28,[0,1,2,3,4,5,...| |0.052151769399642944| 0.0|(28,[0,1,2,3,4,5,...| | 0.08468279242515564| 0.0|(28,[0,1,2,3,4,5,...| | 0.20581847429275513| 0.0|(28,[0,1,2,3,4,5,...| | 0.3741578459739685| 0.0|(28,[0,1,2,3,4,5,...| +--------------------+-----+--------------------+ only showing top 5 rows