0
点赞
收藏
分享

微信扫一扫

tensorflow.js基本使用 iris(四)

归零者245号 2022-01-16 阅读 50
javascript

示例

$(async () => {
  $('#fi').on('submit',()=>{
    console.log($('#fi a').val());
    console.log($('#fi b').val());
    console.log($('#fi c').val());
    console.log($('#fi d').val());
    if(window.predict){
      window.predict({
        a:$('#fi #a').val(),
        b:$('#fi #b').val(),
        c:$('#fi #c').val(),
        d:$('#fi #d').val()
      });
    }else{
      alert('模型正在训练');
    }
    return false;
  });

  const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
  // xTrain.print();
  // yTrain.print();
  // xTest.print();
  // yTest.print();

  //定义连续模型sequential
  const model = tf.sequential();

  //设置层,全连接层tf.layer.dense
  model.add(tf.layers.dense({
    units: 10,
    inputShape: [xTrain.shape[1]],
    activation: 'sigmoid'
  }));

  //分为三类
  model.add(tf.layers.dense({
    units: 3,
    activation: 'softmax'//激活函数,处理非线性变化,适用于多分类
    // activation: 'sigmoid'//激活函数,处理非线性变化,适用于二分类
  }));

  //设置优化器
  model.compile({
    loss: 'categoricalCrossentropy',//交叉熵,适用于多分类
    optimizer: tf.train.adam(0.1),//优化器
    metrics: ['accuracy']//准确度
  });

  //训练模型
  await model.fit(xTrain, yTrain, {
    epochs: 100,
    validationData: [xTest, yTest],
    callbacks: tfvis.show.fitCallbacks(
      { name: '训练效果' },
      ['loss', 'val_loss', 'acc', 'val_acc'],
      { callbacks: ['onEpochEnd'] }
    )
  });

  window.predict=(form)=>{
    const input=tf.tensor([[
      form.a*1,
      form.b*1,
      form.c*1,
      form.d*1
    ]]);
    const pred=model.predict(input);
    alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);//第二位最大值
  }
});

html部分

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Document</title>
  <script src="js/tensorflow/tfjs.js"></script>
  <script src="js/tensorflow/tfjs-vis.js"></script>
  <script src="js/jquery/jquery.js"></script>
  <script src="js/iris-data/data.js"></script>
</head>
<body>
  <div>iris</div>
  <form id="fi">
    <label for="a">
      花萼长度:<input type="text" name="a" id="a">
    </label><br />
    <label for="b">
      花萼宽度:<input type="text" name="b" id="b">
    </label><br />
    <label for="c">
      花瓣长度:<input type="text" name="c" id="c">
    </label><br />
    <label for="d">
      花瓣宽度:<input type="text" name="d" id="d">
    </label><br />
    <button>提交</button>
  </form>
</body>
<script src="js/iris.js"></script>
</html>
举报

相关推荐

0 条评论