zoukankan      html  css  js  c++  java
  • AlexNet复现代码

    train.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import os
    from tensorboardX import SummaryWriter
    import torchvision.datasets as Datasets
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from tqdm import tqdm
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    NUM_EPOCHS = 10
    BATCH_SIZE = 4
    MOMENTUM = 0.9
    LR_DECAY = 0.0005
    LR_INIT = 0.01
    IMAGE_DIM = 227
    NUM_CLASSES = 10
    DEVICE_DIS = [0, 1, 2, 3]
    
    
    
    #数据集和输出路径
    INPUT_ROOT_DIR = 'alexnet_data_in'
    TRAIN_IMG_DIR = INPUT_ROOT_DIR + '/imagenet'
    OUTPUT_DIR = 'alexnet_data_out'
    LOG_DIR = OUTPUT_DIR + '/tblogs'
    CHECKPOINT_DIR = OUTPUT_DIR + '/models'
    
    os.makedirs(CHECKPOINT_DIR, exist_ok = True)
    
    
    
    class AlexNet(nn.Module):
        def __init__(self):
            super(AlexNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 96, 11, stride = 4)
            self.conv2 = nn.Conv2d(96, 256, 5, padding = 2)
            self.conv3 = nn.Conv2d(256, 384, 2, padding = 1)
            self.conv4 = nn.Conv2d(384, 384, 3, padding = 1)
            self.conv5 = nn.Conv2d(384, 256, 3, padding = 1)
            self.fc1 = nn.Linear((256 * 6 * 6), 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, NUM_CLASSES)
        def forward(self, x):
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv1(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv2(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.relu(self.conv3(x))
            x = F.relu(self.conv4(x))
            x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size = 3, stride = 2)
            x = x.view(x.size()[0], -1)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = F.relu(self.fc2(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = self.fc3(x)
            return x
    
    
    
    if __name__ == '__main__':
    
        tbwrite = SummaryWriter(log_dir = LOG_DIR)
        print('tensorboardX summary write created')
        alexnet = AlexNet().to(device)
        print(alexnet)
        print('alexnet created')
        dataset = Datasets.ImageFolder(TRAIN_IMG_DIR, transform = transforms.Compose([
            transforms.CenterCrop(IMAGE_DIM),
            #将读进来的图片转换为tensor
            transforms.ToTensor(),
            #对tensor进行归一化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]))
        print('dataset created')
    
        dataloader = Dataloader.DataLoader(
            dataset,
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            batch_size = BATCH_SIZE)
        print('dataloader created')
        optimizer = torch.optim.Adam(alexnet.parameters(), lr = 0.0001)
        #每30轮将lr * 0.1
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)
        loss = nn.CrossEntropyLoss()
        print('Start Training...')
        cnt = 0
        for epoch in range(NUM_EPOCHS):
            true_total = 0
            img_total = 0
            loss_total = 0
            for imgs, classes in tqdm(dataloader):
                cnt += 1
                optimizer.zero_grad()
                imgs, classes = imgs.to(device), classes.to(device)
                print(classes)
                output = alexnet(imgs)
                loss_value = loss(output, classes)
                loss_value.backward()
                optimizer.step()
    
                preds = torch.max(output, 1)[1]
                true_total += torch.sum(preds == classes)
                img_total += len(classes)
                loss_total += loss_value.item()
    
            loss_k = float(loss_total) / float(cnt)
            accuracy = float(true_total) / float(img_total)
            print('epoch: {} 	 Loss: {:.4f} 	 Acc: {:.2f}'.format(epoch + 1, loss_k, accuracy))
            tbwrite.add_scalar('loss', loss_k, epoch + 1)
            tbwrite.add_scalar('accuracy', accuracy, epoch + 1)
    
    
    
            #log information and add to tensorboard
    
            lr_scheduler.step()
    
    
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(CHECKPOINT_DIR, 'alexnet_staes_e{}.pkl'.format(epoch + 1))
                state = {
                    'epoch' : epoch,
                    'optimizer' : optimizer.state_dict(),
                    'model' : alexnet.state_dict(),
                }
                torch.save(state, checkpoint_path)

    predict.py

    import torch
    import torch.nn
    import torch.nn.functional as F
    import numpy as np
    import torchvision.datasets as Dataset
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    from train import AlexNet
    import cv2
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    IMG_DIM = 227
    INPUT_ROOT_DIR = 'alexnet_data_in'
    OUTPUT_ROOT_DIR = 'alexnet_data_out'
    IMG_DIR = INPUT_ROOT_DIR + '/test_imagenet'
    
    checkpoint_dir = os.path.join(OUTPUT_ROOT_DIR, 'models', 'alexnet_staes_e10.pkl')
    checkpoint = torch.load(checkpoint_dir)
    model = AlexNet()
    model.to(device)
    model.load_state_dict(checkpoint['model'])
    
    
    def open_list(dir):
        for home, files, dirs in os.walk(dir):
            return dirs
    
    def show(name, img):
        cv2.imshow(name, img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    
    transform = transforms.Compose([
        transforms.CenterCrop(IMG_DIM),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    
    
    img_list = open_list(IMG_DIR)
    
    
    for img_name in img_list:
        img_path = os.path.join(IMG_DIR, img_name)
        img = Image.open(img_path)
        img = transform(img)
        img.to(device)
        #注意要reshape一下,module是默认有批次的
        img = img.reshape([-1, 3, 227, 227])
        preds = model(img)
        preds = torch.max(preds, 1)[1]
        print(preds.item())
        img = cv2.imread(img_path)
        show('{}'.format(preds.item()), img)
    自己选择的路,跪着也要走完。朋友们,虽然这个世界日益浮躁起来,只要能够为了当时纯粹的梦想和感动坚持努力下去,不管其它人怎么样,我们也能够保持自己的本色走下去。
  • 相关阅读:
    IOS AutoLayout 代码实现约束—VFL
    理解iOS Event Handling
    一些优秀的iOS第三方库
    iOS中NSNotification、delegate、KVO三者之间的区别与联系?
    laravel 框架加载自定义函数/类文件
    Nodejs 使用 socket.io 简单实现实时通信
    Redis 与 Memcache 的异同之处
    Redis 服务安装
    PHP 依赖管理神器 Composer 基本使用
    Ajax无刷新图片插件使用
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15325507.html
Copyright © 2011-2022 走看看