zoukankan      html  css  js  c++  java
  • pytorch GoogLeNet

    一。GoogLeNet网络结构

    1.特点:

    采用inspection结构和2个辅助的分类器。inspection结构是并行结构。加入了1x1的卷积核来实现降维,能够减少训练参数。

    2.网络结构

     3.Inspection结构

     4.参数列表

     二。训练代码

    model.py

      1 import torch.nn as nn
      2 import torch
      3 import torch.nn.functional as F
      4 
      5 
      6 class GoogLeNet(nn.Module):
      7     def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
      8         super(GoogLeNet, self).__init__()
      9         self.aux_logits = aux_logits
     10 
     11         self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
     12         self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
     13 
     14         self.conv2 = BasicConv2d(64, 64, kernel_size=1)
     15         self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
     16         self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
     17 
     18         self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
     19         self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
     20         self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
     21 
     22         self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
     23         self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
     24         self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
     25         self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
     26         self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
     27         self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
     28 
     29         self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
     30         self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
     31 
     32         if self.aux_logits:
     33             self.aux1 = InceptionAux(512, num_classes)
     34             self.aux2 = InceptionAux(528, num_classes)
     35 
     36         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
     37         self.dropout = nn.Dropout(0.4)
     38         self.fc = nn.Linear(1024, num_classes)
     39         if init_weights:
     40             self._initialize_weights()
     41 
     42     def forward(self, x):
     43         # N x 3 x 224 x 224
     44         x = self.conv1(x)
     45         # N x 64 x 112 x 112
     46         x = self.maxpool1(x)
     47         # N x 64 x 56 x 56
     48         x = self.conv2(x)
     49         # N x 64 x 56 x 56
     50         x = self.conv3(x)
     51         # N x 192 x 56 x 56
     52         x = self.maxpool2(x)
     53 
     54         # N x 192 x 28 x 28
     55         x = self.inception3a(x)
     56         # N x 256 x 28 x 28
     57         x = self.inception3b(x)
     58         # N x 480 x 28 x 28
     59         x = self.maxpool3(x)
     60         # N x 480 x 14 x 14
     61         x = self.inception4a(x)
     62         # N x 512 x 14 x 14
     63         if self.training and self.aux_logits:    # eval model lose this layer
     64             aux1 = self.aux1(x)
     65 
     66         x = self.inception4b(x)
     67         # N x 512 x 14 x 14
     68         x = self.inception4c(x)
     69         # N x 512 x 14 x 14
     70         x = self.inception4d(x)
     71         # N x 528 x 14 x 14
     72         if self.training and self.aux_logits:    # eval model lose this layer
     73             aux2 = self.aux2(x)
     74 
     75         x = self.inception4e(x)
     76         # N x 832 x 14 x 14
     77         x = self.maxpool4(x)
     78         # N x 832 x 7 x 7
     79         x = self.inception5a(x)
     80         # N x 832 x 7 x 7
     81         x = self.inception5b(x)
     82         # N x 1024 x 7 x 7
     83 
     84         x = self.avgpool(x)
     85         # N x 1024 x 1 x 1
     86         x = torch.flatten(x, 1)
     87         # N x 1024
     88         x = self.dropout(x)
     89         x = self.fc(x)
     90         # N x 1000 (num_classes)
     91         if self.training and self.aux_logits:   # eval model lose this layer
     92             return x, aux2, aux1
     93         return x
     94 
     95     def _initialize_weights(self):
     96         for m in self.modules():
     97             if isinstance(m, nn.Conv2d):
     98                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     99                 if m.bias is not None:
    100                     nn.init.constant_(m.bias, 0)
    101             elif isinstance(m, nn.Linear):
    102                 nn.init.normal_(m.weight, 0, 0.01)
    103                 nn.init.constant_(m.bias, 0)
    104 
    105 
    106 class Inception(nn.Module):
    107     def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
    108         super(Inception, self).__init__()
    109 
    110         self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
    111 
    112         self.branch2 = nn.Sequential(
    113             BasicConv2d(in_channels, ch3x3red, kernel_size=1),
    114             BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
    115         )
    116 
    117         self.branch3 = nn.Sequential(
    118             BasicConv2d(in_channels, ch5x5red, kernel_size=1),
    119             BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
    120         )
    121 
    122         self.branch4 = nn.Sequential(
    123             nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
    124             BasicConv2d(in_channels, pool_proj, kernel_size=1)
    125         )
    126 
    127     def forward(self, x):
    128         branch1 = self.branch1(x)
    129         branch2 = self.branch2(x)
    130         branch3 = self.branch3(x)
    131         branch4 = self.branch4(x)
    132 
    133         outputs = [branch1, branch2, branch3, branch4]
    134         return torch.cat(outputs, 1)
    135 
    136 
    137 class InceptionAux(nn.Module):
    138     def __init__(self, in_channels, num_classes):
    139         super(InceptionAux, self).__init__()
    140         self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
    141         self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]
    142 
    143         self.fc1 = nn.Linear(2048, 1024)
    144         self.fc2 = nn.Linear(1024, num_classes)
    145 
    146     def forward(self, x):
    147         # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
    148         x = self.averagePool(x)
    149         # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
    150         x = self.conv(x)
    151         # N x 128 x 4 x 4
    152         x = torch.flatten(x, 1)
    153         x = F.dropout(x, 0.5, training=self.training)
    154         # N x 2048
    155         x = F.relu(self.fc1(x), inplace=True)
    156         x = F.dropout(x, 0.5, training=self.training)
    157         # N x 1024
    158         x = self.fc2(x)
    159         # N x num_classes
    160         return x
    161 
    162 
    163 class BasicConv2d(nn.Module):
    164     def __init__(self, in_channels, out_channels, **kwargs):
    165         super(BasicConv2d, self).__init__()
    166         self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
    167         self.relu = nn.ReLU(inplace=True)
    168 
    169     def forward(self, x):
    170         x = self.conv(x)
    171         x = self.relu(x)
    172         return x

    train.py

      1 import torch
      2 import torch.nn as nn
      3 from torchvision import transforms, datasets
      4 import torchvision
      5 import json
      6 import matplotlib.pyplot as plt
      7 import os
      8 import torch.optim as optim
      9 from model import GoogLeNet
     10 
     11 
     12 def main():
     13     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     14     print("using {} device.".format(device))
     15 
     16     data_transform = {
     17         "train": transforms.Compose([transforms.RandomResizedCrop(224),
     18                                      transforms.RandomHorizontalFlip(),
     19                                      transforms.ToTensor(),
     20                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
     21         "val": transforms.Compose([transforms.Resize((224, 224)),
     22                                    transforms.ToTensor(),
     23                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
     24 
     25     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
     26     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
     27     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
     28     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
     29                                          transform=data_transform["train"])
     30     train_num = len(train_dataset)
     31 
     32     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
     33     flower_list = train_dataset.class_to_idx
     34     cla_dict = dict((val, key) for key, val in flower_list.items())
     35     # write dict into json file
     36     json_str = json.dumps(cla_dict, indent=4)
     37     with open('class_indices.json', 'w') as json_file:
     38         json_file.write(json_str)
     39 
     40     batch_size = 32
     41     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
     42     print('Using {} dataloader workers every process'.format(nw))
     43 
     44     train_loader = torch.utils.data.DataLoader(train_dataset,
     45                                                batch_size=batch_size, shuffle=True,
     46                                                num_workers=0)
     47 
     48     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
     49                                             transform=data_transform["val"])
     50     val_num = len(validate_dataset)
     51     validate_loader = torch.utils.data.DataLoader(validate_dataset,
     52                                                   batch_size=batch_size, shuffle=False,
     53                                                   num_workers=0)
     54 
     55     print("using {} images for training, {} images fot validation.".format(train_num,
     56                                                                            val_num))
     57 
     58     # test_data_iter = iter(validate_loader)
     59     # test_image, test_label = test_data_iter.next()
     60 
     61     # net = torchvision.models.googlenet(num_classes=5)
     62     # model_dict = net.state_dict()
     63     # pretrain_model = torch.load("googlenet.pth")
     64     # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
     65     #             "aux2.fc2.weight", "aux2.fc2.bias",
     66     #             "fc.weight", "fc.bias"]
     67     # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
     68     # model_dict.update(pretrain_dict)
     69     # net.load_state_dict(model_dict)
     70     net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
     71     net.to(device)
     72     loss_function = nn.CrossEntropyLoss()
     73     optimizer = optim.Adam(net.parameters(), lr=0.0003)
     74 
     75     best_acc = 0.0
     76     save_path = './googleNet.pth'
     77     for epoch in range(30):
     78         # train
     79         net.train()
     80         running_loss = 0.0
     81         for step, data in enumerate(train_loader, start=0):
     82             images, labels = data
     83             optimizer.zero_grad()
     84             logits, aux_logits2, aux_logits1 = net(images.to(device))
     85             loss0 = loss_function(logits, labels.to(device))
     86             loss1 = loss_function(aux_logits1, labels.to(device))
     87             loss2 = loss_function(aux_logits2, labels.to(device))
     88             loss = loss0 + loss1 * 0.3 + loss2 * 0.3
     89             loss.backward()
     90             optimizer.step()
     91 
     92             # print statistics
     93             running_loss += loss.item()
     94             # print train process
     95             rate = (step + 1) / len(train_loader)
     96             a = "*" * int(rate * 50)
     97             b = "." * int((1 - rate) * 50)
     98             print("
    train loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
     99         print()
    100 
    101         # validate
    102         net.eval()
    103         acc = 0.0  # accumulate accurate number / epoch
    104         with torch.no_grad():
    105             for val_data in validate_loader:
    106                 val_images, val_labels = val_data
    107                 outputs = net(val_images.to(device))  # eval model only have last output layer
    108                 predict_y = torch.max(outputs, dim=1)[1]
    109                 acc += (predict_y == val_labels.to(device)).sum().item()
    110             val_accurate = acc / val_num
    111             if val_accurate > best_acc:
    112                 best_acc = val_accurate
    113                 torch.save(net.state_dict(), save_path)
    114             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
    115                   (epoch + 1, running_loss / step, val_accurate))
    116 
    117     print('Finished Training')
    118 
    119 
    120 if __name__ == '__main__':
    121     main()

     predict.py

     1 import torch
     2 from model import GoogLeNet
     3 from PIL import Image
     4 from torchvision import transforms
     5 import matplotlib.pyplot as plt
     6 import json
     7 
     8 data_transform = transforms.Compose(
     9     [transforms.Resize((224, 224)),
    10      transforms.ToTensor(),
    11      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    12 
    13 # load image
    14 img = Image.open("../rose.jpg")
    15 plt.imshow(img)
    16 # [N, C, H, W]
    17 img = data_transform(img)
    18 # expand batch dimension
    19 img = torch.unsqueeze(img, dim=0)
    20 
    21 # read class_indict
    22 try:
    23     json_file = open('./class_indices.json', 'r')
    24     class_indict = json.load(json_file)
    25 except Exception as e:
    26     print(e)
    27     exit(-1)
    28 
    29 # create model
    30 model = GoogLeNet(num_classes=5, aux_logits=False)
    31 # load model weights
    32 model_weight_path = "./googleNet.pth"
    33 missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
    34 model.eval()
    35 with torch.no_grad():
    36     # predict class
    37     output = torch.squeeze(model(img))
    38     predict = torch.softmax(output, dim=0)
    39     predict_cla = torch.argmax(predict).numpy()
    40 print(class_indict[str(predict_cla)])
    41 plt.show()
  • 相关阅读:
    url-pattern / /*匹配
    velocity入门
    配置eclipse插件
    Myeclipse 2014 破解
    Eclipse kepler 安装 Dynamic Web Project差距WTP
    Errors running builder 'Faceted Project Validation Builder' on project
    JSF web.xml的各类参数属性配置
    bpm 学习笔记一
    love is ... ...
    .sh_history文件的管理机制
  • 原文地址:https://www.cnblogs.com/sclu/p/14164399.html
Copyright © 2011-2022 走看看