手写神经网络,实现mnist数据集分类。本过程使用numpy实现,没有使用机器学习和深度学习框架。
1. 数据读取及处理
import numpy as np
import struct
import matplotlib.pyplot as plt
import os
from PIL import Image
from sklearn.utils import gen_batches
from sklearn.metrics import classification_report, confusion_matrix
from typing import *
from numpy.linalg import *
train_image_file = './mnist/train-images-idx3-ubyte'
train_label_file = './mnist/train-labels-idx1-ubyte'
test_image_file = './mnist/t10k-images-idx3-ubyte'
test_label_file = './mnist/t10k-labels-idx1-ubyte'
def decode_image(path):
with open(path, 'rb') as f:
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
images = np.fromfile(f, dtype=np.uint8).reshape(-1, 784)
images = np.array(images, dtype = float)
return images
def decode_label(path):
with open(path, 'rb') as f: