LDA
LDA为ML API类模型接口。
模型接口类别 |
函数接口 |
---|---|
ML API |
def fit(dataset: Dataset[_]): LDAModel def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[LDAModel] def fit(dataset: Dataset[_], paramMap: ParamMap): LDAModel def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): LDAModel |
ML API
- 功能描述
- 输入输出
- 包名:package org.apache.spark.ml.clustering.LDA
- 类名:LDA
- 方法名:fit
- 输入:Dataset[_],训练样本数据,必须字段如下。
参数名称
取值类型
默认值
描述
featuresCol
Vector
"features"
特征向量
- 算法参数
算法参数
def setCheckpointInterval(value: Int): LDA.this.type
def setDocConcentration(value: Double): LDA.this.type
def setDocConcentration(value: Array[Double]): LDA.this.type
def setFeaturesCol(value: String): LDA.this.type
def setK(value: Int): LDA.this.type
def setMaxIter(value: Int): LDA.this.type
def setSeed(value: Long): LDA.this.type
def setSubsamplingRate(value: Double): LDA.this.type
def setTopicConcentration(value: Double): LDA.this.type
def setTopicDistributionCol(value: String): LDA.this.type
def setOptimizer(value: String): LDA.this.type
def setKeepLastCheckpoint(value: Boolean): LDA.this.type
def setLearningDecay(value: Double): LDA.this.type
def setLearningOffset(value: Double): LDA.this.type
def setOptimizeDocConcentration(value: Boolean): LDA.this.type
参数及fit代码接口示例:
import org.apache.spark.ml.param.{ParamMap, ParamPair} val lda = new LDA() //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数 val paramMap = ParamMap(lda.k -> k) .put(lda.maxIter, maxIter) // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数 val paramMaps: Array[ParamMap] = new Array[ParamMap](2) for (i <- 0 to 2) { paramMaps(i) = ParamMap(lda.k -> k) .put(lda.maxIter, maxIter) }//对paramMaps进行赋值 // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数 val kParamPair = ParamPair(lda.k, k) val maxIterParamPair = ParamPair(lda.maxIter, maxIter) val checkpointIntervalParamPair = ParamPair(lda.checkpointInterval, checkpointInterval) // 调用各个fit接口 model = lda.fit(trainingData) model = lda.fit(trainingData, paramMap) models = lda.fit(trainingData, paramMaps) model = lda.fit(trainingData, kkParamPair, maxIterParamPair, checkpointIntervalParamPair)
- 输出:LDAModel,LDA模型,模型预测时的输出字段如下。
参数名称
取值类型
默认值
描述
topicDistribution Col
Vector
"topicDistributio nCol"
每一个文档的主题分布
- 使用样例
import org.apache.spark.ml.clustering.LDA // Loads data. val dataset = spark.read.format("libsvm") .load("data/mllib/sample_lda_libsvm_data.txt") // Trains a LDA model. val lda = new LDA().setK(10).setMaxIter(10) val model = lda.fit(dataset) val ll = model.logLikelihood(dataset) val lp = model.logPerplexity(dataset) println(s"The lower bound on the log likelihood of the entire corpus: $ll") println(s"The upper bound on perplexity: $lp") // Describe topics. val topics = model.describeTopics(3) println("The topics described by their top-weighted terms:") topics.show(false) // Shows the result. val transformed = model.transform(dataset) transformed.show(false)
- 结果样例
Test Error = 0.0714285714285714 The lower bound on the log likelihood of the entire corpus: -841.0546578646513 The upper bound on perplexity: 3.2394551460843743 The topics described by their top-weighted terms: +-----+-----------+---------------------------------------------------------------+ |topic|termIndices|termWeights | +-----+-----------+---------------------------------------------------------------+ |0 |[2, 5, 7] |[0.10606440859619756, 0.10570106168104901, 0.10430389617455987]| |1 |[1, 6, 2] |[0.10185076997493327, 0.09816928141852303, 0.09632454354056506]| |2 |[10, 6, 9] |[0.2183019165124768, 0.13864436129889263, 0.13063106158820773] | |3 |[0, 4, 8] |[0.10270701955799236, 0.09842848153379427, 0.09815661242066778]| |4 |[9, 6, 4] |[0.10452964433317273, 0.1041490817814721, 0.10103987046100901] | |5 |[1, 10, 0] |[0.10214945362083101, 0.10129059983059674, 0.09513643669014085]| |6 |[3, 7, 4] |[0.11638316687843665, 0.09901763170620775, 0.09795372072055877]| |7 |[4, 0, 2] |[0.10855453653883299, 0.10334275138796098, 0.10034943368696514]| |8 |[0, 7, 8] |[0.11008008210198178, 0.09919723498780184, 0.09810902425203567]| |9 |[9, 6, 8] |[0.10106110089497022, 0.10013295826841445, 0.09769277851351822]| +-----+-----------+---------------------------------------------------------------+ +-----+--------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |label|features | topicDistribution | +-----+--------------------------------------------------------------- +--------------------------------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------------+ |0.0 |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0]) |[0.7020102723525649,0.004825993075374452,0.2593820375820705,0.004825958849718463,0.004825 93471041594,0.0048259867769672666,0.004825958945138608,0.004826029288984295,0.004825898501024282,0.004825929917741284] | |1.0 |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0]) |[0.008050057075554595,0.008049908274306143,0.5358743394043809,0.3997256465275279,0.008049 856144496943,0.008050063705485845,0.008050120759460606,0.008050118726041808,0.008050050616141662,0.00804983876660343] | |2.0 |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0]) |[0.004196160909228379,0.004196335770441454,0.9622355738659514,0.004196094608077031,0.0041 95947814373813,0.004195985684081985,0.004196034575816405,0.0041960010288430855,0.004195858812126837,0.004196006931059736] | |3.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0]) |[0.0037108247735123012,0.0037108689277450544,0.9666020933193584,0.003710885149198092,0.00 3710937342701358,0.003710898637581225,0.0037108756147115974,0.00371084146473461,0.0037108939404337966,0.0037108808300235123] | |4.0 |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0]) |[0.004020615024955425,0.004020656597961219,0.9638138100408753,0.00402067920269795,0.00402 07333361871175,0.004020674593541886,0.004020896170250796,0.004020656802269066,0.004020671303761361,0.004020606927500021] | |5.0 |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.003711292914086739,0.003711270749947734,0.37751721754443834,0.5927926154552048,0.00371 1280496155832,0.0037112541886366213,0.0037112946841805156,0.0037112985706198556,0.0037112619372495314,0.0037112134594801333] | |6.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0]) |[0.0038593999027201936,0.0038594526216442883,0.9652647151220881,0.0038595435459804063,0.0 038594924013660528,0.003859503714977076,0.00385946223024681,0.003859418248801901,0.0038595472452573583,0.00385946496691772] | |7.0 |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.004386696022976412,0.00438670306926343,0.9605196858001243,0.004386709348681648,0.00438 67254271577425,0.004386684670877583,0.004386856901250918,0.004386652103913797,0.00438664 418447618,0.004386642471277907] | |8.0 |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0]) |[0.004386774230797772,0.004386827847922799,0.004929736666909785,0.9599758309339176,0.0043 867768998796354,0.004386851645634705,0.004386859479736638,0.004386804006042446,0.004386800687711282,0.0043867376014473675] | |9.0 |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0]) |[0.003326812373404729,0.003326822375240019,0.970058659329348,0.003326809001555563,0.00332 6846900378563,0.0033268275133451976,0.003326770265726515,0.0033268422245536526,0.0033268053628080743,0.003326804653639628] | |10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0]) |[0.004195866114900055,0.0041958485431816475,0.9622374182867508,0.004195834901098945,0.004 19580907412282,0.004195800055494461,0.004196126970464509,0.004195750672824693,0.004195747645015526,0.004195797736146441] | |11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0]) |[0.004826030254621442,0.004825970840289355,0.0054210556618595335,0.9559711158059512,0.004 825951272185808,0.0048259739863484715,0.0048259709554675685,0.0048260615010730715,0.004825982971090291,0.0048258867511133344]| +-----+--------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+