中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

LDA

LDA为ML API类模型接口。

模型接口类别

函数接口

ML API

def fit(dataset: Dataset[_]): LDAModel def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[LDAModel] def fit(dataset: Dataset[_], paramMap: ParamMap): LDAModel

def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): LDAModel

ML API

  • 功能描述

    传入Dataset格式的样本数据,调用训练接口,输出LDA模型。

  • 输入输出
    1. 包名:package org.apache.spark.ml.clustering.LDA
    2. 类名:LDA
    3. 方法名:fit
    4. 输入:Dataset[_],训练样本数据,必须字段如下。

      参数名称

      取值类型

      默认值

      描述

      featuresCol

      Vector

      "features"

      特征向量

    5. 算法参数

      算法参数

      def setCheckpointInterval(value: Int): LDA.this.type

      def setDocConcentration(value: Double): LDA.this.type

      def setDocConcentration(value: Array[Double]): LDA.this.type

      def setFeaturesCol(value: String): LDA.this.type

      def setK(value: Int): LDA.this.type

      def setMaxIter(value: Int): LDA.this.type

      def setSeed(value: Long): LDA.this.type

      def setSubsamplingRate(value: Double): LDA.this.type

      def setTopicConcentration(value: Double): LDA.this.type

      def setTopicDistributionCol(value: String): LDA.this.type

      def setOptimizer(value: String): LDA.this.type

      def setKeepLastCheckpoint(value: Boolean): LDA.this.type

      def setLearningDecay(value: Double): LDA.this.type

      def setLearningOffset(value: Double): LDA.this.type

      def setOptimizeDocConcentration(value: Boolean): LDA.this.type

      参数及fit代码接口示例:

      import org.apache.spark.ml.param.{ParamMap, ParamPair}
      
      val lda = new LDA()
      //定义def fit(dataset: Dataset[_], paramMap: ParamMap) 接口参数
      val paramMap = ParamMap(lda.k -> k)
      .put(lda.maxIter, maxIter)
      
      // 定义def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): 接口参数
      val paramMaps: Array[ParamMap] = new Array[ParamMap](2)
      for (i <- 0 to  2) {
      paramMaps(i) = ParamMap(lda.k -> k)
      .put(lda.maxIter, maxIter)
      }//对paramMaps进行赋值
      
      // 定义def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*) 接口参数
      val kParamPair = ParamPair(lda.k, k)
      val maxIterParamPair = ParamPair(lda.maxIter, maxIter)
      val checkpointIntervalParamPair = ParamPair(lda.checkpointInterval, checkpointInterval)
      
      // 调用各个fit接口
      model = lda.fit(trainingData)
      model = lda.fit(trainingData, paramMap)
      models = lda.fit(trainingData, paramMaps)
      model = lda.fit(trainingData, kkParamPair, maxIterParamPair, checkpointIntervalParamPair)
    6. 输出:LDAModel,LDA模型,模型预测时的输出字段如下。

      参数名称

      取值类型

      默认值

      描述

      topicDistribution Col

      Vector

      "topicDistributio nCol"

      每一个文档的主题分布

  • 使用样例
    import org.apache.spark.ml.clustering.LDA
    
    // Loads data.
    val dataset = spark.read.format("libsvm")
    .load("data/mllib/sample_lda_libsvm_data.txt")
    
    // Trains a LDA model.
    val lda = new LDA().setK(10).setMaxIter(10)
    val model = lda.fit(dataset)
    
    val ll = model.logLikelihood(dataset)
    val lp = model.logPerplexity(dataset)
    println(s"The lower bound on the log likelihood of the entire corpus: $ll")
    println(s"The upper bound on perplexity: $lp")
    
    // Describe topics.
    val topics = model.describeTopics(3)
    println("The topics described by their top-weighted terms:")
    topics.show(false)
    
    // Shows the result.
    val transformed = model.transform(dataset)
    transformed.show(false)
  • 结果样例
    Test Error = 0.0714285714285714
    The lower bound on the log likelihood of the entire corpus: -841.0546578646513
    The upper bound on perplexity: 3.2394551460843743
    The topics described by their top-weighted terms:
    +-----+-----------+---------------------------------------------------------------+
    |topic|termIndices|termWeights                                                    |
    +-----+-----------+---------------------------------------------------------------+
    |0    |[2, 5, 7]  |[0.10606440859619756, 0.10570106168104901, 0.10430389617455987]|
    |1    |[1, 6, 2]  |[0.10185076997493327, 0.09816928141852303, 0.09632454354056506]|
    |2    |[10, 6, 9] |[0.2183019165124768, 0.13864436129889263, 0.13063106158820773] |
    |3    |[0, 4, 8]  |[0.10270701955799236, 0.09842848153379427, 0.09815661242066778]|
    |4    |[9, 6, 4]  |[0.10452964433317273, 0.1041490817814721, 0.10103987046100901] |
    |5    |[1, 10, 0] |[0.10214945362083101, 0.10129059983059674, 0.09513643669014085]|
    |6    |[3, 7, 4]  |[0.11638316687843665, 0.09901763170620775, 0.09795372072055877]|
    |7    |[4, 0, 2]  |[0.10855453653883299, 0.10334275138796098, 0.10034943368696514]|
    |8    |[0, 7, 8]  |[0.11008008210198178, 0.09919723498780184, 0.09810902425203567]|
    |9    |[9, 6, 8]  |[0.10106110089497022, 0.10013295826841445, 0.09769277851351822]|
    +-----+-----------+---------------------------------------------------------------+
    +-----+---------------------------------------------------------------
    +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |label|features                                                       |
    topicDistribution                                                                                                                                                                                                     |
    +-----+---------------------------------------------------------------
    +---------------------------------------------------------------------------------------------------------------------
    -------------------------------------------------------------------------------------------------+
    |0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.7020102723525649,0.004825993075374452,0.2593820375820705,0.004825958849718463,0.004825 93471041594,0.0048259867769672666,0.004825958945138608,0.004826029288984295,0.004825898501024282,0.004825929917741284]       |
    |1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.008050057075554595,0.008049908274306143,0.5358743394043809,0.3997256465275279,0.008049
    856144496943,0.008050063705485845,0.008050120759460606,0.008050118726041808,0.008050050616141662,0.00804983876660343]        |
    |2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.004196160909228379,0.004196335770441454,0.9622355738659514,0.004196094608077031,0.0041
    95947814373813,0.004195985684081985,0.004196034575816405,0.0041960010288430855,0.004195858812126837,0.004196006931059736]    |
    |3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.0037108247735123012,0.0037108689277450544,0.9666020933193584,0.003710885149198092,0.00
    3710937342701358,0.003710898637581225,0.0037108756147115974,0.00371084146473461,0.0037108939404337966,0.0037108808300235123] |
    |4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.004020615024955425,0.004020656597961219,0.9638138100408753,0.00402067920269795,0.00402
    07333361871175,0.004020674593541886,0.004020896170250796,0.004020656802269066,0.004020671303761361,0.004020606927500021]     |
    |5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.003711292914086739,0.003711270749947734,0.37751721754443834,0.5927926154552048,0.00371
    1280496155832,0.0037112541886366213,0.0037112946841805156,0.0037112985706198556,0.0037112619372495314,0.0037112134594801333] |
    |6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.0038593999027201936,0.0038594526216442883,0.9652647151220881,0.0038595435459804063,0.0
    038594924013660528,0.003859503714977076,0.00385946223024681,0.003859418248801901,0.0038595472452573583,0.00385946496691772]  |
    |7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.004386696022976412,0.00438670306926343,0.9605196858001243,0.004386709348681648,0.00438
    67254271577425,0.004386684670877583,0.004386856901250918,0.004386652103913797,0.00438664 418447618,0.004386642471277907]      |
    |8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.004386774230797772,0.004386827847922799,0.004929736666909785,0.9599758309339176,0.0043
    867768998796354,0.004386851645634705,0.004386859479736638,0.004386804006042446,0.004386800687711282,0.0043867376014473675]   |
    |9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.003326812373404729,0.003326822375240019,0.970058659329348,0.003326809001555563,0.00332
    6846900378563,0.0033268275133451976,0.003326770265726515,0.0033268422245536526,0.0033268053628080743,0.003326804653639628]   |
    |10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.004195866114900055,0.0041958485431816475,0.9622374182867508,0.004195834901098945,0.004
    19580907412282,0.004195800055494461,0.004196126970464509,0.004195750672824693,0.004195747645015526,0.004195797736146441]     |
    |11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.004826030254621442,0.004825970840289355,0.0054210556618595335,0.9559711158059512,0.004
    825951272185808,0.0048259739863484715,0.0048259709554675685,0.0048260615010730715,0.004825982971090291,0.0048258867511133344]|
    +-----+---------------------------------------------------------------
    +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+