0
点赞
收藏
分享

微信扫一扫

Spark NaiveBayes Demo 朴素贝叶斯分类算法


  • 模型sql文件 :https://pan.baidu.com/s/1hugrI9e
  • 使用数据链接 https://pan.baidu.com/s/1kWz8fNh
  • NaiveBayes Spark Mllib训练 

package com.xxx.xxx.xxx

import java.io.ObjectInputStream
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.{Arrays, Date, Scanner}

import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable.ArrayBuffer

/**
  * Created by Zsh on 1/31 0031.
  */
object WudiBayesModel {
  var conn: Connection = null
  var stmt: PreparedStatement = null
  val outputPath = "/zshModel"
  val driverUrl = "jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8&autoDeserialize=true"
  var df :DataFrame = null
  val classify:String = "健康"
  var model:NaiveBayesModel = null
  val lables = "社会|体育|汽车|女性|新闻|科技|财经|军事|广告|娱乐|健康|教育|旅游|生活|文化"

  def main(args: Array[String]): Unit ={

    //    training(lables)


//    val text = "今年“五一”假期虽然缩短为三天,但来个“周边游”却正逢其时。昨日记者从保定市旅游部门了解到,5月1日—3日,该市满城汉墓景区将举办“2008全国滑翔伞俱乐部联赛”第一站比赛。届时将有来自全国各地的滑翔伞高手云集陵山,精彩上演自由飞翔。\uE40C滑翔伞俱乐部举办联赛为我国第一次,所有参赛运动员均在国际或国内大型比赛中取得过名次,并且所有运动员必须持2008年贴花的中国航空运动协会会员证书和B级以上滑翔伞运动证书,所使用的比赛动作均为我国滑翔伞最新动作。本届比赛的项目,除保留传统的“留空时间赛”和“精确着陆赛”以外,还增加了“盘升高度赛”等内容。届时,参赛运动员将冲击由保定市运动员韩均创造的1450米的盘升高度记录。截至目前,已有11个省市的50多名运动员报名参赛,其中包括多名外籍运动员和7名女运动员。 (来源:燕赵晚报)\uE40C(责任编辑:李妍)"
    //测试mysql读取的模型
//    BayesUtils.testMysql(text,lables)

    //    inputTestModel()

  }
  //手动输入数据测试模型
  def inputTestModel(): Unit ={
    val scan = new Scanner(System.in)
    val startTime = new Date().getTime
    loadModel
    val time2 =new Date().getTime
    println("加载模型时间:"+(time2-startTime))
    println("模型加载完毕-----")
    while(true) {
      val str = scan.nextLine()
      testData(model,str,lables)
      println("---------------------------------")
    }
  }
  //批量测试某个类准确率
  def batchTesting(): Unit ={
    var time2 =new Date().getTime
    val result = df.map(x=>testData(model,x.getAs("content").toString,lables))
    var time3 =new Date().getTime
    println("预测需要时间:"+ (time3-time2))
    println("准确率:"      + result.filter(_.equals(classify)).count().toDouble/result.count())
  }

  //加载模型
  def loadModel(){
    val conf = new SparkConf().setAppName("NaiveBayesExample1")
      .setMaster("local")
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .set("spark.kryoserializer.buffer.max", "1024mb")
    val sc =new SparkContext(conf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    val model = NaiveBayesModel.load(sc,outputPath)
    val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
    jdbcDF.registerTempTable("testData")
    val sql = "select content from testData where classify in ('"+classify+"')"
    df = sqlContext.sql(sql)
  }

  def testModel()={
    val conf = new SparkConf().setAppName("NaiveBayesExample1")
      .setMaster("local")
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .set("spark.kryoserializer.buffer.max", "1024mb")
    val sc =new SparkContext(conf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    val model = NaiveBayesModel.load(sc,outputPath)
    model
  }
  //测试数据类型
  def testData(model :NaiveBayesModel,text:String,labels_name:String)={
    //    val text= "新浪微博采集"
    val dim = math.pow(2, 20).toInt
    val hashingTF= new HashingTF(dim)
    val tfVector = hashingTF.transform(tokenizer(text))
    val d = model.predict(tfVector)
    //    val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"
    val list2 = labels_name.split("\\|").toList
    //println(list2(d.toInt))
    println("result:"+list2(d.toInt) + " " + d + " " + text)
    list2(d.toInt)
  }
  //训练模型
  def training(labels_name:String): Unit ={

    //全部类型标签
    //    val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"
    val list2 = labels_name.split("\\|").toList
    //标签转化list对应(0 - list.length)的list
    var num=0.0 to labels_name.split("\\|").length.toDouble by 1 toList
    val tuples = list2.zip(num).toMap
    val temp= labels_name.split("\\|").toList   //.toList.zip(0 to labelsname.split("\\|").length)
    var str:String = ""
    for(i<-0 to temp.length-1){
      if(i<temp.length-1)
        str=str+"""""""+temp(i)+"""","""
      else
        str=str+"""""""+temp(i)+"""""""
    }

    val conf = new SparkConf().setAppName("NaiveBayesExample1")
      .setMaster("local[4]")
    //      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    //      .set("spark.kryoserializer.buffer.max", "1024mb")
    val sc =new SparkContext(conf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
    jdbcDF.registerTempTable("testData")
    val sql = "select content,classify from testData where classify in ("+str+")"
    //    println("str"+str)
    //val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()
    //    jdbcDF.registerTempTable("testData")
    //    val sql = "select content,classify from testData where classify in ('"+str+"')"
    //    df = sqlContext.sql(sql)
    //从Mysql读取训练模型所需数据
    val trainData = sqlContext.sql(sql)
    //    val trainData = df
    println("trcount:"+trainData.count())
    //获取正文与标签字段,eg.(text,教育)
    val trainData1 = trainData.map(x=>(x.getAs("content").toString,x.getAs("classify").toString))
    //将正文分词,标签用数字替换
    val trainData2 = trainData1.map(x=>(tokenizer(x._1),tuples(x._2)))
    //tfidf训练所需的部分
    val cData = trainData2.map(_._1)
    //标签字段(1,2,3,4,5,6,7,8...)
    val clData = trainData2.map(_._2)
    //设置向量维度,该值越大模型占用空间越大,河里设置该值
    val dim = math.pow(2, 16).toInt
    //计算TF
    val hashTF= new HashingTF(dim)
    val tf = hashTF.transform(cData).cache()
    //计算idf
    val hashIDF = new IDF().fit(tf)
    val idf = hashIDF.transform(tf)
    //将计算后的向量与标签字段关联
    val zip = clData.zip(idf)
    //转化为可训练的类型LabeledPoint
    val tData = zip.map{case (label,vector) =>LabeledPoint(label,vector)}
    //切分数据60%训练数据,40%验证数据
    val splits = tData.randomSplit(Array(0.7, 0.3), seed = 11L)
    val trData = splits(0).cache()
    val teData = splits(1).cache()

    val model = NaiveBayes.train(trData,lambda = 1.0, modelType = "multinomial")
    model.save(sc,outputPath)
    println("save model success !")
    //将model转换BayesModelData2,保存到mysql
    val data = BayesModelData2(model.labels.toArray,
      model.pi.toArray,model.theta.map(_.toArray).toArray,"multinomial")
    //保存到mysql
    serializeToMysql(data)

    //===============模型验证
    val testAndLabel = teData.map(x=>{(model.predict(x.features),x.label)})
    //    println("****************************")
    //    testAndLabel.foreach(println)
    //    println("****************************")
    val total = testAndLabel.count()
    //已知分类
    val totalPostiveNum = testAndLabel.filter(x => x._2 == 11.0).count()
    //预测结果
    val totalTrueNum = testAndLabel.filter(x => x._1 == 11.0).count()
    //某一类别预测正确数
    val testRealTrue = testAndLabel.filter(x => x._1 == x._2 && x._2 == 11.0).count()
    //全部预测正确数
    val testReal = testAndLabel.filter(x => x._1 == x._2).count()

    val testAccuracy = 1.0 * testReal / total
    val testPrecision = 1.0 * testRealTrue / totalTrueNum
    val testRecall = 1.0 * testRealTrue / totalPostiveNum

    println("统计分类准确率:============================")
    println("准确率:", testAccuracy) //预测正确数/预测总数     Accuracy=(TP+TN)/(TP+FP+TN+FN) Error= (FP+FN)/(TP+FP+TN+FN)
    println("精确度:", testPrecision) //预测为P实际T/实际为P 查准率  Precision=TP/(TP+FP)
    println("召回率:", testRecall) //预测为P实际T/实际为T 查全率   Recall=TP/(TP+FN)
    //    val accuracy = 1.0 * testAndLabel.filter(x => x._1 == x._2).count() / teData.count()
    println("模型准确度============================")

  }
  def tokenizer(line: String): Seq[String] = {
    val reg1 = "@\\w{2,20}:".r
    val reg2 = "http://[0-9a-zA-Z/\\?&#%$@\\=\\\\]+".r
    AnsjSegment(line)
      .split(",")
      .filter(_!=null)
      .filter(token => !reg1.pattern.matcher(token).matches)
      .filter(token => !reg2.pattern.matcher(token).matches)
      //            .filter(token => !stopwordSet.contains(token))
      .toSeq
  }
  def AnsjSegment(line: String): String={
    val StopNatures="""w","",null,"s", "f", "b", "z", "r", "q", "d", "p", "c", "uj", "ul","en", "y", "o", "h", "k", "x"""
    val KeepNatures=List("n","v","a","m","t")
    val StopWords=Arrays.asList("的", "是","了")  //Arrays.asList(stopwordlist.toString())
    //val filter = new FilterRecognition()
    //加入停用词
    //filter.insertStopWords(StopWords)
    //加入停用词性
    //filter.insertStopNatures(StopNatures)
    //filter.insertStopRegex("小.*?")
    //此步骤将会只取分词,不附带词性
    //for (i <- Range(0, filter1.size())) {
    //word += words.get(i).getName
    //}
    val words = ToAnalysis.parse(line)
    val word = ArrayBuffer[String]()
    for (i <- Range(0,words.size())) { //KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&
      if(KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&words.get(i).getName.length()>=2)
        word += words.get(i).getName
    }
    //    println(word)
    word.mkString(",")
  }
  //保存到mysql
  def serializeToMysql[T](o: BayesModelData2) {  //文件序列化
    val model_Id = "test"
    new MysqlConn()
    val query="replace into "+"ams_recommender_model"+"(model_ID,model_Data) values (?,?)"
    stmt=conn.prepareStatement(query)
    stmt.setString(1, model_Id)
    stmt.setObject(2,o)
    stmt.executeQuery()
    conn.close()
  }
  class MysqlConn() {
    val trainning_url="jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8"
    try {
      //当前使用训练和输出同一个url,以后可以分为两个
      conn = DriverManager.getConnection(trainning_url, "xxx", "xxx")
    } catch {
      case e: Exception => println("mysql连接异常")
    }
  }
  //从mysql取出,并将类型转换
  def deserializeFromMysql[T](): BayesModelData2 = {  //文件反序列化 bytes: Array[Byte]
    new MysqlConn()
    val model_Id = "test"
    val query="select model_Data from "+"ams_recommender_model"+" where model_ID='"+ model_Id +"' "
    stmt=conn.prepareStatement(query)
    val resultSet = stmt.executeQuery()
    resultSet.next()
    val bis= resultSet.getBlob("model_Data").getBinaryStream()
    val ois = new ObjectInputStream(bis)
    conn.close()
    ois.readObject.asInstanceOf[BayesModelData2]
  }
}

 

  • 调用BayesUtils类,目录必须是org.apache.spark ,因为NaiveBayesModel是private[spark]私有的
  • 参考:How to use BLAS library in Spark (Symbol BLAS is inaccessible from this space) - spark:http://note.youdao.com/noteshare?id=7f1eec90cc6e56303d06ff92422c29b6&sub=wcp151747625212826

调用

package org.apache.spark

import com.izhonghong.mission.learn.BayesModelData2
import com.izhonghong.mission.learn.WudiBayesModel.{deserializeFromMysql, tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Vector}

/**
 * Created by Zsh on 2/1 0001.
 */
object BayesUtils {
 //测试mysql读取出的模型
 def testMysql(text:String,labels_name:String){
   val hashingTF= new HashingTF()
   val tfVector = hashingTF.transform(tokenizer(text))
   val BayesModelData2 = deserializeFromMysql()
   val model = new NaiveBayesModel(BayesModelData2.labels,BayesModelData2.pi,BayesModelData2.theta,BayesModelData2.modelType)
   val d = model.predict(tfVector)
//    val d = predict(BayesModelData2,tfVector)
   val list2 = labels_name.split("\\|").toList
   list2(d.toInt)
   println("result:"+list2(d.toInt) + " " + d + " " + text)

 }

 //预测返回类别,NaiveBayesModel源码中提取,最初因为NaiveBayesModel无法引用,后来讲源码提取出来发现,在spark目录下就可以new NaiveBayesModel
 def predict(bayesModel :BayesModelData2,tfVector:Vector)={
   val thetaMatrix = new DenseMatrix(bayesModel.labels.length, bayesModel.theta(0).length, bayesModel.theta.flatten, true)
   val piVector = new DenseVector(bayesModel.pi)
   val prob = thetaMatrix.multiply(tfVector)
   org.apache.spark.mllib.linalg.BLAS.axpy( 1.0, piVector, prob)
   val d = bayesModel.labels(prob.argmax)
   d
 }





}

  •  主要配置文件

<dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.10</artifactId>
            <version>1.6.0</version>
        </dependency>

        <dependency>
            <groupId>org.ansj</groupId>
            <artifactId>ansj_seg</artifactId>
            <version>5.0.4</version>
        </dependency>

全部配置文件

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xxx</groupId>
    <artifactId>xxx-xxx-xxx</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.6</maven.compiler.source>
        <maven.compiler.target>1.6</maven.compiler.target>
        <encoding>UTF-8</encoding>
        <scala.tools.version>2.10</scala.tools.version>
        <scala.version>2.10.6</scala.version>
        <hbase.version>1.2.2</hbase.version>
    </properties>

    <dependencies>
       <!-- <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.1.0</version>
        </dependency>-->
        <!--<dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>1.6.0</version>
        </dependency>-->
       <!-- <dependency>
            <groupId>com.hankcs</groupId>
            <artifactId>hanlp</artifactId>
            <version>portable-1.5.0</version>
        </dependency>-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.10</artifactId>
            <version>1.6.0</version>
        </dependency>

        <dependency>
            <groupId>org.ansj</groupId>
            <artifactId>ansj_seg</artifactId>
            <version>5.0.4</version>
        </dependency>


        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.10.6</version>
        </dependency>
        <dependency>
            <groupId>org.apache.kafka</groupId>
            <artifactId>kafka-clients</artifactId>
            <version>0.10.0.0</version>
        </dependency>



        <dependency>
            <groupId>net.sf.json-lib</groupId>
            <classifier>jdk15</classifier>
            <artifactId>json-lib</artifactId>
            <version>2.4</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>

        <!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.10</artifactId>
            <version>2.1.1</version> </dependency> -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.10</artifactId>
            <version>1.6.2</version>
            <exclusions>
                <exclusion>
                    <artifactId>scala-library</artifactId>
                    <groupId>org.scala-lang</groupId>
                </exclusion>
            </exclusions>
        </dependency>

        <!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId>
            <version>2.1.1</version> <scope>provided</scope> </dependency> -->
        <dependency>
            <groupId>com.huaban</groupId>
            <artifactId>jieba-analysis</artifactId>
            <version>1.0.2</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.14</version>
        </dependency>


        <dependency>
            <groupId>redis.clients</groupId>
            <artifactId>jedis</artifactId>
            <version>2.9.0</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>${scala.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-server</artifactId>
            <version>1.2.2</version>
            <exclusions>
                <exclusion>
                    <artifactId>servlet-api-2.5</artifactId>
                    <groupId>org.mortbay.jetty</groupId>
                </exclusion>
            </exclusions>
        </dependency>
      <!--  <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.18</version>
        </dependency>-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.10</artifactId>
            <version>1.6.2</version>
            <!-- <version>2.1.1</version> -->
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-client</artifactId>
            <version>2.7.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-common</artifactId>
            <version>2.7.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-hdfs</artifactId>
            <version>2.7.0</version>
            <exclusions>
                <exclusion>
                    <groupId>javax.servlet.jsp</groupId>
                    <artifactId>*</artifactId>
                </exclusion>
                <exclusion>
                    <artifactId>servlet-api</artifactId>
                    <groupId>javax.servlet</groupId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.39</version>
        </dependency>
        <!--<dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-server</artifactId>
            <version>1.2.2</version>
        </dependency>-->

        <!-- Test -->
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.specs2</groupId>
            <artifactId>specs2_${scala.tools.version}</artifactId>
            <version>1.13</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.scalatest</groupId>
            <artifactId>scalatest_${scala.tools.version}</artifactId>
            <version>2.0.M6-SNAP8</version>
            <scope>test</scope>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>net.alchim31.maven</groupId>
                <artifactId>scala-maven-plugin</artifactId>
                <version>3.2.0</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>compile</goal>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-jar-plugin</artifactId>
                <configuration>
                    <archive>
                        <manifest>
                            <addClasspath>true</addClasspath>
                            <classpathPrefix>lib/</classpathPrefix>
                            <mainClass></mainClass>
                        </manifest>
                    </archive>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-dependency-plugin</artifactId>
                <executions>
                    <execution>
                        <id>copy</id>
                        <phase>package</phase>
                        <goals>
                            <goal>copy-dependencies</goal>
                        </goals>
                        <configuration>
                            <outputDirectory>${project.build.directory}/lib</outputDirectory>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
    <!-- <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId>
        <configuration> <archive> <manifest> 这里要替换成jar包main方法所在类 <mainClass>com.sf.pps.client.IntfClientCall</mainClass>
        </manifest> <manifestEntries> <Class-Path>.</Class-Path> </manifestEntries>
        </archive> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef>
        </descriptorRefs> </configuration> <executions> <execution> <id>make-assembly</id>
        this is used for inheritance merges <phase>package</phase> 指定在打包节点执行jar包合并操作
        <goals> <goal>single</goal> </goals> </execution> </executions> </plugin>
        </plugins> </build> -->

</project>

 

 

举报

相关推荐

0 条评论