zoukankan      html  css  js  c++  java
  • pytorch训练AlexNet

    一。AlexNet网络结构和参数

     

     二。训练部分

    model.py

     1 import torch.nn as nn
     2 import torch
     3 
     4 
     5 class AlexNet(nn.Module):
     6     def __init__(self, num_classes=1000, init_weights=False):
     7         super(AlexNet, self).__init__()
     8         self.features = nn.Sequential(
     9             nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
    10             nn.ReLU(inplace=True),
    11             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
    12             nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
    13             nn.ReLU(inplace=True),
    14             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
    15             nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
    16             nn.ReLU(inplace=True),
    17             nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
    18             nn.ReLU(inplace=True),
    19             nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
    20             nn.ReLU(inplace=True),
    21             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
    22         )
    23         self.classifier = nn.Sequential(
    24             nn.Dropout(p=0.5),
    25             nn.Linear(128 * 6 * 6, 2048),
    26             nn.ReLU(inplace=True),
    27             nn.Dropout(p=0.5),
    28             nn.Linear(2048, 2048),
    29             nn.ReLU(inplace=True),
    30             nn.Linear(2048, num_classes),
    31         )
    32         if init_weights:
    33             self._initialize_weights()
    34 
    35     def forward(self, x):
    36         x = self.features(x)
    37         x = torch.flatten(x, start_dim=1)
    38         x = self.classifier(x)
    39         return x
    40 
    41     def _initialize_weights(self):
    42         for m in self.modules():
    43             if isinstance(m, nn.Conv2d):
    44                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    45                 if m.bias is not None:
    46                     nn.init.constant_(m.bias, 0)
    47             elif isinstance(m, nn.Linear):
    48                 nn.init.normal_(m.weight, 0, 0.01)
    49                 nn.init.constant_(m.bias, 0)

    train.py

      1 import torch
      2 import torch.nn as nn
      3 from torchvision import transforms, datasets, utils
      4 import matplotlib.pyplot as plt
      5 import numpy as np
      6 import torch.optim as optim
      7 from model import AlexNet
      8 import os
      9 import json
     10 import time
     11 
     12 
     13 def main():
     14     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     15     print("using {} device.".format(device))
     16 
     17     data_transform = {
     18         "train": transforms.Compose([transforms.RandomResizedCrop(224),
     19                                      transforms.RandomHorizontalFlip(),
     20                                      transforms.ToTensor(),
     21                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
     22         "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
     23                                    transforms.ToTensor(),
     24                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
     25 
     26     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
     27     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
     28     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
     29     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
     30                                          transform=data_transform["train"])
     31     train_num = len(train_dataset)
     32 
     33     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
     34     flower_list = train_dataset.class_to_idx
     35     cla_dict = dict((val, key) for key, val in flower_list.items())
     36     # write dict into json file
     37     json_str = json.dumps(cla_dict, indent=4)
     38     with open('class_indices.json', 'w') as json_file:
     39         json_file.write(json_str)
     40 
     41     batch_size = 32
     42     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
     43     print('Using {} dataloader workers every process'.format(nw))
     44 
     45     train_loader = torch.utils.data.DataLoader(train_dataset,
     46                                                batch_size=batch_size, shuffle=True,
     47                                                num_workers=0)
     48 
     49     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
     50                                             transform=data_transform["val"])
     51     val_num = len(validate_dataset)
     52     validate_loader = torch.utils.data.DataLoader(validate_dataset,
     53                                                   batch_size=batch_size, shuffle=True,
     54                                                   num_workers=0)
     55 
     56     print("using {} images for training, {} images fot validation.".format(train_num,
     57                                                                            val_num))
     58     # test_data_iter = iter(validate_loader)
     59     # test_image, test_label = test_data_iter.next()
     60     # #
     61     # def imshow(img):
     62     #     img = img / 2 + 0.5  # unnormalize
     63     #     npimg = img.numpy()
     64     #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
     65     #     plt.show()
     66     #
     67     # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
     68     # imshow(utils.make_grid(test_image))
     69 
     70     net = AlexNet(num_classes=5, init_weights=True)
     71 
     72     net.to(device)
     73     loss_function = nn.CrossEntropyLoss()
     74     # pata = list(net.parameters())
     75     optimizer = optim.Adam(net.parameters(), lr=0.0002)
     76 
     77     save_path = './AlexNet.pth'
     78     best_acc = 0.0
     79     for epoch in range(10):
     80         # train
     81         net.train()
     82         running_loss = 0.0
     83         t1 = time.perf_counter()
     84         for step, data in enumerate(train_loader, start=0):
     85             images, labels = data
     86             optimizer.zero_grad()
     87             outputs = net(images.to(device))
     88             loss = loss_function(outputs, labels.to(device))
     89             loss.backward()
     90             optimizer.step()
     91 
     92             # print statistics
     93             running_loss += loss.item()
     94             # print train process
     95             rate = (step + 1) / len(train_loader)
     96             a = "*" * int(rate * 50)
     97             b = "." * int((1 - rate) * 50)
     98             print("
    train loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
     99         print()
    100         print(time.perf_counter()-t1)
    101 
    102         # validate
    103         net.eval()
    104         acc = 0.0  # accumulate accurate number / epoch
    105         with torch.no_grad():
    106             for val_data in validate_loader:
    107                 val_images, val_labels = val_data
    108                 outputs = net(val_images.to(device))
    109                 predict_y = torch.max(outputs, dim=1)[1]
    110                 acc += (predict_y == val_labels.to(device)).sum().item()
    111             val_accurate = acc / val_num
    112             if val_accurate > best_acc:
    113                 best_acc = val_accurate
    114                 torch.save(net.state_dict(), save_path)
    115             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
    116                   (epoch + 1, running_loss / step, val_accurate))
    117 
    118     print('Finished Training')
    119 
    120 
    121 if __name__ == '__main__':
    122     main()
  • 相关阅读:
    ASP.NET MVC 5 学习教程:使用 SQL Server LocalDB
    ASP.NET MVC 5 学习教程:生成的代码详解
    ASP.NET MVC 5 学习教程:通过控制器访问模型的数据
    ASP.NET MVC 5 学习教程:创建连接字符串
    ASP.NET MVC 5 学习教程:添加模型
    ASP.NET MVC 5 学习教程:控制器传递数据给视图
    ASP.NET MVC 5 学习教程:修改视图和布局页
    ASP.NET MVC 5 学习教程:添加视图
    ASP.NET MVC 5 学习教程:添加控制器
    ASP.NET MVC 5 学习教程:快速入门
  • 原文地址:https://www.cnblogs.com/sclu/p/14162460.html
Copyright © 2011-2022 走看看