0
点赞
收藏
分享

微信扫一扫

基于GoogLeNet的猫十二分类

1.导包

from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm

2.数据预处理

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

(1)定义prepare_image函数从文件中分离路径和标签

def prepare_image(file_path):
    X_train = []
    y_train = []

    with open(file_path) as f:
        context = f.readlines()
    random.shuffle(context)

    for str in context:
        str = str.strip('\n').split('\t')

        X_train.append('./image/cat_12/' + str[0])
        y_train.append(str[1])

    return X_train, y_train

(2)定义preprocess_image函数进行图像的归一化

def preprocess_image(img):
    img = image.load_img(img, target_s
举报

相关推荐

0 条评论