试一试
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from torchsummary import summary 5 6 class ResBlock(nn.Module): 7 def __init__(self, inchannel, outchannel, stride=1): 8 super(ResBlock, self).__init__() 9 self.left = nn.Sequential( 10 nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 11 nn.BatchNorm2d(outchannel), 12 nn.ReLU(inplace=True), 13 nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 14 nn.BatchNorm2d(outchannel) 15 ) 16 self.shortcut = nn.Sequential() 17 if stride != 1 or inchannel != outchannel: 18 self.shortcut = nn.Sequential( 19 nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 20 nn.BatchNorm2d(outchannel) 21 ) 22 23 def forward(self, x): 24 out = self.left(x) 25 out = out + self.shortcut(x) 26 out = F.relu(out) 27 return out 28 29 class ResNet(nn.Module): 30 def __init__(self, ResBlock, num_classes=10): 31 super(ResNet, self).__init__() 32 self.inchannel = 64 33 self.conv1 = nn.Sequential( 34 nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 35 nn.BatchNorm2d(64), 36 nn.ReLU() 37 ) 38 self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1) 39 self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2) 40 self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2) 41 self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2) 42 self.fc = nn.Linear(512, num_classes) 43 def make_layer(self, block, channels, num_blocks, stride): 44 strides = [stride] + [1] * (num_blocks - 1) 45 layers = [] 46 for stride in strides: 47 layers.append(block(self.inchannel, channels, stride)) 48 self.inchannel = channels 49 return nn.Sequential(*layers) 50 51 def forward(self, x): 52 out = self.conv1(x) 53 out = self.layer1(out) 54 out = self.layer2(out) 55 out = self.layer3(out) 56 out = self.layer4(out) 57 out = F.avg_pool2d(out, 28) 58 out = out.view(out.size(0), -1) 59 out = self.fc(out) 60 return out 61 62 def ResNet18(): 63 return ResNet(ResBlock, num_classes=40) 64 65 if __name__ == "__main__": 66 model = ResNet18().cuda() 67 #summary(model, (3, 32, 32)) 68 summary(model, (3, 224, 224))