zoukankan      html  css  js  c++  java
  • vgg16复现

    主要是练了一下数据读取

    这次用的cifa10,整个是一个字典,取了前100个去训练了一下

    要先把每一行reshape成32 * 32 * 3

    self.data = self.data.reshape(-1, 32, 32, 3)

     __getitem__ 里放到tranforms之前先Image.fromarray()

    VGG_dataset:

    from torch.utils import data
    from PIL import Image
    import random
    import torchvision.transforms as T
    import matplotlib.pyplot as plt
    
    def unpickle(file):
        import pickle
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict
    
    # imgs = unpickle('H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1')
    # print(imgs[b'data'].reshape(-1, 3, 32, 32))
    
    
    
    class Dataset(data.Dataset):
        def __init__(self, root, train = True, test = False):
            self.test = test
            self.train = train
            imgs = unpickle(root)
            self.data = imgs[b'data'][: 100, :]
            self.data = self.data.reshape(-1, 32, 32, 3)
            self.label = imgs[b'labels'][: 100]
    
            if self.train:
                self.transforms = T.Compose([
                    T.Scale(random.randint(256, 384)),
                    T.RandomCrop(224),
                    T.ToTensor(),
                    T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
                ])
            elif self.test:
                self.transforms = T.Compose([
                    T.Scale(224),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
    
        def __getitem__(self, index):
            data = Image.fromarray(self.data[index])
            data = self.transforms(data)
            return data, self.label[index]
        def __len__(self):
            return len(self.label)

    config:

    class configuration:
        train_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1'
        test_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/test_batch'
        label_nums = 10
        batch_size = 4
        epochs = 10
        lr = 0.01

    VGG:

    import torch
    import torch.nn as nn
    import torch.utils.data.dataloader as Dataloader
    import numpy as np
    import torch.nn.functional as F
    from config import configuration
    from VGG_dataset import Dataset
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    from PIL import Image
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    con = configuration()
    
    class vgg(nn.Module):
        def __init__(self):
            super(vgg, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(64, 64,kernel_size = 3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, stride=1, padding=1)
            self.conv4 = nn.Conv2d(128, 128, kernel_size = 3, stride=1, padding=1)
            self.conv5 = nn.Conv2d(128, 256, kernel_size = 3, stride=1, padding=1)
            self.conv6 = nn.Conv2d(256, 256, kernel_size = 3, stride=1, padding=1)
            self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
            self.conv8 = nn.Conv2d(256, 512, kernel_size = 3, stride=1, padding=1)
            self.conv9 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
            self.conv10 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
            self.conv11 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
            self.conv12 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
            self.conv13 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
            self.fc1 = nn.Linear(512 * 7 * 7, 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, con.label_nums)
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = F.relu(self.conv3(x))
            x = F.max_pool2d(F.relu(self.conv4(x)), 2)
            x = F.relu(self.conv5(x))
            x = F.relu(self.conv6(x))
            x = F.max_pool2d(F.relu(self.conv7(x)), 2)
            x = F.relu(self.conv8(x))
            x = F.relu(self.conv9(x))
            x = F.max_pool2d(F.relu(self.conv10(x)), 2)
            x = F.relu(self.conv11(x))
            x = F.relu(self.conv12(x))
            x = F.max_pool2d(F.relu(self.conv13(x)), 2)
            x = x.view(x.size()[0], -1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # img = Image.open('H:/C5AM385_Intensity.jpg')
    # print(np.array(img).shape)
    
    
    if __name__ == '__main__':
        model = vgg()
        model.to(device)
        train_dataset = Dataset(con.train_root)
        test_dataset = Dataset(con.test_root, False, True)
        train_dataloader = Dataloader.DataLoader(train_dataset, batch_size = con.batch_size, shuffle = True, num_workers = 4)
        loss = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr = con.lr)
    
        for epoch in range(con.epochs):
            total_loss = 0
            cnt = 0
            true_label = 0
            for data, label in tqdm(train_dataloader):
                # print(np.array(data[0]).shape)
                # plt.imshow(data[0])
                # plt.show()
    
                optimizer.zero_grad()
                data.to(device)
                label.to(device)
                output = model(data)
                loss_value = loss(output, label)
                loss_value.backward()
                optimizer.step()
                output = torch.max(output, 1)[1]
                total_loss += loss_value
                true_label += torch.sum(output == label)
                cnt += 1
            loss_mean = total_loss / float(cnt)
            accuracy = true_label / float(len(train_dataset))
            print('Loss:{:.4f}, Accuracy:{:.2f}'.format(loss_mean, accuracy))
        print('Train Accepted!')
    自己选择的路,跪着也要走完。朋友们,虽然这个世界日益浮躁起来,只要能够为了当时纯粹的梦想和感动坚持努力下去,不管其它人怎么样,我们也能够保持自己的本色走下去。
  • 相关阅读:
    SAP PI 如何实现消息定义查询
    EWM与ERP交互程序
    ITS Mobile Template interpretation failed. Template does not exist
    SAP Material Flow System (MFS) 物料流系统简介
    SAP EWM Table list
    EWM RF 屏幕增强
    SAP EWM TCODE list
    SAP扩展仓库管理(SAPEWM)在线研讨会笔记
    ERP与EWM集成配置ERP端组织架构(二)
    EWM RF(Radio Frequency)简介
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15364206.html
Copyright © 2011-2022 走看看