mahout的trainnb调用的是TrainNaiveBayesJob完成训练模型任务。所在包:
org.apache.mahout.classifier.naivebayes.training
TrainNaiveBayesJob的输入是在tfidf文件上split出来的一部分,用作训练。
TrainNaiveBayesJob代码分析,
首先加入一些命令行选项,如
LABEL -L
ALPHA_I -a
LABEL_INDEX -li
TRAIN_COMPLEMENTARY -c
然后从输入文件中读取label,将label保存于label index,例如20news group的例子,读取的label有两个,label index如下
Key class: class org.apache.hadoop.io.Text Value Class: class org.apache.hadoop.io.IntWritable
Key: 20news-bydate-test: Value: 0
Key: 20news-bydate-train: Value: 1
其实也就是将分类建一个索引。
接下来,将相同label的vectors相加。也就是将同一个类别的所有的文章的vector相加。这里vector其实是一个key/value vector,每项由词的id和tfidf值组成。这样相加后就是一个一个类的vector,相同id的tfidf相加,没有的则插入,类似两个递增的链表的合并。由一个job来完成:
// Key class: class org.apache.hadoop.io.Text
// Value Class: class org.apache.mahout.math.VectorWritable
//add up all the vectors with the same labels, while mapping the labels into our index
Job indexInstances = prepareJob (getInputPath ( ), //input path
getTempPath (SUMMED_OBSERVATIONS ), //output path
SequenceFileInputFormat. class, //input format
IndexInstancesMapper. class, //mapper class
IntWritable. class, //mapper key
VectorWritable. class, //mapper value
VectorSumReducer. class, //reducer class
IntWritable. class, //reducer key
VectorWritable. class, //reducer value
SequenceFileOutputFormat. class ) ; //output format
indexInstances. setCombinerClass (VectorSumReducer. class ) ;
boolean succeeded = indexInstances. waitForCompletion ( true ) ;
if ( !succeeded ) {
return - 1 ;
}
Mapper为IndexInstancesMapper,Reducer为Reducer VectorSumReducer,代码也比较简单,如下,
protected void map (Text labelText, VectorWritable instance, Context ctx ) throws IOException, InterruptedException {
String label = labelText. toString ( ). split ( "/" ) [ 1 ] ;
if (labelIndex. containsKey (label ) ) {
//从文件中读取的类的index作为key
ctx. write ( new IntWritable (labelIndex. get (label ) ), instance ) ;
} else {
ctx. getCounter (Counter. SKIPPED_INSTANCES ). increment ( 1 ) ;
}
}
//相同key的vector相加
protected void reduce (WritableComparable < ? > key, Iterable < VectorWritable > values, Context ctx )
throws IOException, InterruptedException {
Vector vector = null ;
for (VectorWritable v : values ) {
if (vector == null ) {
vector = v. get ( ) ;
} else {
vector. assign (v. get ( ), Functions. PLUS ) ;
}
}
ctx. write (key, new VectorWritable (vector ) ) ;
}
OK,到现在已经得到了< label_index,label_vector >,即类的id和类中所有item(或者说feature)的TFIDF值。此步得到类似如下的输出,
Key: 0
Value: /comp.sys.ibm.pc.hardware/60252:{93562:17.52922821044922,93559:9.745443344116211,93558:107.53932094573975,93557:49.015570640563965,93556:9.745443344116211……}
key:1
Value:
/alt.atheism/53261:{93562:26.293842315673828,93560:19.490886688232422,93559:9.745443344116211,93558:78.52010536193848,93557:62.2713, 93555:14.35555171……}
下一个阶段就是统计每个label的所有ITIDF和,输入为上一步的输出,并由一个job来执行,
//sum up all the weights from the previous step, per label and per feature
Job weightSummer = prepareJob (getTempPath (SUMMED_OBSERVATIONS ),
getTempPath (WEIGHTS ),
SequenceFileInputFormat. class,
WeightsMapper. class,
Text. class,
VectorWritable. class,
VectorSumReducer. class,
Text. class,
VectorWritable. class,
SequenceFileOutputFormat. class ) ;
weightSummer. getConfiguration ( ). set (WeightsMapper. NUM_LABELS, String. valueOf (labelSize ) ) ;
weightSummer. setCombinerClass (VectorSumReducer. class ) ;
succeeded = weightSummer. waitForCompletion ( true ) ;
if ( !succeeded ) {
return - 1 ;
}
job的mapper为WeightsMapper,reducer与上一步的相同,为VectorSumReducer。
mapper如下,
protected void map (IntWritable index, VectorWritable value, Context ctx ) throws IOException, InterruptedException {
Vector instance = value. get ( ) ;
if (weightsPerFeature == null ) {
weightsPerFeature = new RandomAccessSparseVector (instance. size ( ), instance. getNumNondefaultElements ( ) ) ;
}
int label = index. get ( ) ;
weightsPerFeature. assign (instance, Functions. PLUS ) ;
weightsPerLabel. set (label, weightsPerLabel. get (label ) + instance. zSum ( ) ) ;
}
此步的输出写在cleanup()中。
protected void cleanup ( Context ctx ) throws IOException, InterruptedException {
if (weightsPerFeature != null ) {
ctx. write ( new Text (TrainNaiveBayesJob. WEIGHTS_PER_FEATURE ),
new VectorWritable (weightsPerFeature ) ) ;
ctx. write ( new Text (TrainNaiveBayesJob. WEIGHTS_PER_LABEL ),
new VectorWritable (weightsPerLabel ) ) ;
}
super. cleanup (ctx ) ;
}
也就是说输出只有两个key/value.
一个是WEIGHTS_PER_FEATURE(定义的常量,__SPF)
一个是WEIGHTS_PER_LABEL(__SPL)
weightsPerFeature其实就是保持上一步的vector没变,仍然是一个类中所有iterm(feature)的TFIDF。
weightsPerLabel就是求每个label中的和了。
可以看到输出为,
Key: __SPF
Value: {93562:43.82307052612305,93560:19.490886688232422,93559:19.490886688232422,93558:186.05942630767822,93557:111.28696632385254,93556:9.745443344116211……}
Key: __SPL
Value: {1:7085520.472989678,0:4662610.912284017}
最后一步,先看源代码,
//calculate the Thetas, write out to LABEL_THETA_NORMALIZER vectors
//-- TODO: add reference here to the part of the Rennie paper that discusses this
Job thetaSummer =
prepareJob (getTempPath (SUMMED_OBSERVATIONS ), getTempPath (THETAS ),
SequenceFileInputFormat. class,
ThetaMapper. class,
Text. class,
VectorWritable. class,
VectorSumReducer. class,
Text. class,
VectorWritable. class,
SequenceFileOutputFormat. class ) ;
thetaSummer. setCombinerClass (VectorSumReducer. class ) ;
thetaSummer. getConfiguration ( ). setFloat (ThetaMapper. ALPHA_I, alphaI ) ;
thetaSummer. getConfiguration ( ). setBoolean (ThetaMapper. TRAIN_COMPLEMENTARY, trainComplementary ) ;
/* TODO(robinanil): Enable this when thetanormalization works.
succeeded = thetaSummer.waitForCompletion(true);
if (!succeeded) {
return -1;
}*/
可以看到thetaSummer.waitForCompletion(true)被注释掉了,job没有执行。注释里面说的Rennie paper指的就是mahout bayes算法参考的这篇论文:Tackling the Poor Assumptions of Naive Bayes Text Classifiers,论文里面有个求Ɵ的公式如下。不知为何注释掉?求解。
最最后一步,其实model有weightsPerFeature和weightsPerLabel就完成了。这一步也就是把它们变成矩阵形式,如下,每行一个权重vector。
____|item1,iterm2,item3……
lab1|
lab2|
……
源代码如下,
//得到SparseMatrix矩阵
NaiveBayesModel naiveBayesModel = BayesUtils. readModelFromDir (getTempPath ( ), getConf ( ) ) ;
naiveBayesModel. validate ( ) ;
//序列化,写到output/naiveBayesModel.bin
naiveBayesModel. serialize (getOutputPath ( ), getConf ( ) ) ;
THE END
http://hnote.org/big-data/mahout/mahout-train-naive-bayes-job
http://soledede.com/
微信公众号: