0
点赞
收藏
分享

微信扫一扫

基于TensorFlow2的mnist数据集手写字体识别

徐一村 03-14 23:45 阅读 2

使用版本:

Anaconda 2022.05、Pycharm 2022.02、TensorFlow 2.8.0、Python 3.9.12

训练程序:


#导入模块
import keras
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
#导入数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
class_names=['0','1','2','3','4','5','6','7','8','9']
#图像预处理
x_train4D = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')
x_test4D = x_test.reshape(x_test.shape[0],28,28,1).astype('float32')
#像素标准化
x_train, x_test = x_train4D / 255.0, x_test4D / 255.0
#模型搭建
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='same',
                 input_shape=(28,28,1),  activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Conv2D(filters=36, kernel_size=(5,5), padding='same',
    			 activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10,activation='softmax')
])

#打印模型
print(model.summary())
#训练配置
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])
#开始训练
model.fit(x=x_train, y=y_train, validation_split=0.2,
                        epochs=20, batch_size=300, verbose=2)
#保存模型
model.save('my_model.h5')

测试程序,自己手写的图片放在项目下image文件夹下


#导入模块
import keras
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize as imresize
import tensorflow as tf
#载入模型
new_model = keras.models.load_model('my_model.h5')
new_model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])
new_model.summary()
class_names=['0','1','2','3','4','5','6','7','8','9']
#预测
mypath = 'C:\\users\\admin\\desktop\\mnist-test\\image'
def getimg(mypath):
 listdir = os.listdir(mypath)
 imgs = []
 for p in listdir:
  img = plt.imread(mypath+'\\'+p)
  img = np.abs(img/255-1)
  img = imresize(img, [28, 28])
  imgs.append(img[:,:,0])
 return np.array(imgs),len(imgs)
imgs = getimg(mypath)
test_images = np.reshape(imgs[0],[-1,28,28,1])
predictions = new_model.predict(test_images)
plt.figure()
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

运行结果:

基于TensorFlow2的mnist数据集手写字体识别_tensorflow

举报

相关推荐

0 条评论