zoukankan      html  css  js  c++  java
  • AlexNet复现代码

    train.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import os
    from tensorboardX import SummaryWriter
    import torchvision.datasets as Datasets
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from tqdm import tqdm
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    NUM_EPOCHS = 10
    BATCH_SIZE = 4
    MOMENTUM = 0.9
    LR_DECAY = 0.0005
    LR_INIT = 0.01
    IMAGE_DIM = 227
    NUM_CLASSES = 10
    DEVICE_DIS = [0, 1, 2, 3]
    
    
    
    #数据集和输出路径
    INPUT_ROOT_DIR = 'alexnet_data_in'
    TRAIN_IMG_DIR = INPUT_ROOT_DIR + '/imagenet'
    OUTPUT_DIR = 'alexnet_data_out'
    LOG_DIR = OUTPUT_DIR + '/tblogs'
    CHECKPOINT_DIR = OUTPUT_DIR + '/models'
    
    os.makedirs(CHECKPOINT_DIR, exist_ok = True)
    
    
    
    class AlexNet(nn.Module):
        def __init__(self):
            super(AlexNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 96, 11, stride = 4)
            self.conv2 = nn.Conv2d(96, 256, 5, padding = 2)
            self.conv3 = nn.Conv2d(256, 384, 2, padding = 1)
            self.conv4 = nn.Conv2d(384, 384, 3, padding = 1)
            self.conv5 = nn.Conv2d(384, 256, 3, padding = 1)
            self.fc1 = nn.Linear((256 * 6 * 6), 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, NUM_CLASSES)
        def forward(self, x):
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv1(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv2(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.relu(self.conv3(x))
            x = F.relu(self.conv4(x))
            x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size = 3, stride = 2)
            x = x.view(x.size()[0], -1)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = F.relu(self.fc2(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = self.fc3(x)
            return x
    
    
    
    if __name__ == '__main__':
    
        tbwrite = SummaryWriter(log_dir = LOG_DIR)
        print('tensorboardX summary write created')
        alexnet = AlexNet().to(device)
        print(alexnet)
        print('alexnet created')
        dataset = Datasets.ImageFolder(TRAIN_IMG_DIR, transform = transforms.Compose([
            transforms.CenterCrop(IMAGE_DIM),
            #将读进来的图片转换为tensor
            transforms.ToTensor(),
            #对tensor进行归一化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]))
        print('dataset created')
    
        dataloader = Dataloader.DataLoader(
            dataset,
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            batch_size = BATCH_SIZE)
        print('dataloader created')
        optimizer = torch.optim.Adam(alexnet.parameters(), lr = 0.0001)
        #每30轮将lr * 0.1
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)
        loss = nn.CrossEntropyLoss()
        print('Start Training...')
        cnt = 0
        for epoch in range(NUM_EPOCHS):
            true_total = 0
            img_total = 0
            loss_total = 0
            for imgs, classes in tqdm(dataloader):
                cnt += 1
                optimizer.zero_grad()
                imgs, classes = imgs.to(device), classes.to(device)
                print(classes)
                output = alexnet(imgs)
                loss_value = loss(output, classes)
                loss_value.backward()
                optimizer.step()
    
                preds = torch.max(output, 1)[1]
                true_total += torch.sum(preds == classes)
                img_total += len(classes)
                loss_total += loss_value.item()
    
            loss_k = float(loss_total) / float(cnt)
            accuracy = float(true_total) / float(img_total)
            print('epoch: {} 	 Loss: {:.4f} 	 Acc: {:.2f}'.format(epoch + 1, loss_k, accuracy))
            tbwrite.add_scalar('loss', loss_k, epoch + 1)
            tbwrite.add_scalar('accuracy', accuracy, epoch + 1)
    
    
    
            #log information and add to tensorboard
    
            lr_scheduler.step()
    
    
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(CHECKPOINT_DIR, 'alexnet_staes_e{}.pkl'.format(epoch + 1))
                state = {
                    'epoch' : epoch,
                    'optimizer' : optimizer.state_dict(),
                    'model' : alexnet.state_dict(),
                }
                torch.save(state, checkpoint_path)

    predict.py

    import torch
    import torch.nn
    import torch.nn.functional as F
    import numpy as np
    import torchvision.datasets as Dataset
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    from train import AlexNet
    import cv2
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    IMG_DIM = 227
    INPUT_ROOT_DIR = 'alexnet_data_in'
    OUTPUT_ROOT_DIR = 'alexnet_data_out'
    IMG_DIR = INPUT_ROOT_DIR + '/test_imagenet'
    
    checkpoint_dir = os.path.join(OUTPUT_ROOT_DIR, 'models', 'alexnet_staes_e10.pkl')
    checkpoint = torch.load(checkpoint_dir)
    model = AlexNet()
    model.to(device)
    model.load_state_dict(checkpoint['model'])
    
    
    def open_list(dir):
        for home, files, dirs in os.walk(dir):
            return dirs
    
    def show(name, img):
        cv2.imshow(name, img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    
    transform = transforms.Compose([
        transforms.CenterCrop(IMG_DIM),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    
    
    img_list = open_list(IMG_DIR)
    
    
    for img_name in img_list:
        img_path = os.path.join(IMG_DIR, img_name)
        img = Image.open(img_path)
        img = transform(img)
        img.to(device)
        #注意要reshape一下,module是默认有批次的
        img = img.reshape([-1, 3, 227, 227])
        preds = model(img)
        preds = torch.max(preds, 1)[1]
        print(preds.item())
        img = cv2.imread(img_path)
        show('{}'.format(preds.item()), img)
    自己选择的路,跪着也要走完。朋友们,虽然这个世界日益浮躁起来,只要能够为了当时纯粹的梦想和感动坚持努力下去,不管其它人怎么样,我们也能够保持自己的本色走下去。
  • 相关阅读:
    MySQL启动和关闭命令总结
    MySQL数据库5.6版本首次安装Root密码问题
    tomcat 9性能调优注意事项
    扫除减脂之路上的几个小障碍
    MySQL常见面试题
    关于邮箱发送邮件二之附件及图片
    关于邮箱发送邮件
    关于算法
    python中常见的数据类型
    C++实现复数类的输入输出流以及+-*/的重载
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15325507.html
Copyright © 2011-2022 走看看