一。网络结构和参数
特点:堆叠多个小尺寸的卷积核来做到和大卷积核一样的感受野。减少网络参数的同时加深了网络深度。
二。模型定义和训练代码
model.py
1 import torch.nn as nn 2 import torch 3 4 5 class VGG(nn.Module): 6 def __init__(self, features, num_classes=1000, init_weights=False): 7 super(VGG, self).__init__() 8 self.features = features 9 self.classifier = nn.Sequential( 10 nn.Dropout(p=0.5), 11 nn.Linear(512*7*7, 2048), 12 nn.ReLU(True), 13 nn.Dropout(p=0.5), 14 nn.Linear(2048, 2048), 15 nn.ReLU(True), 16 nn.Linear(2048, num_classes) 17 ) 18 if init_weights: 19 self._initialize_weights() 20 21 def forward(self, x): 22 # N x 3 x 224 x 224 23 x = self.features(x) 24 # N x 512 x 7 x 7 25 x = torch.flatten(x, start_dim=1) 26 # N x 512*7*7 27 x = self.classifier(x) 28 return x 29 30 def _initialize_weights(self): 31 for m in self.modules(): 32 if isinstance(m, nn.Conv2d): 33 # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 34 nn.init.xavier_uniform_(m.weight) 35 if m.bias is not None: 36 nn.init.constant_(m.bias, 0) 37 elif isinstance(m, nn.Linear): 38 nn.init.xavier_uniform_(m.weight) 39 # nn.init.normal_(m.weight, 0, 0.01) 40 nn.init.constant_(m.bias, 0) 41 42 43 def make_features(cfg: list): 44 layers = [] 45 in_channels = 3 46 for v in cfg: 47 if v == "M": 48 layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 49 else: 50 conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 51 layers += [conv2d, nn.ReLU(True)] 52 in_channels = v 53 return nn.Sequential(*layers) 54 55 56 cfgs = { 57 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 58 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 59 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 60 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 61 } 62 63 64 def vgg(model_name="vgg16", **kwargs): 65 try: 66 cfg = cfgs[model_name] 67 except: 68 print("Warning: model number {} not in cfgs dict!".format(model_name)) 69 exit(-1) 70 model = VGG(make_features(cfg), **kwargs) 71 return model
train.py
1 import torch.nn as nn 2 from torchvision import transforms, datasets 3 import json 4 import os 5 import torch.optim as optim 6 from model import vgg 7 import torch 8 9 10 def main(): 11 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 print("using {} device.".format(device)) 13 14 data_transform = { 15 "train": transforms.Compose([transforms.RandomResizedCrop(224), 16 transforms.RandomHorizontalFlip(), 17 transforms.ToTensor(), 18 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 19 "val": transforms.Compose([transforms.Resize((224, 224)), 20 transforms.ToTensor(), 21 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 22 23 data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path 24 image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path 25 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 26 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 27 transform=data_transform["train"]) 28 train_num = len(train_dataset) 29 30 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 31 flower_list = train_dataset.class_to_idx 32 cla_dict = dict((val, key) for key, val in flower_list.items()) 33 # write dict into json file 34 json_str = json.dumps(cla_dict, indent=4) 35 with open('class_indices.json', 'w') as json_file: 36 json_file.write(json_str) 37 38 batch_size = 32 39 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 40 print('Using {} dataloader workers every process'.format(nw)) 41 42 train_loader = torch.utils.data.DataLoader(train_dataset, 43 batch_size=batch_size, shuffle=True, 44 num_workers=0) 45 46 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 47 transform=data_transform["val"]) 48 val_num = len(validate_dataset) 49 validate_loader = torch.utils.data.DataLoader(validate_dataset, 50 batch_size=batch_size, shuffle=False, 51 num_workers=0) 52 print("using {} images for training, {} images fot validation.".format(train_num, 53 val_num)) 54 55 # test_data_iter = iter(validate_loader) 56 # test_image, test_label = test_data_iter.next() 57 58 model_name = "vgg16" 59 net = vgg(model_name=model_name, num_classes=5, init_weights=True) 60 net.to(device) 61 loss_function = nn.CrossEntropyLoss() 62 optimizer = optim.Adam(net.parameters(), lr=0.0001) 63 64 best_acc = 0.0 65 save_path = './{}Net.pth'.format(model_name) 66 for epoch in range(30): 67 # train 68 net.train() 69 running_loss = 0.0 70 for step, data in enumerate(train_loader, start=0): 71 images, labels = data 72 optimizer.zero_grad() 73 outputs = net(images.to(device)) 74 loss = loss_function(outputs, labels.to(device)) 75 loss.backward() 76 optimizer.step() 77 78 # print statistics 79 running_loss += loss.item() 80 # print train process 81 rate = (step + 1) / len(train_loader) 82 a = "*" * int(rate * 50) 83 b = "." * int((1 - rate) * 50) 84 print(" train loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") 85 print() 86 87 # validate 88 net.eval() 89 acc = 0.0 # accumulate accurate number / epoch 90 with torch.no_grad(): 91 for val_data in validate_loader: 92 val_images, val_labels = val_data 93 optimizer.zero_grad() 94 outputs = net(val_images.to(device)) 95 predict_y = torch.max(outputs, dim=1)[1] 96 acc += (predict_y == val_labels.to(device)).sum().item() 97 val_accurate = acc / val_num 98 if val_accurate > best_acc: 99 best_acc = val_accurate 100 torch.save(net.state_dict(), save_path) 101 print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % 102 (epoch + 1, running_loss / step, val_accurate)) 103 104 print('Finished Training') 105 106 107 if __name__ == '__main__': 108 main()