- 模型sql文件 :https://pan.baidu.com/s/1hugrI9e
- 使用数据链接 https://pan.baidu.com/s/1kWz8fNh
- NaiveBayes Spark Mllib训练
package com.xxx.xxx.xxx
import java.io.ObjectInputStream
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.{Arrays, Date, Scanner}
import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
import org.apache.spark.{SparkConf, SparkContext}
import scala.collection.mutable.ArrayBuffer
/**
* Created by Zsh on 1/31 0031.
*/
object WudiBayesModel {
var conn: Connection = null
var stmt: PreparedStatement = null
val outputPath = "/zshModel"
val driverUrl = "jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8&autoDeserialize=true"
var df :DataFrame = null
val classify:String = "健康"
var model:NaiveBayesModel = null
val lables = "社会|体育|汽车|女性|新闻|科技|财经|军事|广告|娱乐|健康|教育|旅游|生活|文化"
def main(args: Array[String]): Unit ={
// training(lables)
// val text = "今年“五一”假期虽然缩短为三天,但来个“周边游”却正逢其时。昨日记者从保定市旅游部门了解到,5月1日—3日,该市满城汉墓景区将举办“2008全国滑翔伞俱乐部联赛”第一站比赛。届时将有来自全国各地的滑翔伞高手云集陵山,精彩上演自由飞翔。\uE40C滑翔伞俱乐部举办联赛为我国第一次,所有参赛运动员均在国际或国内大型比赛中取得过名次,并且所有运动员必须持2008年贴花的中国航空运动协会会员证书和B级以上滑翔伞运动证书,所使用的比赛动作均为我国滑翔伞最新动作。本届比赛的项目,除保留传统的“留空时间赛”和“精确着陆赛”以外,还增加了“盘升高度赛”等内容。届时,参赛运动员将冲击由保定市运动员韩均创造的1450米的盘升高度记录。截至目前,已有11个省市的50多名运动员报名参赛,其中包括多名外籍运动员和7名女运动员。 (来源:燕赵晚报)\uE40C(责任编辑:李妍)"
//测试mysql读取的模型
// BayesUtils.testMysql(text,lables)
// inputTestModel()
}
//手动输入数据测试模型
def inputTestModel(): Unit ={
val scan = new Scanner(System.in)
val startTime = new Date().getTime
loadModel
val time2 =new Date().getTime
println("加载模型时间:"+(time2-startTime))
println("模型加载完毕-----")
while(true) {
val str = scan.nextLine()
testData(model,str,lables)
println("---------------------------------")
}
}
//批量测试某个类准确率
def batchTesting(): Unit ={
var time2 =new Date().getTime
val result = df.map(x=>testData(model,x.getAs("content").toString,lables))
var time3 =new Date().getTime
println("预测需要时间:"+ (time3-time2))
println("准确率:" + result.filter(_.equals(classify)).count().toDouble/result.count())
}
//加载模型
def loadModel(){
val conf = new SparkConf().setAppName("NaiveBayesExample1")
.setMaster("local")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryoserializer.buffer.max", "1024mb")
val sc =new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val model = NaiveBayesModel.load(sc,outputPath)
val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
jdbcDF.registerTempTable("testData")
val sql = "select content from testData where classify in ('"+classify+"')"
df = sqlContext.sql(sql)
}
def testModel()={
val conf = new SparkConf().setAppName("NaiveBayesExample1")
.setMaster("local")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryoserializer.buffer.max", "1024mb")
val sc =new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val model = NaiveBayesModel.load(sc,outputPath)
model
}
//测试数据类型
def testData(model :NaiveBayesModel,text:String,labels_name:String)={
// val text= "新浪微博采集"
val dim = math.pow(2, 20).toInt
val hashingTF= new HashingTF(dim)
val tfVector = hashingTF.transform(tokenizer(text))
val d = model.predict(tfVector)
// val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"
val list2 = labels_name.split("\\|").toList
//println(list2(d.toInt))
println("result:"+list2(d.toInt) + " " + d + " " + text)
list2(d.toInt)
}
//训练模型
def training(labels_name:String): Unit ={
//全部类型标签
// val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"
val list2 = labels_name.split("\\|").toList
//标签转化list对应(0 - list.length)的list
var num=0.0 to labels_name.split("\\|").length.toDouble by 1 toList
val tuples = list2.zip(num).toMap
val temp= labels_name.split("\\|").toList //.toList.zip(0 to labelsname.split("\\|").length)
var str:String = ""
for(i<-0 to temp.length-1){
if(i<temp.length-1)
str=str+"""""""+temp(i)+"""","""
else
str=str+"""""""+temp(i)+"""""""
}
val conf = new SparkConf().setAppName("NaiveBayesExample1")
.setMaster("local[4]")
// .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// .set("spark.kryoserializer.buffer.max", "1024mb")
val sc =new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
jdbcDF.registerTempTable("testData")
val sql = "select content,classify from testData where classify in ("+str+")"
// println("str"+str)
//val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
// jdbcDF.registerTempTable("testData")
// val sql = "select content,classify from testData where classify in ('"+str+"')"
// df = sqlContext.sql(sql)
//从Mysql读取训练模型所需数据
val trainData = sqlContext.sql(sql)
// val trainData = df
println("trcount:"+trainData.count())
//获取正文与标签字段,eg.(text,教育)
val trainData1 = trainData.map(x=>(x.getAs("content").toString,x.getAs("classify").toString))
//将正文分词,标签用数字替换
val trainData2 = trainData1.map(x=>(tokenizer(x._1),tuples(x._2)))
//tfidf训练所需的部分
val cData = trainData2.map(_._1)
//标签字段(1,2,3,4,5,6,7,8...)
val clData = trainData2.map(_._2)
//设置向量维度,该值越大模型占用空间越大,河里设置该值
val dim = math.pow(2, 16).toInt
//计算TF
val hashTF= new HashingTF(dim)
val tf = hashTF.transform(cData).cache()
//计算idf
val hashIDF = new IDF().fit(tf)
val idf = hashIDF.transform(tf)
//将计算后的向量与标签字段关联
val zip = clData.zip(idf)
//转化为可训练的类型LabeledPoint
val tData = zip.map{case (label,vector) =>LabeledPoint(label,vector)}
//切分数据60%训练数据,40%验证数据
val splits = tData.randomSplit(Array(0.7, 0.3), seed = 11L)
val trData = splits(0).cache()
val teData = splits(1).cache()
val model = NaiveBayes.train(trData,lambda = 1.0, modelType = "multinomial")
model.save(sc,outputPath)
println("save model success !")
//将model转换BayesModelData2,保存到mysql
val data = BayesModelData2(model.labels.toArray,
model.pi.toArray,model.theta.map(_.toArray).toArray,"multinomial")
//保存到mysql
serializeToMysql(data)
//===============模型验证
val testAndLabel = teData.map(x=>{(model.predict(x.features),x.label)})
// println("****************************")
// testAndLabel.foreach(println)
// println("****************************")
val total = testAndLabel.count()
//已知分类
val totalPostiveNum = testAndLabel.filter(x => x._2 == 11.0).count()
//预测结果
val totalTrueNum = testAndLabel.filter(x => x._1 == 11.0).count()
//某一类别预测正确数
val testRealTrue = testAndLabel.filter(x => x._1 == x._2 && x._2 == 11.0).count()
//全部预测正确数
val testReal = testAndLabel.filter(x => x._1 == x._2).count()
val testAccuracy = 1.0 * testReal / total
val testPrecision = 1.0 * testRealTrue / totalTrueNum
val testRecall = 1.0 * testRealTrue / totalPostiveNum
println("统计分类准确率:============================")
println("准确率:", testAccuracy) //预测正确数/预测总数 Accuracy=(TP+TN)/(TP+FP+TN+FN) Error= (FP+FN)/(TP+FP+TN+FN)
println("精确度:", testPrecision) //预测为P实际T/实际为P 查准率 Precision=TP/(TP+FP)
println("召回率:", testRecall) //预测为P实际T/实际为T 查全率 Recall=TP/(TP+FN)
// val accuracy = 1.0 * testAndLabel.filter(x => x._1 == x._2).count() / teData.count()
println("模型准确度============================")
}
def tokenizer(line: String): Seq[String] = {
val reg1 = "@\\w{2,20}:".r
val reg2 = "http://[0-9a-zA-Z/\\?&#%$@\\=\\\\]+".r
AnsjSegment(line)
.split(",")
.filter(_!=null)
.filter(token => !reg1.pattern.matcher(token).matches)
.filter(token => !reg2.pattern.matcher(token).matches)
// .filter(token => !stopwordSet.contains(token))
.toSeq
}
def AnsjSegment(line: String): String={
val StopNatures="""w","",null,"s", "f", "b", "z", "r", "q", "d", "p", "c", "uj", "ul","en", "y", "o", "h", "k", "x"""
val KeepNatures=List("n","v","a","m","t")
val StopWords=Arrays.asList("的", "是","了") //Arrays.asList(stopwordlist.toString())
//val filter = new FilterRecognition()
//加入停用词
//filter.insertStopWords(StopWords)
//加入停用词性
//filter.insertStopNatures(StopNatures)
//filter.insertStopRegex("小.*?")
//此步骤将会只取分词,不附带词性
//for (i <- Range(0, filter1.size())) {
//word += words.get(i).getName
//}
val words = ToAnalysis.parse(line)
val word = ArrayBuffer[String]()
for (i <- Range(0,words.size())) { //KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&
if(KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&words.get(i).getName.length()>=2)
word += words.get(i).getName
}
// println(word)
word.mkString(",")
}
//保存到mysql
def serializeToMysql[T](o: BayesModelData2) { //文件序列化
val model_Id = "test"
new MysqlConn()
val query="replace into "+"ams_recommender_model"+"(model_ID,model_Data) values (?,?)"
stmt=conn.prepareStatement(query)
stmt.setString(1, model_Id)
stmt.setObject(2,o)
stmt.executeQuery()
conn.close()
}
class MysqlConn() {
val trainning_url="jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8"
try {
//当前使用训练和输出同一个url,以后可以分为两个
conn = DriverManager.getConnection(trainning_url, "xxx", "xxx")
} catch {
case e: Exception => println("mysql连接异常")
}
}
//从mysql取出,并将类型转换
def deserializeFromMysql[T](): BayesModelData2 = { //文件反序列化 bytes: Array[Byte]
new MysqlConn()
val model_Id = "test"
val query="select model_Data from "+"ams_recommender_model"+" where model_ID='"+ model_Id +"' "
stmt=conn.prepareStatement(query)
val resultSet = stmt.executeQuery()
resultSet.next()
val bis= resultSet.getBlob("model_Data").getBinaryStream()
val ois = new ObjectInputStream(bis)
conn.close()
ois.readObject.asInstanceOf[BayesModelData2]
}
}
- 调用BayesUtils类,目录必须是org.apache.spark ,因为NaiveBayesModel是private[spark]私有的
- 参考:How to use BLAS library in Spark (Symbol BLAS is inaccessible from this space) - spark:http://note.youdao.com/noteshare?id=7f1eec90cc6e56303d06ff92422c29b6&sub=wcp151747625212826
调用
package org.apache.spark
import com.izhonghong.mission.learn.BayesModelData2
import com.izhonghong.mission.learn.WudiBayesModel.{deserializeFromMysql, tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Vector}
/**
* Created by Zsh on 2/1 0001.
*/
object BayesUtils {
//测试mysql读取出的模型
def testMysql(text:String,labels_name:String){
val hashingTF= new HashingTF()
val tfVector = hashingTF.transform(tokenizer(text))
val BayesModelData2 = deserializeFromMysql()
val model = new NaiveBayesModel(BayesModelData2.labels,BayesModelData2.pi,BayesModelData2.theta,BayesModelData2.modelType)
val d = model.predict(tfVector)
// val d = predict(BayesModelData2,tfVector)
val list2 = labels_name.split("\\|").toList
list2(d.toInt)
println("result:"+list2(d.toInt) + " " + d + " " + text)
}
//预测返回类别,NaiveBayesModel源码中提取,最初因为NaiveBayesModel无法引用,后来讲源码提取出来发现,在spark目录下就可以new NaiveBayesModel
def predict(bayesModel :BayesModelData2,tfVector:Vector)={
val thetaMatrix = new DenseMatrix(bayesModel.labels.length, bayesModel.theta(0).length, bayesModel.theta.flatten, true)
val piVector = new DenseVector(bayesModel.pi)
val prob = thetaMatrix.multiply(tfVector)
org.apache.spark.mllib.linalg.BLAS.axpy( 1.0, piVector, prob)
val d = bayesModel.labels(prob.argmax)
d
}
}
- 主要配置文件
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>org.ansj</groupId>
<artifactId>ansj_seg</artifactId>
<version>5.0.4</version>
</dependency>
全部配置文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xxx</groupId>
<artifactId>xxx-xxx-xxx</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.6</maven.compiler.source>
<maven.compiler.target>1.6</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.tools.version>2.10</scala.tools.version>
<scala.version>2.10.6</scala.version>
<hbase.version>1.2.2</hbase.version>
</properties>
<dependencies>
<!-- <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.1.0</version>
</dependency>-->
<!--<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>1.6.0</version>
</dependency>-->
<!-- <dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.5.0</version>
</dependency>-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>org.ansj</groupId>
<artifactId>ansj_seg</artifactId>
<version>5.0.4</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.10.6</version>
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka-clients</artifactId>
<version>0.10.0.0</version>
</dependency>
<dependency>
<groupId>net.sf.json-lib</groupId>
<classifier>jdk15</classifier>
<artifactId>json-lib</artifactId>
<version>2.4</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>1.6.2</version>
</dependency>
<!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.10</artifactId>
<version>2.1.1</version> </dependency> -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.10</artifactId>
<version>1.6.2</version>
<exclusions>
<exclusion>
<artifactId>scala-library</artifactId>
<groupId>org.scala-lang</groupId>
</exclusion>
</exclusions>
</dependency>
<!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId>
<version>2.1.1</version> <scope>provided</scope> </dependency> -->
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.14</version>
</dependency>
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-server</artifactId>
<version>1.2.2</version>
<exclusions>
<exclusion>
<artifactId>servlet-api-2.5</artifactId>
<groupId>org.mortbay.jetty</groupId>
</exclusion>
</exclusions>
</dependency>
<!-- <dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.18</version>
</dependency>-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.6.2</version>
<!-- <version>2.1.1</version> -->
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>2.7.0</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>2.7.0</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>2.7.0</version>
<exclusions>
<exclusion>
<groupId>javax.servlet.jsp</groupId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<artifactId>servlet-api</artifactId>
<groupId>javax.servlet</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.10</artifactId>
<version>1.6.2</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.10</artifactId>
<version>1.6.2</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.39</version>
</dependency>
<!--<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-server</artifactId>
<version>1.2.2</version>
</dependency>-->
<!-- Test -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.specs2</groupId>
<artifactId>specs2_${scala.tools.version}</artifactId>
<version>1.13</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.tools.version}</artifactId>
<version>2.0.M6-SNAP8</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.2.0</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<classpathPrefix>lib/</classpathPrefix>
<mainClass></mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy</id>
<phase>package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${project.build.directory}/lib</outputDirectory>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<!-- <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId>
<configuration> <archive> <manifest> 这里要替换成jar包main方法所在类 <mainClass>com.sf.pps.client.IntfClientCall</mainClass>
</manifest> <manifestEntries> <Class-Path>.</Class-Path> </manifestEntries>
</archive> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs> </configuration> <executions> <execution> <id>make-assembly</id>
this is used for inheritance merges <phase>package</phase> 指定在打包节点执行jar包合并操作
<goals> <goal>single</goal> </goals> </execution> </executions> </plugin>
</plugins> </build> -->
</project>