zoukankan      html  css  js  c++  java
  • pytorch ResNet

    一、ResNet网络结构

    1.1ResNet特点

    • 深层网络结构
    • 残差模块

     

    •  Batch Normalization加速训练

    使一批feature map满足均值为0,方差为1的分布。

    ResNet解决了网络层数增加带来的梯度消失,梯度爆炸和梯度退化问题。

    1.2网络结构

     residua block的虚线代表主分支和shortcut的shape不同,所以要在shortcut中加入kernel,使得输出的维度和主分支相同,才能进行相加。

    1.3参数列表

     二。模型和训练代码

    2.1 model.py

      1 import torch.nn as nn
      2 import torch
      3 
      4 
      5 class BasicBlock(nn.Module):
      6     expansion = 1
      7 
      8     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
      9         super(BasicBlock, self).__init__()
     10         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
     11                                kernel_size=3, stride=stride, padding=1, bias=False)
     12         self.bn1 = nn.BatchNorm2d(out_channel)
     13         self.relu = nn.ReLU()
     14         self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
     15                                kernel_size=3, stride=1, padding=1, bias=False)
     16         self.bn2 = nn.BatchNorm2d(out_channel)
     17         self.downsample = downsample
     18 
     19     def forward(self, x):
     20         identity = x
     21         if self.downsample is not None:
     22             identity = self.downsample(x)
     23 
     24         out = self.conv1(x)
     25         out = self.bn1(out)
     26         out = self.relu(out)
     27 
     28         out = self.conv2(out)
     29         out = self.bn2(out)
     30 
     31         out += identity
     32         out = self.relu(out)
     33 
     34         return out
     35 
     36 
     37 class Bottleneck(nn.Module):
     38     expansion = 4
     39 
     40     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
     41         super(Bottleneck, self).__init__()
     42         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
     43                                kernel_size=1, stride=1, bias=False)  # squeeze channels
     44         self.bn1 = nn.BatchNorm2d(out_channel)
     45         # -----------------------------------------
     46         self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
     47                                kernel_size=3, stride=stride, bias=False, padding=1)
     48         self.bn2 = nn.BatchNorm2d(out_channel)
     49         # -----------------------------------------
     50         self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
     51                                kernel_size=1, stride=1, bias=False)  # unsqueeze channels
     52         self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
     53         self.relu = nn.ReLU(inplace=True)
     54         self.downsample = downsample
     55 
     56     def forward(self, x):
     57         identity = x
     58         if self.downsample is not None:
     59             identity = self.downsample(x)
     60 
     61         out = self.conv1(x)
     62         out = self.bn1(out)
     63         out = self.relu(out)
     64 
     65         out = self.conv2(out)
     66         out = self.bn2(out)
     67         out = self.relu(out)
     68 
     69         out = self.conv3(out)
     70         out = self.bn3(out)
     71 
     72         out += identity
     73         out = self.relu(out)
     74 
     75         return out
     76 
     77 
     78 class ResNet(nn.Module):
     79 
     80     def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
     81         super(ResNet, self).__init__()
     82         self.include_top = include_top
     83         self.in_channel = 64
     84 
     85         self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
     86                                padding=3, bias=False)
     87         self.bn1 = nn.BatchNorm2d(self.in_channel)
     88         self.relu = nn.ReLU(inplace=True)
     89         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
     90         self.layer1 = self._make_layer(block, 64, blocks_num[0])
     91         self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
     92         self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
     93         self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
     94         if self.include_top:
     95             self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
     96             self.fc = nn.Linear(512 * block.expansion, num_classes)
     97 
     98         for m in self.modules():
     99             if isinstance(m, nn.Conv2d):
    100                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    101 
    102     def _make_layer(self, block, channel, block_num, stride=1):
    103         downsample = None
    104         if stride != 1 or self.in_channel != channel * block.expansion:
    105             downsample = nn.Sequential(
    106                 nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
    107                 nn.BatchNorm2d(channel * block.expansion))
    108 
    109         layers = []
    110         layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
    111         self.in_channel = channel * block.expansion
    112 
    113         for _ in range(1, block_num):
    114             layers.append(block(self.in_channel, channel))
    115 
    116         return nn.Sequential(*layers)
    117 
    118     def forward(self, x):
    119         x = self.conv1(x)
    120         x = self.bn1(x)
    121         x = self.relu(x)
    122         x = self.maxpool(x)
    123 
    124         x = self.layer1(x)
    125         x = self.layer2(x)
    126         x = self.layer3(x)
    127         x = self.layer4(x)
    128 
    129         if self.include_top:
    130             x = self.avgpool(x)
    131             x = torch.flatten(x, 1)
    132             x = self.fc(x)
    133 
    134         return x
    135 
    136 
    137 def resnet34(num_classes=1000, include_top=True):
    138     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
    139 
    140 
    141 def resnet101(num_classes=1000, include_top=True):
    142     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

    2.2 train.py 带迁移学习

      1 import torch
      2 import torch.nn as nn
      3 from torchvision import transforms, datasets
      4 import json
      5 import matplotlib.pyplot as plt
      6 import os
      7 import torch.optim as optim
      8 from model import resnet34, resnet101
      9 
     10 import torchvision.models.resnet
     11 def main():
     12     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     13     print("using {} device.".format(device))
     14 
     15     data_transform = {
     16         "train": transforms.Compose([transforms.RandomResizedCrop(224),
     17                                      transforms.RandomHorizontalFlip(),
     18                                      transforms.ToTensor(),
     19                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
     20         "val": transforms.Compose([transforms.Resize(256),
     21                                    transforms.CenterCrop(224),
     22                                    transforms.ToTensor(),
     23                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
     24 
     25     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
     26     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
     27     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
     28     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
     29                                          transform=data_transform["train"])
     30     train_num = len(train_dataset)
     31 
     32     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
     33     flower_list = train_dataset.class_to_idx
     34     cla_dict = dict((val, key) for key, val in flower_list.items())
     35     # write dict into json file
     36     json_str = json.dumps(cla_dict, indent=4)
     37     with open('class_indices.json', 'w') as json_file:
     38         json_file.write(json_str)
     39 
     40     batch_size = 16
     41     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
     42     print('Using {} dataloader workers every process'.format(0))
     43 
     44     train_loader = torch.utils.data.DataLoader(train_dataset,
     45                                                batch_size=batch_size, shuffle=True,
     46                                                num_workers=0)
     47 
     48     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
     49                                             transform=data_transform["val"])
     50     val_num = len(validate_dataset)
     51     validate_loader = torch.utils.data.DataLoader(validate_dataset,
     52                                                   batch_size=batch_size, shuffle=False,
     53                                                   num_workers=0)
     54 
     55     print("using {} images for training, {} images fot validation.".format(train_num,
     56                                                                            val_num))
     57     
     58     net = resnet34()
     59     # load pretrain weights transfer learning
     60     # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
     61     model_weight_path = "./resnet34-pre.pth"
     62     assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
     63     missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
     64     # for param in net.parameters():
     65     #     param.requires_grad = False
     66     # change fc layer structure
     67     in_channel = net.fc.in_features
     68     net.fc = nn.Linear(in_channel, 5)
     69     net.to(device)
     70 
     71     loss_function = nn.CrossEntropyLoss()
     72     optimizer = optim.Adam(net.parameters(), lr=0.0001)
     73 
     74     best_acc = 0.0
     75     save_path = './resNet34.pth'
     76     for epoch in range(3):
     77         # train
     78         net.train()
     79         running_loss = 0.0
     80         for step, data in enumerate(train_loader, start=0):
     81             images, labels = data
     82             optimizer.zero_grad()
     83             logits = net(images.to(device))
     84             loss = loss_function(logits, labels.to(device))
     85             loss.backward()
     86             optimizer.step()
     87 
     88             # print statistics
     89             running_loss += loss.item()
     90             # print train process
     91             rate = (step+1)/len(train_loader)
     92             a = "*" * int(rate * 50)
     93             b = "." * int((1 - rate) * 50)
     94             print("
    train loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
     95         print()
     96 
     97         # validate
     98         net.eval()
     99         acc = 0.0  # accumulate accurate number / epoch
    100         with torch.no_grad():
    101             for val_data in validate_loader:
    102                 val_images, val_labels = val_data
    103                 outputs = net(val_images.to(device))  # eval model only have last output layer
    104                 # loss = loss_function(outputs, test_labels)
    105                 predict_y = torch.max(outputs, dim=1)[1]
    106                 acc += (predict_y == val_labels.to(device)).sum().item()
    107             val_accurate = acc / val_num
    108             if val_accurate > best_acc:
    109                 best_acc = val_accurate
    110                 torch.save(net.state_dict(), save_path)
    111             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
    112                   (epoch + 1, running_loss / step, val_accurate))
    113 
    114     print('Finished Training')
    115 
    116 
    117 if __name__ == '__main__':
    118     main()

    2.3predict.py

     1 import torch
     2 from model import resnet34
     3 from PIL import Image
     4 from torchvision import transforms
     5 import matplotlib.pyplot as plt
     6 import json
     7 
     8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     9 
    10 data_transform = transforms.Compose(
    11     [transforms.Resize(256),
    12      transforms.CenterCrop(224),
    13      transforms.ToTensor(),
    14      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    15 
    16 # load image
    17 img = Image.open("../rose.jpg")
    18 plt.imshow(img)
    19 # [N, C, H, W]
    20 img = data_transform(img)
    21 # expand batch dimension
    22 img = torch.unsqueeze(img, dim=0)
    23 
    24 # read class_indict
    25 try:
    26     json_file = open('./class_indices.json', 'r')
    27     class_indict = json.load(json_file)
    28 except Exception as e:
    29     print(e)
    30     exit(-1)
    31 
    32 # create model
    33 model = resnet34(num_classes=5)
    34 # load model weights
    35 model_weight_path = "./resNet34.pth"
    36 model.load_state_dict(torch.load(model_weight_path, map_location=device))
    37 model.eval()
    38 with torch.no_grad():
    39     # predict class
    40     output = torch.squeeze(model(img))
    41     predict = torch.softmax(output, dim=0)
    42     predict_cla = torch.argmax(predict).numpy()
    43 print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    44 plt.show()
  • 相关阅读:
    @getMapping与@postMapping
    springcloud--入门
    Linux(centos6.5)mysql安装
    基于用户Spark ALS推荐系统(转)
    hadoop MapReduce在Linux上运行的一些命令
    Navicat连接阿里云轻量级应用服务器mysql
    HDFS操作笔记
    线程池的5种创建方式
    分布式共享锁的程序逻辑流程
    推荐系统常用数据集
  • 原文地址:https://www.cnblogs.com/sclu/p/14165056.html
Copyright © 2011-2022 走看看