object DecisionTree {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local").setAppName("tree")
val sc = SparkSession.builder().config(conf).getOrCreate()
val lines = sc.sparkContext.textFile("tree.txt")
val tf = new HashingTF(10000)
val transdata = lines.map {
lien => {
val t1 = lien.split(",")
val t2 = t1(1).split(" ")
LabeledPoint(t1(0).toDouble, tf.transform(t2))
}
}
val Array(data,test) = ansdata.randomSplit(Array(0.8,0.2),seed = 1L)
// 设置决策树参数,训练模型
val numClasses = 2;
//不纯度的类型,有基尼不纯度——“gini”,熵——“entropy”
val impurity = "gini";
//对层数进行限制,避免过拟合
val maxDepth = 5;
//决策规则集,可以理解成是决策树的孩子节点的数量
val maxBins = 32;
val tree_model = DecisionTree.trainClassifier(data, numClasses,
Map[Int,Int](), impurity, maxDepth, maxBins);
System.out.println("决策树模型:")
System.out.println(tree_model.toDebugString)
val datatest = "30 帅 收入中等 是公务员"
val strings = datatest.split(" ")
val d = tree_model.predict(tf.transform(strings))
if (d == 1.0){
print("去相亲")
}
if (d == 0.0){
print("不去相亲")
}
}
tree.text
0,32 帅 收入中等 不是公务员
1,25 帅 收入中等 是公务员
0,25 帅 收入中等 不是公务员
1,29 帅 收入中等 是公务员
1,24 帅 收入高 不是公务员
0,31 帅 收入高 不是公务员
0,35 帅 收入中等 是公务员
0,30 不帅 收入中等 不是公务员
0,31 帅 收入高 不是公务员
1,30 帅 收入中等 是公务员
1,21 帅 收入高 不是公务员
0,21 帅 收入中等 不是公务员
1,21 帅 收入中等 是公务员
0,29 不帅 收入中等 是公务员
0,29 帅 收入底 是公务员
0,29 不帅 收入低 是公务员
1,30 帅 收入高 不是公务员