zoukankan      html  css  js  c++  java
  • 神经网络学习--PyTorch学习05 定义VGGNet网络

    使用数据集猫狗大战

    import time
    
    import torch
    import torchvision
    from torchvision import datasets, transforms
    import os
    import matplotlib.pyplot as plt
    from torch.autograd import Variable
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用GPU 0
    data_dir = "DogsVsCats"
    # 设置数据格式
    data_transform = {x: transforms.Compose([transforms.Scale([64, 64]),  # scale类将原始图缩放至64*64
                        transforms.ToTensor()])
                        for x in ["train", "valid"]}
    # 加载数据
    image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x),
                                transform=data_transform[x])
                      for x in ["train", "valid"]}
    
    # 数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
    # 在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
    dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x],
                                    batch_size=16,
                                    shuffle=True)
                                    for x in ["train", "valid"]}
    
    # 获取一个批次的装载数据  x_example(16,3,64,64) y_example 进行了独热编码,里面全为0和1
    x_example, y_example = next(iter(dataloader["train"]))
    
    # index_classes的 输出结果为{'cat':0,'dog',1}
    index_classes = image_datasets["train"].class_to_idx
    
    #将原始标签的结果存在example_clasees中 {'cat','dog'}
    example_clasees = image_datasets["train"].classes
    
    # 做成网格数据
    img = torchvision.utils.make_grid(x_example)
    img = img.numpy().transpose([1, 2, 0])  # 转换维度
    # print([example_clasees[i] for i in y_example])
    # plt.imshow(img)
    # plt.show()
    
    # VGGNet模型
    class Models(torch.nn.Module):
        def __init__(self):
            super(Models,self).__init__()
            self.Conv = torch.nn.Sequential(
                torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2),
    
                torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2),
    
                torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2),
    
                torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2),
            )
    
            self.Classes = torch.nn.Sequential(
                torch.nn.Linear(4*4*512, 1024),
                torch.nn.ReLU(),
                torch.nn.Dropout(p=0.5),
                torch.nn.Linear(1024, 1024),
                torch.nn.Dropout(p=0.5),
                torch.nn.Linear(1024, 2)
            )
    
        def forward(self, input):
            x = self.Conv(input)
            x = x.view(-1, 4*4*512)
            x = self.Classes(x)
            return x
    
    
    model = Models()
    # print(model)
    
    loss_f = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.00001)
    Use_gpu = torch.cuda.is_available()  # 判断是否存在cuda
    if Use_gpu:
        model = model.cuda()  # ***********************************************************
    epoch_n = 10
    time_open = time.time()
    
    for epoch in range(epoch_n):
        print("Epoch{}/{}".format(epoch,epoch_n-1))
        print("-"*10)
    
        for phase in ["train", "valid"]:
            if phase == "train":
                print("Training...")
                model.train(True)
            else:
                print("Validing...")
                model.train(False)
    
            running_loss = 0.0
            running_corrects = 0
            for batch, data in enumerate(dataloader[phase], 1):  # enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,
                X, y = data
                if Use_gpu:
                    X, y =Variable(X.cuda()),Variable(y.cuda())  # **************************************
                else:
                    X, y = Variable(X), Variable(y)
                y_pred = model(X)  # 得到预测值
                _,pred =torch.max(y_pred,1)
                optimizer.zero_grad()  # 清空梯度
                loss = loss_f(y_pred, y)  # 定义损失函数
    
                if phase == "train":
                    loss.backward()  # 如果是训练,进行反向传播
                    optimizer.step()  # 更新各节点的参数
                running_loss += loss.item()
                running_corrects += torch.sum(pred == y.data)
    
                if batch%500 == 0 and phase == "train":
                    print("Batch{},TrainLoss:{:.4f},Train ACC:{:.4f}".format(
                        batch,running_loss/batch, 100*running_corrects/(16*batch)))
            epocn_loss = running_loss*16/len(image_datasets[phase])
            epoch_acc = 100*running_corrects/len(image_datasets[phase])
            print("{} Loss:{:.4f} Acc:{:4f}%".format(phase, epocn_loss, epoch_acc))
    time_end = time.time()-time_open
    print(time_end)
  • 相关阅读:
    docker search 报错
    mgo连接池
    饿了么这样跳过Redis Cluster遇到的“坑”
    Linux Swap的那些事
    epoll使用详解(精髓)(转)
    select、poll、epoll之间的区别总结[整理](转)
    git merge 和 git rebase 小结(转)
    linux查看端口占用情况
    [LeetCode] Combinations——递归
    C++中的static关键字的总结(转)
  • 原文地址:https://www.cnblogs.com/zuhaoran/p/11502551.html
Copyright © 2011-2022 走看看