0
点赞
收藏
分享

微信扫一扫

TensorFlow 从入门到精通(4)—— Fashion-MNIST分类问题


Fashion-MNIST是德国研究机构Zalando Research在2017年8月份开源的一个数据集,它包含了如图13.7所示的人们日常穿着的衣服、鞋子、包等十类生活用品,其中训练集包含60000张训练集和10000张测试集,每张图片和MNIST数据集一样,均是28*28像素的灰度图,包括10个种类的生活用品(T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)。

  • 其实本质上也就是十分类问题,所以我们的手写数字识别可以很好的迁移到这个问题上。

一、数据集

import tensorflow as tf
tf.__version__

'2.6.0'

# 导入数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()

train_images.shape,test_images.shape,train_labels.shape

((60000, 28, 28), (10000, 28, 28), (60000,))

# 展示图片
import matplotlib.pyplot as plt

def plot_image(image):
plt.imshow(image,cmap='binary')
plt.show()

plot_image(train_images[0])

TensorFlow 从入门到精通(4)—— Fashion-MNIST分类问题_数据集

# 划分数据集
total_num = len(train_images)
split_valid = 0.2
train_num = int((1 - split_valid) * total_num)

# 训练集
train_x = train_images[:train_num]
train_y = train_labels[:train_num]
# 验证集
valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]
# 测试集
test_x = test_images
test_y =

# 数据塑形+归一化
train_x = tf.cast(train_x.reshape(-1,784)/255.0,dtype=tf.float32)
valid_x = tf.cast(valid_x.reshape(-1,784)/255.0,dtype=tf.float32)
test_x = tf.cast(test_x.reshape(-1,784)/255.0,dtype=tf.float32)

# 标签进行独热编码
train_y = tf.one_hot(train_y,10)
valid_y = tf.one_hot(valid_y,10)
test_y = tf.one_hot(test_y,10)

二、模型

def model(w,x,b):
model_ = tf.matmul(x,w) + b
return tf.nn.softmax(model_)

# 损失函数
def loss(w,x,y,b):
pred = model(w,x,b)
loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)
return tf.reduce_mean(loss_)

# 准确率
def accuracy(w,x,y,b):
pred = model(w,x,b)
acc = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))
return tf.reduce_mean(tf.cast(acc,dtype=tf.float32))

# 计算梯度
def grad(w,x,y,b):
with tf.GradientTape() as tape:
loss_ = loss(w,x,y,b)
return tape.gradient(loss_,[w,b])

三、训练

# 定义超参数
train_epochs = 50
learning_rate = 0.003
batch_size = 50

W = tf.Variable(tf.random.normal(shape=(784,10)),dtype=tf.float32)
B = tf.Variable(tf.zeros(10),dtype=tf.float32)
total_steps = train_num // batch_size
train_loss_list = []
valid_loss_list = []
trian_acc_list = []
valide_acc_list = []

# 优化器
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

for epoch in range(train_epochs):
for step in range(total_steps):
xs = train_x[step*batch_size:(step+1)*batch_size]
ys = train_y[step*batch_size:(step+1)*batch_size]
grads = grad(W,xs,ys,B)
optimizer.apply_gradients(zip(grads,[W,B]))
trian_loss = loss(W,train_x,train_y,B).numpy()
valid_loss = loss(W,valid_x,valid_y,B).numpy()
train_accuracy = accuracy(W,train_x,train_y,B).numpy()
valid_accuracy = accuracy(W,valid_x,valid_y,B).numpy()
trian_acc_list.append(train_accuracy)
valide_acc_list.append(valid_accuracy)
train_loss_list.append(trian_loss)
valid_loss_list.append(valid_loss)
print(f'{epoch+1}:trian_loss:{trian_loss}valid_loss:{valid_loss}train_accuracy:{train_accuracy}valid_accuracy:{valid_accuracy}')

1:trian_loss:1.138431429862976valid_loss:1.1841025352478027train_accuracy:0.7462499737739563valid_accuracy:0.7403333187103271
2:trian_loss:0.8920454382896423valid_loss:0.9442158341407776train_accuracy:0.7822083234786987valid_accuracy:0.7770000100135803
3:trian_loss:0.780419111251831valid_loss:0.8372109532356262train_accuracy:0.7978749871253967valid_accuracy:0.7905833125114441
4:trian_loss:0.7121660113334656valid_loss:0.771328866481781train_accuracy:0.807729184627533valid_accuracy:0.8003333210945129
5:trian_loss:0.6643809676170349valid_loss:0.7260748744010925train_accuracy:0.8161249756813049valid_accuracy:0.8078333139419556
6:trian_loss:0.6263258457183838valid_loss:0.6907376050949097train_accuracy:0.8222916722297668valid_accuracy:0.8115000128746033
7:trian_loss:0.5948985815048218valid_loss:0.6614484786987305train_accuracy:0.827958345413208valid_accuracy:0.8144166469573975
8:trian_loss:0.5691960453987122valid_loss:0.6380931735038757train_accuracy:0.8324791789054871valid_accuracy:0.8181666731834412
9:trian_loss:0.5439199805259705valid_loss:0.6148083209991455train_accuracy:0.8366041779518127valid_accuracy:0.8224166631698608
10:trian_loss:0.5263316035270691valid_loss:0.5994471311569214train_accuracy:0.839062511920929valid_accuracy:0.8263333439826965
11:trian_loss:0.5114473700523376valid_loss:0.5868139863014221train_accuracy:0.8418750166893005valid_accuracy:0.827750027179718
12:trian_loss:0.49910587072372437valid_loss:0.5764788389205933train_accuracy:0.8445833325386047valid_accuracy:0.8302500247955322
13:trian_loss:0.4881014823913574valid_loss:0.5674534440040588train_accuracy:0.8461250066757202valid_accuracy:0.8321666717529297
14:trian_loss:0.4786192774772644valid_loss:0.5599692463874817train_accuracy:0.8481666445732117valid_accuracy:0.8335833549499512
15:trian_loss:0.4701770544052124valid_loss:0.5534582734107971train_accuracy:0.8498749732971191valid_accuracy:0.8341666460037231
16:trian_loss:0.46222203969955444valid_loss:0.547866702079773train_accuracy:0.8515416383743286valid_accuracy:0.8355000019073486
17:trian_loss:0.4555463194847107valid_loss:0.5430532097816467train_accuracy:0.8529999852180481valid_accuracy:0.8364166617393494
18:trian_loss:0.44957873225212097valid_loss:0.5388360023498535train_accuracy:0.8540208339691162valid_accuracy:0.8360833525657654
19:trian_loss:0.44411206245422363valid_loss:0.5349902510643005train_accuracy:0.8552500009536743valid_accuracy:0.8370833396911621
20:trian_loss:0.43913379311561584valid_loss:0.5315722823143005train_accuracy:0.8560000061988831valid_accuracy:0.8379166722297668
21:trian_loss:0.43459969758987427valid_loss:0.5285435914993286train_accuracy:0.8567291498184204valid_accuracy:0.8379166722297668
22:trian_loss:0.4275083541870117valid_loss:0.5233802795410156train_accuracy:0.8580625057220459valid_accuracy:0.8403333425521851
23:trian_loss:0.423616498708725valid_loss:0.5209391117095947train_accuracy:0.8589583039283752valid_accuracy:0.840833306312561
24:trian_loss:0.4200538396835327valid_loss:0.5187627077102661train_accuracy:0.8595208525657654valid_accuracy:0.840416669845581
25:trian_loss:0.41685953736305237valid_loss:0.5168584585189819train_accuracy:0.8600833415985107valid_accuracy:0.840666651725769
26:trian_loss:0.41382479667663574valid_loss:0.5151072144508362train_accuracy:0.8607708215713501valid_accuracy:0.8409166932106018
27:trian_loss:0.41100260615348816valid_loss:0.5135394930839539train_accuracy:0.8615416884422302valid_accuracy:0.8410000205039978
28:trian_loss:0.40838223695755005valid_loss:0.5121424198150635train_accuracy:0.8618124723434448valid_accuracy:0.8412500023841858
29:trian_loss:0.4059460163116455valid_loss:0.5109005570411682train_accuracy:0.862541675567627valid_accuracy:0.8418333530426025
30:trian_loss:0.40367549657821655valid_loss:0.5097998380661011train_accuracy:0.8631458282470703valid_accuracy:0.8420833349227905
31:trian_loss:0.4012702703475952valid_loss:0.5087489485740662train_accuracy:0.8636458516120911valid_accuracy:0.8427500128746033
32:trian_loss:0.3992216885089874valid_loss:0.507870614528656train_accuracy:0.8638333082199097valid_accuracy:0.8429999947547913
33:trian_loss:0.3973230719566345valid_loss:0.5071054697036743train_accuracy:0.8643125295639038valid_accuracy:0.8428333401679993
34:trian_loss:0.39554792642593384valid_loss:0.5064404010772705train_accuracy:0.8648541569709778valid_accuracy:0.8426666855812073
35:trian_loss:0.3938808739185333valid_loss:0.5058655738830566train_accuracy:0.8652499914169312valid_accuracy:0.8424999713897705
36:trian_loss:0.39210212230682373valid_loss:0.505139172077179train_accuracy:0.8655624985694885valid_accuracy:0.8431666493415833
37:trian_loss:0.3905573785305023valid_loss:0.5046170353889465train_accuracy:0.8657916784286499valid_accuracy:0.8428333401679993
38:trian_loss:0.38911256194114685valid_loss:0.5041868686676025train_accuracy:0.8661875128746033valid_accuracy:0.8430833220481873
39:trian_loss:0.3877559304237366valid_loss:0.5038341283798218train_accuracy:0.8667291402816772valid_accuracy:0.8431666493415833
40:trian_loss:0.38655149936676025valid_loss:0.5036598443984985train_accuracy:0.8669583201408386valid_accuracy:0.843416690826416
41:trian_loss:0.38540738821029663valid_loss:0.5034105777740479train_accuracy:0.867270827293396valid_accuracy:0.844083309173584
42:trian_loss:0.38427892327308655valid_loss:0.503218412399292train_accuracy:0.8675000071525574valid_accuracy:0.843999981880188
43:trian_loss:0.3831900656223297valid_loss:0.5030636191368103train_accuracy:0.8677916526794434valid_accuracy:0.8445833325386047
44:trian_loss:0.38214486837387085valid_loss:0.5029441714286804train_accuracy:0.8681874871253967valid_accuracy:0.8448333144187927
45:trian_loss:0.3811493217945099valid_loss:0.502860426902771train_accuracy:0.8684166669845581valid_accuracy:0.8445000052452087
46:trian_loss:0.3802007734775543valid_loss:0.502810001373291train_accuracy:0.8686875104904175valid_accuracy:0.8445833325386047
47:trian_loss:0.37929588556289673valid_loss:0.5027902126312256train_accuracy:0.8688750267028809valid_accuracy:0.8450000286102295
48:trian_loss:0.37843266129493713valid_loss:0.5027962923049927train_accuracy:0.8690416812896729valid_accuracy:0.8450000286102295
49:trian_loss:0.3776928782463074valid_loss:0.5028855800628662train_accuracy:0.8692708611488342valid_accuracy:0.8448333144187927
50:trian_loss:0.37693071365356445valid_loss:0.5029789805412292train_accuracy:0.8694166541099548valid_accuracy:0.8445833325386047

accuracy(W,test_x,test_y,B).numpy()

0.8304

# 损失图像
plt.plot(train_loss_list,'r')
plt.plot(valid_loss_list,'b')

[<matplotlib.lines.Line2D at 0x7fec3befee10>]

TensorFlow 从入门到精通(4)—— Fashion-MNIST分类问题_tensorflow_02

# 准确率图像
plt.plot(trian_acc_list,'r')
plt.plot(valide_acc_list,'b')

[<matplotlib.lines.Line2D at 0x7fec3beed690>]

TensorFlow 从入门到精通(4)—— Fashion-MNIST分类问题_神经网络_03

四、预测

def predict(x,w,b):
pred = model(w,x,b)
pred_ = tf.argmax(pred,axis=1)
return

import numpy as np
id = np.random.randint(0,len(test_x)) # 随机生成一个验证id
# 预测值
pred = predict(test_x,W,B)[id]
# 真实值
true = test_labels[id]
print(true,pred.numpy())

1 1

import sklearn.metrics as sm
print(f'r2:{sm.r2_score(test_y,model(W,test_x,B))}')

r2:0.7226787373136627


举报

相关推荐

0 条评论