模型搭建
0、VGG模型
1、torchvision自带的VGG模型
1 import torch
2 import torchvision
3 from torchsummary import summary
4
5 model = torchvision.models.vgg13(num_classes=8).cuda()
6 summary(model, (3, 224, 224))
2、自己搭建
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from torchsummary import summary
5
6 class VGG(nn.Module):
7
8 def __init__(self, arch, num_classes=1000):
9 super(VGG, self).__init__()
10 self.in_channels = 3
11 self.conv3_64 = self._make_layer(64, arch[0])
12 self.conv3_128 = self._make_layer(128, arch[1])
13 self.conv3_256 = self._make_layer(256, arch[2])
14 self.conv3_512a = self._make_layer(512, arch[3])
15 self.conv3_512b = self._make_layer(512, arch[4])
16 self.flatten = nn.Flatten()
17 self.fc1 = nn.Linear(7*7*512, 4096)
18 self.bn1 = nn.BatchNorm1d(4096)
19 self.fc2 = nn.Linear(4096, 4096)
20 self.fc3 = nn.Linear(4096, num_classes)
21
22 def _make_layer(self, channels, num):
23 layers = []
24 for i in range(num):
25 layers.append(nn.Conv2d(self.in_channels, channels, 3, stride=1, padding=1, bias=False))
26 layers.append(nn.BatchNorm2d(channels))
27 layers.append(nn.ReLU())
28 self.in_channels = channels
29 return nn.Sequential(*layers)
30
31 def forward(self, x):
32 x = self.conv3_64(x)
33 x = F.max_pool2d(x, 2)
34 x = self.conv3_128(x)
35 x = F.max_pool2d(x, 2)
36 x = self.conv3_256(x)
37 x = F.max_pool2d(x, 2)
38 x = self.conv3_512a(x)
39 x = F.max_pool2d(x, 2)
40 x = self.conv3_512b(x)
41 x = F.max_pool2d(x, 2)
42 #x = x.view(x.size(0), -1)
43 x = self.flatten(x)
44 x = self.fc1(x)
45 x = self.bn1(x)
46 x = F.relu(x)
47 x = self.fc2(x)
48 x = self.bn1(x)
49 x = F.relu(x)
50 x = self.fc3(x)
51 return x
52
53 def VGG_11():
54 return VGG([1,1,2,2,2], num_classes=6)
55
56 def VGG_13():
57 return VGG([2,2,2,2,2], num_classes=7)
58
59 def VGG_16():
60 return VGG([2,2,3,3,3], num_classes=8)
61
62 def VGG_19():
63 return VGG([2,2,4,4,4], num_classes=9)
64
65 net = VGG_16().cuda()
66 summary(net, (3, 224, 224))