zoukankan      html  css  js  c++  java
  • AlexNet复现代码

    train.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import os
    from tensorboardX import SummaryWriter
    import torchvision.datasets as Datasets
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from tqdm import tqdm
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    NUM_EPOCHS = 10
    BATCH_SIZE = 4
    MOMENTUM = 0.9
    LR_DECAY = 0.0005
    LR_INIT = 0.01
    IMAGE_DIM = 227
    NUM_CLASSES = 10
    DEVICE_DIS = [0, 1, 2, 3]
    
    
    
    #数据集和输出路径
    INPUT_ROOT_DIR = 'alexnet_data_in'
    TRAIN_IMG_DIR = INPUT_ROOT_DIR + '/imagenet'
    OUTPUT_DIR = 'alexnet_data_out'
    LOG_DIR = OUTPUT_DIR + '/tblogs'
    CHECKPOINT_DIR = OUTPUT_DIR + '/models'
    
    os.makedirs(CHECKPOINT_DIR, exist_ok = True)
    
    
    
    class AlexNet(nn.Module):
        def __init__(self):
            super(AlexNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 96, 11, stride = 4)
            self.conv2 = nn.Conv2d(96, 256, 5, padding = 2)
            self.conv3 = nn.Conv2d(256, 384, 2, padding = 1)
            self.conv4 = nn.Conv2d(384, 384, 3, padding = 1)
            self.conv5 = nn.Conv2d(384, 256, 3, padding = 1)
            self.fc1 = nn.Linear((256 * 6 * 6), 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, NUM_CLASSES)
        def forward(self, x):
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv1(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv2(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.relu(self.conv3(x))
            x = F.relu(self.conv4(x))
            x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size = 3, stride = 2)
            x = x.view(x.size()[0], -1)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = F.relu(self.fc2(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = self.fc3(x)
            return x
    
    
    
    if __name__ == '__main__':
    
        tbwrite = SummaryWriter(log_dir = LOG_DIR)
        print('tensorboardX summary write created')
        alexnet = AlexNet().to(device)
        print(alexnet)
        print('alexnet created')
        dataset = Datasets.ImageFolder(TRAIN_IMG_DIR, transform = transforms.Compose([
            transforms.CenterCrop(IMAGE_DIM),
            #将读进来的图片转换为tensor
            transforms.ToTensor(),
            #对tensor进行归一化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]))
        print('dataset created')
    
        dataloader = Dataloader.DataLoader(
            dataset,
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            batch_size = BATCH_SIZE)
        print('dataloader created')
        optimizer = torch.optim.Adam(alexnet.parameters(), lr = 0.0001)
        #每30轮将lr * 0.1
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)
        loss = nn.CrossEntropyLoss()
        print('Start Training...')
        cnt = 0
        for epoch in range(NUM_EPOCHS):
            true_total = 0
            img_total = 0
            loss_total = 0
            for imgs, classes in tqdm(dataloader):
                cnt += 1
                optimizer.zero_grad()
                imgs, classes = imgs.to(device), classes.to(device)
                print(classes)
                output = alexnet(imgs)
                loss_value = loss(output, classes)
                loss_value.backward()
                optimizer.step()
    
                preds = torch.max(output, 1)[1]
                true_total += torch.sum(preds == classes)
                img_total += len(classes)
                loss_total += loss_value.item()
    
            loss_k = float(loss_total) / float(cnt)
            accuracy = float(true_total) / float(img_total)
            print('epoch: {} 	 Loss: {:.4f} 	 Acc: {:.2f}'.format(epoch + 1, loss_k, accuracy))
            tbwrite.add_scalar('loss', loss_k, epoch + 1)
            tbwrite.add_scalar('accuracy', accuracy, epoch + 1)
    
    
    
            #log information and add to tensorboard
    
            lr_scheduler.step()
    
    
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(CHECKPOINT_DIR, 'alexnet_staes_e{}.pkl'.format(epoch + 1))
                state = {
                    'epoch' : epoch,
                    'optimizer' : optimizer.state_dict(),
                    'model' : alexnet.state_dict(),
                }
                torch.save(state, checkpoint_path)

    predict.py

    import torch
    import torch.nn
    import torch.nn.functional as F
    import numpy as np
    import torchvision.datasets as Dataset
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    from train import AlexNet
    import cv2
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    IMG_DIM = 227
    INPUT_ROOT_DIR = 'alexnet_data_in'
    OUTPUT_ROOT_DIR = 'alexnet_data_out'
    IMG_DIR = INPUT_ROOT_DIR + '/test_imagenet'
    
    checkpoint_dir = os.path.join(OUTPUT_ROOT_DIR, 'models', 'alexnet_staes_e10.pkl')
    checkpoint = torch.load(checkpoint_dir)
    model = AlexNet()
    model.to(device)
    model.load_state_dict(checkpoint['model'])
    
    
    def open_list(dir):
        for home, files, dirs in os.walk(dir):
            return dirs
    
    def show(name, img):
        cv2.imshow(name, img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    
    transform = transforms.Compose([
        transforms.CenterCrop(IMG_DIM),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    
    
    img_list = open_list(IMG_DIR)
    
    
    for img_name in img_list:
        img_path = os.path.join(IMG_DIR, img_name)
        img = Image.open(img_path)
        img = transform(img)
        img.to(device)
        #注意要reshape一下,module是默认有批次的
        img = img.reshape([-1, 3, 227, 227])
        preds = model(img)
        preds = torch.max(preds, 1)[1]
        print(preds.item())
        img = cv2.imread(img_path)
        show('{}'.format(preds.item()), img)
    自己选择的路,跪着也要走完。朋友们,虽然这个世界日益浮躁起来,只要能够为了当时纯粹的梦想和感动坚持努力下去,不管其它人怎么样,我们也能够保持自己的本色走下去。
  • 相关阅读:
    第十章 迭代器模式 Iterator
    第四章:使用Proxy代理让客户端服务端分工合作。
    第三章:真正弄清楚一个Mod的组织结构
    第二章:开始开发mod前你需要知道的一些事情
    第一章:在IDEA里搭建基于Forge的Minecraft mod开发环境
    Android实现真正的ViewPager【平滑过渡】+【循环滚动】!!!顺带还有【末页跳转】。
    关于坑爹的PopupWindow的“阻塞”争议问题:Android没有真正的“阻塞式”对话框
    快排-Go版本
    链表翻转(按K个一组)(Go语言)
    牛客刷题-重建二叉树(GO语言版)
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15325507.html
Copyright © 2011-2022 走看看