pytorch手写体识别
代码
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from torch_study.lesson5_minist_train.utils import plot_curve, plot_image, plt, one_hot
batch_size = 512
# step1. load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
#batch_size为一次训练多少,shuffle是否打散
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
#查看数据维度
x,y = next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#wx+b
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,64)
self.fc3 = nn.Linear(64,10)
def forward(self,x):
# x:[b,1,28,28]
# h1=relu(xw1+b1)
x=F.relu(self.fc1(x))
# h2=relu(h1*w2+b2)
x=F.relu(self.fc2(x))
# h3=h2*w3+b3
x=self.fc3(x)
return x
net = Net()
# [w1,b1,w2,b2,w3,b3] momentum动量
optimizer = optim.SGD(net.parameters(),lr=0.05,momentum=0.9)
train_loss = []
#对数据集迭代3次
for epoch in range(3):
#从数据集中sample出一个batch_size图片
for batch_idx ,(x,y) in enumerate(train_loader):
#x:[b,1,28,28] ,y[512]
#[b,1,28,28] => [b,feature]
x=x.view(x.size(0),28*28)
# => [b,10]
out = net(x)
#[b,10]
y_onehot = one_hot(y)
#loss = mse(out,y_onehot)
loss = F.cross_entropy(out,y_onehot)
#清零梯度
optimizer.zero_grad()
#计算梯度
loss.backward()
#w'=w-lr*grad,更新梯度
optimizer.step()
train_loss.append(loss.item())
if batch_idx %10 ==0:
print(epoch,batch_idx,loss.item())
#绘制损失曲线
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]
#对测试集进行判断
total_corrrect=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
# out:[b,10] => pred: [b]
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item()
total_corrrect+=correct
total_num = len(test_loader.dataset)
acc = total_corrrect / total_num
print('test acc:',acc)
x,y=next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')
结果
模型提升
增加模型层数
调整loss损失计算函数
调整学习率,训练大小batch_size