示例
import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs, img2x, file2img } from './utils.js';
$(async () => {
const { inputs, labels } = await getInputs();
console.log(inputs, labels);
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 255 } });
const NUM_CLASSES = 3;
inputs.forEach(imgEl => {
surface.drawArea.appendChild(imgEl);
});
const MOBILENET_MODEL_PATH = "http://127.0.0.1:8080/mobilenet/web_model/model.json";
//加载外部模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
//截断模型
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
//将处理结果返回给第二个模型
const model = tf.sequential();
model.add(tf.layers.flatten({
//去掉首位图片数目,现在没数据,值为null
inputShape: layer.outputShape.slice(1)
}));
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
//设置损失函数,优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
//把输入数据输入给截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map((item) => truncatedMobilenet.predict(img2x(item))));
const ys = tf.tensor(labels);
return { xs, ys };
});
model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
window.download = async () => {
await model.save('downloads://model');
}
});
html 部分
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Document</title>
</head>
<body>
<div>图标识别</div>
<input type="file" onchange="predict(this.files[0])">
<button onclick="download()">保存模型</button>
</body>
<script src="./t7.js"></script>
</html>
util.js
import * as tf from '@tensorflow/tfjs';
//载入测试图片的方法↓↓↓↓↓↓↓↓↓↓↓
const IMAGE_SIZE = 224;
const loadImg = (src) => {
return new Promise(resolve => {
const img = new Image();
img.crossOrigin = "anonymous";
img.src = src;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => resolve(img);
});
};
export const getInputs = async () => {
const loadImgs = [];
const labels = [];
for (let i = 0; i < 30; i += 1) {
['android', 'apple', 'windows'].forEach(label => {
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
const img = loadImg(src);
loadImgs.push(img);
labels.push([
label === 'android' ? 1 : 0,
label === 'apple' ? 1 : 0,
label === 'windows' ? 1 : 0,
]);
});
}
const inputs = await Promise.all(loadImgs);
return {
inputs,
labels,
};
}
//载入测试图片的方法↑↑↑↑↑↑↑↑↑↑↑
//图片格式转换↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
//图片格式转换↑↑↑↑↑↑↑↑↑↑↑
执行结果