溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Spark中決策樹(shù)源碼分析

發(fā)布時(shí)間:2020-07-09 15:26:55 來(lái)源:網(wǎng)絡(luò) 閱讀:592 作者:jjjssswww 欄目:大數(shù)據(jù)

1.Example

使用Spark MLlib中決策樹(shù)分類器API,訓(xùn)練出一個(gè)決策樹(shù)模型,使用Python開(kāi)發(fā)。

"""
Decision Tree Classification Example.
"""from __future__ import print_functionfrom pyspark import SparkContextfrom pyspark.mllib.tree import DecisionTree, DecisionTreeModelfrom pyspark.mllib.util import MLUtilsif __name__ == "__main__":

    sc = SparkContext(appName="PythonDecisionTreeClassificationExample")    # 加載和解析數(shù)據(jù)文件為RDD
    dataPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/sample_libsvm_data.txt"
    print(dataPath)

    data = MLUtils.loadLibSVMFile(sc,dataPath)    # 將數(shù)據(jù)集分割為訓(xùn)練數(shù)據(jù)集和測(cè)試數(shù)據(jù)集
    (trainingData,testData) = data.randomSplit([0.7,0.3])
    print("train data count: " + str(trainingData.count()))
    print("test data count : " + str(testData.count()))    # 訓(xùn)練決策樹(shù)分類器
    # categoricalFeaturesInfo 為空,表示所有的特征均為連續(xù)值
    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                         impurity='gini', maxDepth=5, maxBins=32)    # 測(cè)試數(shù)據(jù)集上預(yù)測(cè)
    predictions = model.predict(testData.map(lambda x: x.features))    # 打包真實(shí)值與預(yù)測(cè)值
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)    # 統(tǒng)計(jì)預(yù)測(cè)錯(cuò)誤的樣本的頻率
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Decision Tree Test Error = %5.3f%%'%(testErr*100))
    print("Decision Tree Learned classifiction tree model : ")
    print(model.toDebugString())    # 保存和加載訓(xùn)練好的模型
    modelPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/myDecisionTreeClassificationModel"
    model.save(sc, modelPath)
    sameModel = DecisionTreeModel.load(sc, modelPath)

2.決策樹(shù)源碼分析

決策樹(shù)分類器API為DecisionTree.trainClassifier,進(jìn)入源碼分析。

源碼文件所在路徑為,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala。

  @Since("1.1.0")
  def trainClassifier(
      input: RDD[LabeledPoint],
      numClasses: Int,
      categoricalFeaturesInfo: Map[Int, Int],
      impurity: String,
      maxDepth: Int,
      maxBins: Int): DecisionTreeModel = {
    val impurityType = Impurities.fromString(impurity)
    train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
      categoricalFeaturesInfo)
  }

訓(xùn)練出一個(gè)分類器,然后調(diào)用了train方法。

  @Since("1.0.0")
  def train(
      input: RDD[LabeledPoint],
      algo: Algo,
      impurity: Impurity,
      maxDepth: Int,
      numClasses: Int,
      maxBins: Int,
      quantileCalculationStrategy: QuantileStrategy,
      categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
    val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
      quantileCalculationStrategy, categoricalFeaturesInfo)
    new DecisionTree(strategy).run(input)
  }

train方法首先將模型類型(分類或者回歸)、信息增益指標(biāo)、決策樹(shù)深度、分類數(shù)目、最大切分箱子數(shù)等參數(shù)封裝為Strategy,然后新建一個(gè)DecisionTree對(duì)象,并調(diào)用run方法。

@Since("1.0.0")class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
  extends Serializable with Logging {  /**
   * @param strategy The configuration parameters for the tree algorithm which specify the type
   *                 of decision tree (classification or regression), feature type (continuous,
   *                 categorical), depth of the tree, quantile calculation strategy, etc.
   */
  @Since("1.0.0")  def this(strategy: Strategy) = this(strategy, seed = 0)

  strategy.assertValid()  /**
   * Method to train a decision tree model over an RDD
   *
   * @param input Training data: RDD of `org`.`apache`.`spark`.`mllib`.`regression`.`LabeledPoint`.
   * @return DecisionTreeModel that can be used for prediction.
   */
  @Since("1.2.0")  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)    val rfModel = rf.run(input)
    rfModel.trees(0)
  }
}

run方法中首先新建一個(gè)RandomForest對(duì)象,將strategy、決策樹(shù)數(shù)目設(shè)置為1,子集選擇策略為"all"傳遞給RandomForest對(duì)象,然后調(diào)用RandomForest中的run方法,最后返回隨機(jī)森林模型中的第一棵決策樹(shù)。

也就是,決策樹(shù)模型使用了隨機(jī)森林模型進(jìn)行訓(xùn)練,將決策樹(shù)數(shù)目設(shè)置為1,然后將隨機(jī)森林模型中的第一棵決策樹(shù)作為結(jié)果,返回作為決策樹(shù)訓(xùn)練模型。

3.隨機(jī)森林源碼分析

隨機(jī)森林的源碼文件所在路徑為,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala。

private class RandomForest (    private val strategy: Strategy,    private val numTrees: Int,
    featureSubsetStrategy: String,    private val seed: Int)
  extends Serializable with Logging {

  strategy.assertValid()  require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")  require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
    || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
    || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
    s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
    s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
    s" (0.0-1.0], [1-n].")  /**
   * Method to train a decision tree model over an RDD
   *
   * @param input Training data: RDD of `org`.`apache`.`spark`.`mllib`.`regression`.`LabeledPoint`.
   * @return RandomForestModel that can be used for prediction.
   */
  def run(input: RDD[LabeledPoint]): RandomForestModel = {
    val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
      featureSubsetStrategy, seed.toLong, None)    new RandomForestModel(strategy.algo, trees.map(_.toOld))
  }

}

在該文件開(kāi)頭,通過(guò)"import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}"將ml中的RandomForest引入,重新命名為NewRandomForest。

在RandomForest.run方法中,首先新建NewRandomForest模型,并調(diào)用該類的run方法,然后將生成的trees作為新建RandomForestModel的入?yún)ⅰ?/p>

NewRandomForest,源碼文件所在路徑為,spark-1.6/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala。

由于涉及代碼量較大,因此無(wú)法將代碼展開(kāi),run方法主要有如下調(diào)用。

run方法

--->1. val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees,featureSubsetStrategy) # 對(duì)輸入數(shù)據(jù)建立元數(shù)據(jù)--->2. val splits = findSplits(retaggedInput, metadata, seed) # 對(duì)元數(shù)據(jù)中的特征進(jìn)行切分

    --->2.1 計(jì)算采樣率,對(duì)輸入樣本進(jìn)行采樣
    
    --->2.2 findSplitsBySorting(sampledInput, metadata, continuousFeatures) # 對(duì)采樣后的樣本中的特征進(jìn)行切分
    
        --->2.2.1 val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) # 針對(duì)連續(xù)型特征
        
        --->2.2.2 val categories = extractMultiClassCategories(splitIndex + 1, featureArity) # 針對(duì)分類型特征,且特征無(wú)序
        
        --->2.2.3 Array.empty[Split] # 針對(duì)分類型特征,且特征有序,訓(xùn)練時(shí)直接構(gòu)造即可--->3. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) # 將輸入數(shù)據(jù)轉(zhuǎn)換為樹(shù)形數(shù)據(jù)

    --->3.1 input.map { x => TreePoint.labeledPointToTreePoint(x, thresholds, featureArity) # 將LabeledPoint數(shù)據(jù)轉(zhuǎn)換為T(mén)reePoint數(shù)據(jù)
    
    --->3.2 arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) # 在(labeledPoint,feature)中找出一個(gè)離散值--->4. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,withReplacement, seed) # 對(duì)輸入數(shù)據(jù)進(jìn)行采樣

    --->4.1 convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) #有放回采樣

    --->4.2 convertToBaggedRDDWithoutSampling(input) # 樣本數(shù)為1,采樣率為100%

    --->4.3 convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) # 無(wú)放回采樣--->5. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage,metadata, rng) # 取得每棵樹(shù)所有需要切分的結(jié)點(diǎn)

    --->5.1 val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)} # 如果需要子采樣,選擇特征子集
    
    --->5.2 val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L # 計(jì)算添加這個(gè)結(jié)點(diǎn)之后,是否有足夠的內(nèi)存--->6. RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) # 找出最優(yōu)切分點(diǎn)

    --->6.1 val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) #找出每個(gè)結(jié)點(diǎn)最好的切分--->7. new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, strategy.getNumClasses) # 返回決策樹(shù)分類模型


向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI