zoukankan      html  css  js  c++  java
  • 最简单的---了解深度学习训练搭建流程

    import torch
    import torchvision
    import matplotlib.pyplot as plt
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torchvision.transforms import transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #定义ResBlock,见Resnet Learning图
    class ResBlock(nn.Module):
        def __init__(self, in_channel, out_channel, stride=1, shortcut=None):
            super(ResBlock, self).__init__()
            self.left = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 3, stride, 1, bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(True),
                nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False),
                nn.BatchNorm2d(out_channel),
            )
            self.right = shortcut
    
        def forward(self, x):
            out = self.left(x)
            residual = x if self.right is None else self.right(x)
            out += residual
            return F.relu(out)
    
    #定义make_layer
    def make_layer(in_channel, out_channel, block_num, stride=1):
        shortcut = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1, stride),
            nn.BatchNorm2d(out_channel))
        layers = list()
        layers.append(ResBlock(in_channel, out_channel, stride, shortcut))
    
        for i in range(1, block_num):
            layers.append(ResBlock(out_channel, out_channel))
        return nn.Sequential(*layers)
    
    # 堆叠Resnet,见上表所示结构
    class Resnet(nn.Module):
        def __init__(self):
            super(Resnet, self).__init__()
            self.pre = nn.Sequential(
                nn.Conv2d(3, 64, 7, 2, 3, bias=False), nn.BatchNorm2d(64),
                nn.ReLU(True), nn.MaxPool2d(3, 2, 1))
            self.layer1 = make_layer(64, 64, 2)
            self.layer2 = make_layer(64, 128, 2, stride=2)
            self.layer3 = make_layer(128, 256, 2, stride=2)
            self.layer4 = make_layer(256, 512, 2, stride=2)
            self.avg = nn.AvgPool2d(7)
            self.classifier = nn.Sequential(nn.Linear(512, 10))
    
        def forward(self, x):
            x = self.pre(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.avg(x)
            x = x.view(x.size(0), -1)
            out = self.classifier(x)
            return out
    
    #训练函数 train
    def net_train():
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # 将输入传入GPU
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
    
            # 将梯度置零
            optimizer.zero_grad()
    
            # 前向传播-计算误差-反向传播-优化
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # 计算误差并显示
            running_loss += loss.item()
            if i % 127 == 0:  # print every mini-batches
                print(
                    '[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 128))
                running_loss = 0.0
    
        print('Training Epoch Finished')
    
    #测试函数 test
    def net_test():
        correct = 0
        total = 0
        # 关闭梯度
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        print('Accuracy of the network on the 10000 test images: %d %%' %
              (100 * correct / total))
        return
    
    
    #数据集函数 data_loader
    def net_dataloader(root, train_transform, test_transform):
        trainset = torchvision.datasets.CIFAR10(
            root, train=True, transform=train_transform, download=True)
        testset = torchvision.datasets.CIFAR10(
            root, train=False, transform=test_transform, download=True)
        trainloader = DataLoader(
            trainset, batch_size=128, shuffle=True, num_workers=4)
        testloader = DataLoader(
            testset, batch_size=16, shuffle=False, num_workers=4)
        print('Initializing Dataset...')
        return trainloader, testloader
    
    
    # main
    if __name__ == "__main__":
        # 创建实例并送入GPU
        net = Resnet().to(device)
        # 选择误差 loss
        criterion = nn.CrossEntropyLoss()
        # 选择优化器
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
        # 数据位置
        root = './pydata/data/'
        # 数据处理 Augmentation
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # 创建数据loader
        trainloader, testloader = net_dataloader(root, train_transform,
                                                 test_transform)
        # run
        n_epoch = 5  #改变epoch
        for epoch in range(n_epoch):
            print('Training...')
            net_train()  #每个epoch训练一次,测试一次
            print('Testing...')
            net_test()
  • 相关阅读:
    第七讲 宋词:婉约之曲与豪放之声
    P2024 食物链
    可以吹一年的事
    信息传递
    11.11模拟赛总结(又名斗地主战记)
    11.9模拟赛总结
    扩展欧几里得(exgcd模板)
    发糖果(拓扑排序模板)
    高斯消元
    关于我
  • 原文地址:https://www.cnblogs.com/ariel-dreamland/p/12551284.html
Copyright © 2011-2022 走看看