0
点赞
收藏
分享

微信扫一扫

mahout之TrainNaiveBayesJob源码分析


mahout的trainnb调用的是TrainNaiveBayesJob完成训练模型任务。所在包:



org.apache.mahout.classifier.naivebayes.training



TrainNaiveBayesJob的输入是在tfidf文件上split出来的一部分,用作训练。
TrainNaiveBayesJob代码分析,
首先加入一些命令行选项,如


LABEL      -L
ALPHA_I  -a
LABEL_INDEX  -li
TRAIN_COMPLEMENTARY      -c

然后从输入文件中读取label,将label保存于label index,例如20news group的例子,读取的label有两个,label index如下



Key class: class org.apache.hadoop.io.Text   Value Class: class org.apache.hadoop.io.IntWritable
Key: 20news-bydate-test: Value: 0
Key: 20news-bydate-train: Value: 1



其实也就是将分类建一个索引。

接下来,将相同label的vectors相加。也就是将同一个类别的所有的文章的vector相加。这里vector其实是一个key/value vector,每项由词的id和tfidf值组成。这样相加后就是一个一个类的vector,相同id的tfidf相加,没有的则插入,类似两个递增的链表的合并。由一个job来完成:

//      Key class: class org.apache.hadoop.io.Text       
//      Value Class: class org.apache.mahout.math.VectorWritable       
//add up all the vectors with the same labels, while mapping the labels into our index       
Job indexInstances        = prepareJob       (getInputPath       (       ),        //input path       
             getTempPath       (SUMMED_OBSERVATIONS       ),                    //output path       
            SequenceFileInputFormat.       class,                               //input format       
        IndexInstancesMapper.       class,                                    //mapper class       
        IntWritable.       class,                                                        //mapper key       
        VectorWritable.       class,                                                  //mapper value       
        VectorSumReducer.       class,                                          //reducer class       
        IntWritable.       class,                                                         //reducer key       
        VectorWritable.       class,                                                 //reducer value       
        SequenceFileOutputFormat.       class       )       ;                 //output format       
indexInstances.       setCombinerClass       (VectorSumReducer.       class       )       ;       
boolean succeeded        = indexInstances.       waitForCompletion       (       true       )       ;       
if        (       !succeeded       )        {       
          return        -       1       ;       
}



Mapper为IndexInstancesMapper,Reducer为Reducer VectorSumReducer,代码也比较简单,如下,


protected        void map       (Text labelText, VectorWritable instance,        Context ctx       )        throws        IOException,        InterruptedException        {       
           String label        = labelText.       toString       (       ).       split       (       "/"       )       [       1       ]       ;       
if        (labelIndex.       containsKey       (label       )       )        {       
//从文件中读取的类的index作为key       
      ctx.       write       (       new IntWritable       (labelIndex.       get       (label       )       ), instance       )       ;       
           }        else        {       
      ctx.       getCounter       (Counter.       SKIPPED_INSTANCES       ).       increment       (       1       )       ;       
           }       
         }       
         //相同key的vector相加       
         protected        void reduce       (WritableComparable       <        ?        > key, Iterable       < VectorWritable        > values,        Context ctx       )       
           throws        IOException,        InterruptedException        {       
           Vector vector        =        null       ;       
           for        (VectorWritable v        : values       )        {       
             if        (vector        ==        null       )        {       
        vector        = v.       get       (       )       ;       
             }        else        {       
        vector.       assign       (v.       get       (       ), Functions.       PLUS       )       ;       
             }       
           }       
    ctx.       write       (key,        new VectorWritable       (vector       )       )       ;       
         }



OK,到现在已经得到了< label_index,label_vector >,即类的id和类中所有item(或者说feature)的TFIDF值。此步得到类似如下的输出,


Key: 0
Value: /comp.sys.ibm.pc.hardware/60252:{93562:17.52922821044922,93559:9.745443344116211,93558:107.53932094573975,93557:49.015570640563965,93556:9.745443344116211……}
key:1
Value:
/alt.atheism/53261:{93562:26.293842315673828,93560:19.490886688232422,93559:9.745443344116211,93558:78.52010536193848,93557:62.2713, 93555:14.35555171……}


下一个阶段就是统计每个label的所有ITIDF和,输入为上一步的输出,并由一个job来执行,

//sum up all the weights from the previous step, per label and per feature       
    Job weightSummer        = prepareJob       (getTempPath       (SUMMED_OBSERVATIONS       ),       
                       getTempPath       (WEIGHTS       ),       
            SequenceFileInputFormat.       class,       
            WeightsMapper.       class,       
            Text.       class,       
            VectorWritable.       class,       
            VectorSumReducer.       class,       
            Text.       class,       
            VectorWritable.       class,       
            SequenceFileOutputFormat.       class       )       ;       
    weightSummer.       getConfiguration       (       ).       set       (WeightsMapper.       NUM_LABELS,        String.       valueOf       (labelSize       )       )       ;       
    weightSummer.       setCombinerClass       (VectorSumReducer.       class       )       ;       
    succeeded        = weightSummer.       waitForCompletion       (       true       )       ;       
           if        (       !succeeded       )        {       
             return        -       1       ;       
           }



job的mapper为WeightsMapper,reducer与上一步的相同,为VectorSumReducer。
mapper如下,

protected        void map       (IntWritable index, VectorWritable value,        Context ctx       )        throws        IOException,        InterruptedException        {       
           Vector instance        = value.       get       (       )       ;       
           if        (weightsPerFeature        ==        null       )        {       
      weightsPerFeature        =        new RandomAccessSparseVector       (instance.       size       (       ), instance.       getNumNondefaultElements       (       )       )       ;       
           }       
           int label        = index.       get       (       )       ;       
    weightsPerFeature.       assign       (instance, Functions.       PLUS       )       ;       
    weightsPerLabel.       set       (label, weightsPerLabel.       get       (label       )        + instance.       zSum       (       )       )       ;       
         }



此步的输出写在cleanup()中。

protected        void cleanup       (       Context ctx       )        throws        IOException,        InterruptedException        {       
           if        (weightsPerFeature        !=        null       )        {       
      ctx.       write       (       new Text       (TrainNaiveBayesJob.       WEIGHTS_PER_FEATURE       ),       
new VectorWritable       (weightsPerFeature       )       )       ;       
      ctx.       write       (       new Text       (TrainNaiveBayesJob.       WEIGHTS_PER_LABEL       ),       
new VectorWritable       (weightsPerLabel       )       )       ;       
           }       
           super.       cleanup       (ctx       )       ;       
         }



也就是说输出只有两个key/value.
一个是WEIGHTS_PER_FEATURE(定义的常量,__SPF)
一个是WEIGHTS_PER_LABEL(__SPL)
weightsPerFeature其实就是保持上一步的vector没变,仍然是一个类中所有iterm(feature)的TFIDF。
weightsPerLabel就是求每个label中的和了。
可以看到输出为,



Key: __SPF
Value: {93562:43.82307052612305,93560:19.490886688232422,93559:19.490886688232422,93558:186.05942630767822,93557:111.28696632385254,93556:9.745443344116211……}
Key: __SPL
Value: {1:7085520.472989678,0:4662610.912284017}



最后一步,先看源代码,

//calculate the Thetas, write out to LABEL_THETA_NORMALIZER vectors       
//-- TODO: add reference here to the part of the Rennie paper that discusses this       
Job thetaSummer        =       
prepareJob       (getTempPath       (SUMMED_OBSERVATIONS       ), getTempPath       (THETAS       ),       
            SequenceFileInputFormat.       class,       
            ThetaMapper.       class,       
            Text.       class,       
            VectorWritable.       class,       
            VectorSumReducer.       class,       
            Text.       class,       
            VectorWritable.       class,       
            SequenceFileOutputFormat.       class       )       ;       
    thetaSummer.       setCombinerClass       (VectorSumReducer.       class       )       ;       
    thetaSummer.       getConfiguration       (       ).       setFloat       (ThetaMapper.       ALPHA_I, alphaI       )       ;       
    thetaSummer.       getConfiguration       (       ).       setBoolean       (ThetaMapper.       TRAIN_COMPLEMENTARY, trainComplementary       )       ;       
           /* TODO(robinanil): Enable this when thetanormalization works.
    succeeded = thetaSummer.waitForCompletion(true);
    if (!succeeded) {
      return -1;
}*/



可以看到thetaSummer.waitForCompletion(true)被注释掉了,job没有执行。注释里面说的Rennie paper指的就是mahout bayes算法参考的这篇论文:Tackling the Poor Assumptions of Naive Bayes Text Classifiers,论文里面有个求Ɵ的公式如下。不知为何注释掉?求解。


最最后一步,其实model有weightsPerFeature和weightsPerLabel就完成了。这一步也就是把它们变成矩阵形式,如下,每行一个权重vector。
____|item1,iterm2,item3……
lab1|
lab2|
……

源代码如下,

//得到SparseMatrix矩阵       
NaiveBayesModel naiveBayesModel        = BayesUtils.       readModelFromDir       (getTempPath       (       ), getConf       (       )       )       ;       
naiveBayesModel.       validate       (       )       ;       
//序列化,写到output/naiveBayesModel.bin       
naiveBayesModel.       serialize       (getOutputPath       (       ), getConf       (       )       )       ;


THE END

 

http://hnote.org/big-data/mahout/mahout-train-naive-bayes-job




 

http://soledede.com/

 




微信公众号:


mahout之TrainNaiveBayesJob源码分析_Text

举报

相关推荐

0 条评论