1.导包
from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
2.数据预处理
cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt
(1)定义prepare_image函数从文件中分离路径和标签
def prepare_image(file_path):
X_train = []
y_train = []
with open(file_path) as f:
context = f.readlines()
random.shuffle(context)
for str in context:
str = str.strip('\n').split('\t')
X_train.append('./image/cat_12/' + str[0])
y_train.append(str[1])
return X_train, y_train
(2)定义preprocess_image函数进行图像的归一化
def preprocess_image(img):
img = image.load_img(img, target_s