3. 实例3:将图片制作成内存对象数据集
将图片样本数据制成数据集,使用TensorFlow队列方式,好处:能充分使用cpu的多线程资源,让训练模型与数据读取平行同时运行
1. 样本介绍
MNIST手写数据集
2. 代码实现:载入文件名称与标签
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
import os
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
def load_samples(sample_dir):
lfilenames = []
labelsnames = []
for (dirpath, dirnames, filenames) in os.walk(sample_dir):
for filename in filenames:
filename_path = os.sep.join([dirpath, filename])
lfilenames.append(filename_path)
labelsnames.append(dirpath.split('\\')[-1])
lab = list(sorted(set(labelsnames)))
labdict = dict(zip(lab, list(range(len(lab)))))
labels = [labdict[i] for i in labelsnames]
return shuffle(np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
data_dir = 'minst_digits_images\\'
(image, labels), labelnames = load_samples((data_dir))
load_sample函数会以元组的形式返回3个对象
- lfilenames:图片文件名数组
- lables:数值化后的标签
- lab:数值化的标签与字符串标签对应
3. 代码实现:生成队列中的批次样本数据
具体步骤:
- 用tf.train.slice_input_producer函数生成一个输入队列
- 按照路径读取推按,并对图片进行预处理
- 用tf.train.batch函数将图片变成批次数据
-调用tf.train.batch时,可以指定批次(batch_size)、线程个数(num_threads)、队列长度(capacity)
def get_batches(image, label, resize_w, resize_h, channels, batch_size):
queue = tf.train.slice_input_producer([image, label])
label = queue[1]
image_c = tf.read_file(queue[0]) #从队列里读取image路径
image = tf.image.decode_bmp(image_c, channels) #读取问图片文件
image = tf.image.resize_image_with_crop_or_pad(image, resize_w, resize_h) #修改图片大小
#将图像标准化处理
image = tf.image.per_image_standardization(image)
image_batch,label_batch = tf.train.batch([image, label],
batch_size,
num_threads = 64)
images_batch = tf.cast(image_batch, tf.float32)
labels_batch = tf.reshape(label_batch, [batch_size])
return images_batch, labels_batch
batch_size = 16
image_batchs, labels_batch = get_batches(image, label, 28, 28, 1, batch_size)
4. 代码实现:在会话中使用数据集
定义showresult和showing函数,将图片数据进行可视化输出
建立session,准备运行静态图,并在session中启动协调器队列线程,通过session的run方法获得数据
#在会话中使用数据积集
def showresult(subplot, title, thisimg): #显示单个图片
p = plt.subplot(subplot)
p.axis('off')
p.imshow(np.reshape(thisimg, (28, 28)))
p.set_title(title)
def showimg(index, label, img, ntop): #显示批次图片
plt.figure(figsize=(20,10))
plt.axis('off')
ntop = min(ntop, 9)
print(index)
for i in range(ntop):
showresult(100+10*ntop+1+i, label[i], img[i])
plt.show()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
coord = tf.train.Coordinator() #建立队列协调器
threads = tf.train.start_queue_runners(sess = sess, coord = coord) #启动队列线程
try:
for step in np.arange(10):
if coord.should_stop():
break
images,label = sess.run([image_batches,label_batches]) #注入数据
showimg(step, label, images, batch_size)
print(label)
except tf.errors.OutOfRangeError:
print("Done!!")
finally:
coord.request_stop()
coord.join(threads)
读取bmp格式图片时,单通道图片会出错