LeNet5网络和CIFAR10数据集
main函数--dataloader--train--test
1 import torch 2 from torch.utils.data import DataLoader 3 from torchvision import datasets 4 from torchvision import transforms 5 from torch import nn,optim 6 from lenet5 import LeNet5 7 8 def main(): 9 batch_size = 32 10 cifar_train = datasets.CIFAR10('cifar',train = True,transform = transforms.Compose([ 11 transforms.Resize((32,32)), 12 transforms.ToTensor() 13 ]),download = True) 14 15 # 可以同时加载多张图片 16 cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True) 17 18 cifar_test = datasets.CIFAR10('cifar',train = False,transform = transforms.Compose([ 19 transforms.Resize((32,32)), 20 transforms.ToTensor() 21 ]),download = True) 22 23 # 可以同时加载多张图片 24 cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True) 25 26 # 数据加载成功后可以检验shape 27 x,label = iter(cifar_train).next() 28 print('x:',x.shape,'label:',label.shape) 29 30 device = torch.device('cuda') 31 model = LeNet5().to(device) 32 criteon = nn.CrossEntropyLoss().to(device) 33 optimizer = optim.Adam(model.parameters(),lr=1e-3) 34 35 print(model) 36 37 for epoch in range(1000): 38 39 model.train() 40 for batchidx,(x,label) in enumerate(cifar_train): 41 # x: [b,3,32,32], label: [b] 42 x,label = x.to(device),label.to(device) 43 44 logits = model(x) 45 # logits:[b,10] 46 # label:[b] 47 loss = criteon(logits,label) 48 49 # backprop 50 optimizer.zero_grad() 51 loss.backwark() 52 optimizer.step() 53 54 # 55 print(epoch,loss.item()) 56 57 model.eval() 58 # 不需要做梯度相关计算 59 with torch.nn_grad(): 60 # test 61 total_correct = 0 62 total_num = 0 63 for x,label in cifar_test: 64 x,label = x.to(device),label.to(device) 65 # logits:[b,10] 66 logits = model(x) 67 pred = logits.argmax(dim=1) 68 # 获取一个batch的在累加 69 total_correct = += torch.eq(pred,label).float().sum().item() 70 # x.size(0)就是batch_size 71 total_num += x.size(0) 72 73 acc = total_correct / total_num 74 print(epoch,acc) 75 76 if __name__ == '__main__' 77 main()
LeNet网络--tmp测试
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 class LeNet5(nn.Module): 6 """ 7 for cifar10 dataset. 8 """ 9 def __init__(self): 10 super(LeNet5,self).__init__() 11 12 self.conv_unit = nn.Sequential( 13 # x:[b,3,32,32] --> [b,6,] 14 # input_channel,output_channel,kernel_size,stride,padding 15 nn.Conv2d(3,6,kernel_size = 5,stride = 1,padding = 0), 16 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0), 17 # 18 nn.Conv2d(6,16,kernel_size = 5,stride = 1,padding = 0), 19 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0), 20 ) 21 # Flatten 22 # fc_unit 23 self.fc_unit = nn.Sequential( 24 # 由下面的测试得出来的 25 nn.Linear(16*5*5,120), 26 # 全连接层会出现梯度离散现象,加一个relu 27 nn.ReLU(), 28 nn.Linear(120,84), 29 nn.ReLU(), 30 nn.Linear(84,10), 31 ) 32 ''' 33 tmp = torch.randn(2,3,32,32) 34 out = self.conv_unit(tmp) 35 # 测试一下输出的维度,用于全连接层 36 # [2,16,5,5] 37 print('conv_out:',out.shape) 38 ''' 39 40 # use Cross Entropy Loss 41 # 放到类外,不用引入y参数 42 # self.criteon = nn.CrossEntropyLoss() 43 44 # 从左往右走的,backward会自动根据这个走 45 def forward(self,x): 46 # 取得x的shape,然后0号为batch_size 47 batch_size = x.size(0) 48 # [b,3,32,32] --> [b,16,5,5] 49 x = self.conv_unit(x) 50 # [b,16,5,5] --> [b,16*5*5] 51 x = x.view(batch_size,16*5*5) 52 # [b,16*5*5] --> [b,10] 53 logits = self.fc_unit(x) 54 return logits 55 # [b,10] crossEntropy会包含,不用写 56 # pred = F.softmax(logits,dim = 1) 57 # loss = self.criteon(logits,y) 58 59 60 def main(): 61 62 net = LeNet5() 63 tmp = torch.randn(2,3,32,32) 64 out = net(tmp) 65 print('lenet_out:',out.shape) 66 67 68 if __name__ == '__main__' 69 main()