0
点赞
收藏
分享

微信扫一扫

【神经网络八股扩展】:数据增强


文章目录

  • ​​前言​​
  • ​​TensorFlow2数据增强函数​​
  • ​​数据增强+网络八股代码:​​
  • ​​总结​​

前言

本讲目标:数据增强,增大数据量
关于我们为何要使用数据增强以及常用的几种数据增强的手法,可以看看下面的文章,虽说是翻译的,但仍有可鉴之处:
​​​数据增强(Data Augmentation)​​

TensorFlow2数据增强函数

对图像的增强就是对图像的简单形变,用来应对因为拍照角度不同引起的图片形变。
TensorFlow2给出了数据增强函数

image_gen_train =tf.keras.preprocessing.image.ImageDataGenerator(
rescale = 所有数据将乘以该数值
rotation_range =随机旋转角度数范围
width_shift_range = 随机宽度偏移量
height_shift_range =随机高度偏移量
horizontal_flip =是否随机水平翻转
zoom_range =随机缩放的范围[1-n,1+n])
image_gen_train.fit(x_train)
fit需要输入4维数据,所以将x_train reshape为(60000,28,28,1)
这个1表示单通道灰度值
model.fit同步更新为.flow形式

划红线的部分为需要注意的地方:
【神经网络八股扩展】:数据增强_神经网络

数据增强+网络八股代码:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) # 给数据增加一个维度,使数据和网络结构匹配

image_gen_train = ImageDataGenerator(
rescale=1. / 1., # 如为图像,分母为255时,可归至0~1
rotation_range=45, # 随机45度旋转
width_shift_range=.15, # 宽度偏移
height_shift_range=.15, # 高度偏移
horizontal_flip=True, # 水平翻转
zoom_range=0.5 # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),
validation_freq=1)
model.summary()

【神经网络八股扩展】:数据增强_tensorflow_02
随着迭代轮数增加,准确率不断提高。但从数据集上不能看出数据增强的效果,要在实际应用中去使用。

总结


举报

相关推荐

0 条评论