0
点赞
收藏
分享

微信扫一扫

实例3:将图片制作成内存对象数据集

骨灰级搬砖工 2022-04-14 阅读 78

3. 实例3:将图片制作成内存对象数据集

将图片样本数据制成数据集,使用TensorFlow队列方式,好处:能充分使用cpu的多线程资源,让训练模型与数据读取平行同时运行

1. 样本介绍

MNIST手写数据集

2. 代码实现:载入文件名称与标签

import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()

import os 
from matplotlib import pyplot as plt 
import numpy as np 
from sklearn.utils import shuffle

def load_samples(sample_dir):
    lfilenames = []
    labelsnames = []
    for (dirpath, dirnames, filenames) in os.walk(sample_dir):
        for filename in filenames:
            filename_path = os.sep.join([dirpath, filename])
            lfilenames.append(filename_path)
            labelsnames.append(dirpath.split('\\')[-1])
    lab = list(sorted(set(labelsnames)))
    labdict = dict(zip(lab, list(range(len(lab)))))
    
    labels = [labdict[i] for i in labelsnames]
    
    return shuffle(np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)

data_dir = 'minst_digits_images\\'

(image, labels), labelnames = load_samples((data_dir))

load_sample函数会以元组的形式返回3个对象

  • lfilenames:图片文件名数组
  • lables:数值化后的标签
  • lab:数值化的标签与字符串标签对应

3. 代码实现:生成队列中的批次样本数据

具体步骤:

  • 用tf.train.slice_input_producer函数生成一个输入队列
  • 按照路径读取推按,并对图片进行预处理
  • 用tf.train.batch函数将图片变成批次数据
    -调用tf.train.batch时,可以指定批次(batch_size)、线程个数(num_threads)、队列长度(capacity)
def get_batches(image, label, resize_w, resize_h, channels, batch_size):
    
    
    queue = tf.train.slice_input_producer([image, label])
    label = queue[1]
    
    image_c = tf.read_file(queue[0])    #从队列里读取image路径
    image = tf.image.decode_bmp(image_c, channels)  #读取问图片文件
    
    image = tf.image.resize_image_with_crop_or_pad(image, resize_w, resize_h)   #修改图片大小
    
    #将图像标准化处理
    image = tf.image.per_image_standardization(image)

    image_batch,label_batch = tf.train.batch([image, label],
                                             batch_size,
                                             num_threads = 64)
    images_batch = tf.cast(image_batch, tf.float32)
    labels_batch = tf.reshape(label_batch, [batch_size])
    return images_batch, labels_batch

batch_size = 16
image_batchs, labels_batch = get_batches(image, label, 28, 28, 1, batch_size)

4. 代码实现:在会话中使用数据集

定义showresult和showing函数,将图片数据进行可视化输出
建立session,准备运行静态图,并在session中启动协调器队列线程,通过session的run方法获得数据

#在会话中使用数据积集
def showresult(subplot, title, thisimg):    #显示单个图片
    p = plt.subplot(subplot)
    p.axis('off')
    p.imshow(np.reshape(thisimg, (28, 28)))
    p.set_title(title)
    
def showimg(index, label, img, ntop):   #显示批次图片
    plt.figure(figsize=(20,10))
    plt.axis('off')
    ntop = min(ntop, 9)
    print(index)
    for i in range(ntop):
        showresult(100+10*ntop+1+i, label[i], img[i])
    plt.show()
    
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    
    coord = tf.train.Coordinator()  #建立队列协调器
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)  #启动队列线程
    
    try:
        for step in np.arange(10):
            if coord.should_stop():
                break
            images,label = sess.run([image_batches,label_batches])  #注入数据
            
            showimg(step, label, images, batch_size)
            print(label)
    except tf.errors.OutOfRangeError:
        print("Done!!")
    finally:
        coord.request_stop()
        
    coord.join(threads)

读取bmp格式图片时,单通道图片会出错

举报

相关推荐

0 条评论