参考资料:https://cuijiahua.com/blog/2018/01/dl_3.html
代码实现:
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 class Lenet5(nn.Module): 5 """ 6 for cifar10 dataset 7 """ 8 def __init__(self): 9 super(Lenet5, self).__init__() 10 self.conv_unit=nn.Sequential( 11 #卷积层 x:[b,3,32,32] =>[b,6,30,30] 12 nn.Conv2d(3,6,kernel_size=3,stride=1,padding=0), 13 #池化层 [b,6,30,30]=>[b,6,15,15] 14 nn.MaxPool2d(kernel_size=2,stride=2,padding=0), 15 #卷积层 [b,6,15,15] =>[b,16,13,13] 16 nn.Conv2d(6,16,kernel_size=3,stride=1,padding=0), 17 #池化层 [b,16,13,13]=>[b,16,6,6] 18 nn.MaxPool2d(kernel_size=2,stride=2,padding=0) 19 ) 20 #flatten 21 #全连接层 22 self.fc_unit=nn.Sequential( 23 nn.Linear(16*6*6,120), 24 nn.ReLU(inplace=True), 25 nn.Linear(120,84), 26 nn.ReLU(inplace=True), 27 nn.Linear(84,10) 28 ) 29 30 #use Cross Entropy loss 31 #self.criteon=nn.CrossEntropyLoss() 32 33 def forward(self,x): 34 """ 35 x:[b,3,32,32] 36 :param x: 37 :return: 38 """ 39 batchsz=x.size(0) 40 #[b,3,32,32]=>[b,16,6,6] 41 x=self.conv_unit(x) 42 #[b,16,6,6] => [b,16*6*6] 43 x=x.view(batchsz,16*6*6) 44 #[b,16*6*6]=>[b,10] 45 logits=self.fc_unit(x) 46 47 return logits 48 # 49 50 #测试函数 51 def main(): 52 net=Lenet5() 53 tmp = torch.randn(2, 3, 32, 32) 54 #通过测试,确定conv_unit输入维度 55 out = net(tmp) 56 print('conv out:', out.shape) 57 if __name__=='__main__': 58 main()
1 import torch 2 from torch.utils.data import DataLoader 3 from torchvision import datasets 4 from torchvision import transforms 5 from torch import nn, optim 6 7 from lenet5 import Lenet5 8 9 def main(): 10 #batchsize大小 11 batchsz = 32 12 #读取一张图片并进行数据增强 13 cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([ 14 transforms.Resize((32, 32)), 15 transforms.ToTensor() 16 ]), download=True) 17 #读取bathsize大小的一批图像 18 cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True) 19 20 cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([ 21 transforms.Resize((32, 32)), 22 transforms.ToTensor() 23 ]), download=True) 24 cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True) 25 26 #测试图像读取是否正确 27 #x, label = iter(cifar_train).next() 28 #print('x:', x.shape, 'label:', label.shape) 29 30 #选用gpu设备 31 device = torch.device('cuda:0,1') 32 model = Lenet5().to(device) 33 #model = ResNet18().to(device) 34 35 #定义交叉熵评估损失值 36 criteon = nn.CrossEntropyLoss().to(device) 37 #定义优化器 38 optimizer = optim.Adam(model.parameters(), lr=1e-3) 39 #打印模型信息 40 print(model) 41 42 #train 43 for epoch in range(1000): 44 #model.train() 45 for batchidx, (x, label) in enumerate(cifar_train): 46 # [b, 3, 32, 32] 47 # [b] 48 x, label = x.to(device), label.to(device) 49 50 #预测值 51 logits = model(x) 52 # logits: [b, 10] 53 # label: [b] 54 # loss: tensor scalar 55 loss = criteon(logits, label) 56 57 # 反向传播 58 optimizer.zero_grad() 59 loss.backward() 60 optimizer.step() 61 62 print(epoch, 'loss:', loss.item()) 63 #test 64 model.eval() 65 with torch.no_grad(): 66 # test 67 total_correct = 0 68 total_num = 0 69 for x, label in cifar_test: 70 # [b, 3, 32, 32] 71 # [b] 72 x, label = x.to(device), label.to(device) 73 74 # [b, 10] 75 logits = model(x) 76 # [b] 选取所有类别得分的最大值作为预测类 77 pred = logits.argmax(dim=1) 78 # [b] vs [b] => scalar tensor 79 #计算每个batch正确的预测数量 80 correct = torch.eq(pred, label).float().sum().item() 81 total_correct += correct 82 total_num += x.size(0) 83 acc = total_correct / total_num 84 print(epoch, 'acc:', acc) 85 86 if __name__ == '__main__': 87 main()
结果: