zoukankan      html  css  js  c++  java
  • AlexNet复现代码

    train.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import os
    from tensorboardX import SummaryWriter
    import torchvision.datasets as Datasets
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from tqdm import tqdm
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    NUM_EPOCHS = 10
    BATCH_SIZE = 4
    MOMENTUM = 0.9
    LR_DECAY = 0.0005
    LR_INIT = 0.01
    IMAGE_DIM = 227
    NUM_CLASSES = 10
    DEVICE_DIS = [0, 1, 2, 3]
    
    
    
    #数据集和输出路径
    INPUT_ROOT_DIR = 'alexnet_data_in'
    TRAIN_IMG_DIR = INPUT_ROOT_DIR + '/imagenet'
    OUTPUT_DIR = 'alexnet_data_out'
    LOG_DIR = OUTPUT_DIR + '/tblogs'
    CHECKPOINT_DIR = OUTPUT_DIR + '/models'
    
    os.makedirs(CHECKPOINT_DIR, exist_ok = True)
    
    
    
    class AlexNet(nn.Module):
        def __init__(self):
            super(AlexNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 96, 11, stride = 4)
            self.conv2 = nn.Conv2d(96, 256, 5, padding = 2)
            self.conv3 = nn.Conv2d(256, 384, 2, padding = 1)
            self.conv4 = nn.Conv2d(384, 384, 3, padding = 1)
            self.conv5 = nn.Conv2d(384, 256, 3, padding = 1)
            self.fc1 = nn.Linear((256 * 6 * 6), 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, NUM_CLASSES)
        def forward(self, x):
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv1(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.max_pool2d(F.local_response_norm(F.relu(self.conv2(x)), size = 5, alpha = 0.0001, beta = 0.75, k = 2), kernel_size = 3, stride = 2)
            x = F.relu(self.conv3(x))
            x = F.relu(self.conv4(x))
            x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size = 3, stride = 2)
            x = x.view(x.size()[0], -1)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = F.relu(self.fc2(x))
            x = F.dropout(x, p = 0.5, training = True, inplace = True)
            x = self.fc3(x)
            return x
    
    
    
    if __name__ == '__main__':
    
        tbwrite = SummaryWriter(log_dir = LOG_DIR)
        print('tensorboardX summary write created')
        alexnet = AlexNet().to(device)
        print(alexnet)
        print('alexnet created')
        dataset = Datasets.ImageFolder(TRAIN_IMG_DIR, transform = transforms.Compose([
            transforms.CenterCrop(IMAGE_DIM),
            #将读进来的图片转换为tensor
            transforms.ToTensor(),
            #对tensor进行归一化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]))
        print('dataset created')
    
        dataloader = Dataloader.DataLoader(
            dataset,
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            batch_size = BATCH_SIZE)
        print('dataloader created')
        optimizer = torch.optim.Adam(alexnet.parameters(), lr = 0.0001)
        #每30轮将lr * 0.1
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)
        loss = nn.CrossEntropyLoss()
        print('Start Training...')
        cnt = 0
        for epoch in range(NUM_EPOCHS):
            true_total = 0
            img_total = 0
            loss_total = 0
            for imgs, classes in tqdm(dataloader):
                cnt += 1
                optimizer.zero_grad()
                imgs, classes = imgs.to(device), classes.to(device)
                print(classes)
                output = alexnet(imgs)
                loss_value = loss(output, classes)
                loss_value.backward()
                optimizer.step()
    
                preds = torch.max(output, 1)[1]
                true_total += torch.sum(preds == classes)
                img_total += len(classes)
                loss_total += loss_value.item()
    
            loss_k = float(loss_total) / float(cnt)
            accuracy = float(true_total) / float(img_total)
            print('epoch: {} 	 Loss: {:.4f} 	 Acc: {:.2f}'.format(epoch + 1, loss_k, accuracy))
            tbwrite.add_scalar('loss', loss_k, epoch + 1)
            tbwrite.add_scalar('accuracy', accuracy, epoch + 1)
    
    
    
            #log information and add to tensorboard
    
            lr_scheduler.step()
    
    
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(CHECKPOINT_DIR, 'alexnet_staes_e{}.pkl'.format(epoch + 1))
                state = {
                    'epoch' : epoch,
                    'optimizer' : optimizer.state_dict(),
                    'model' : alexnet.state_dict(),
                }
                torch.save(state, checkpoint_path)

    predict.py

    import torch
    import torch.nn
    import torch.nn.functional as F
    import numpy as np
    import torchvision.datasets as Dataset
    import torchvision.transforms as transforms
    import torch.utils.data.dataloader as Dataloader
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    from train import AlexNet
    import cv2
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    IMG_DIM = 227
    INPUT_ROOT_DIR = 'alexnet_data_in'
    OUTPUT_ROOT_DIR = 'alexnet_data_out'
    IMG_DIR = INPUT_ROOT_DIR + '/test_imagenet'
    
    checkpoint_dir = os.path.join(OUTPUT_ROOT_DIR, 'models', 'alexnet_staes_e10.pkl')
    checkpoint = torch.load(checkpoint_dir)
    model = AlexNet()
    model.to(device)
    model.load_state_dict(checkpoint['model'])
    
    
    def open_list(dir):
        for home, files, dirs in os.walk(dir):
            return dirs
    
    def show(name, img):
        cv2.imshow(name, img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    
    transform = transforms.Compose([
        transforms.CenterCrop(IMG_DIM),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    
    
    img_list = open_list(IMG_DIR)
    
    
    for img_name in img_list:
        img_path = os.path.join(IMG_DIR, img_name)
        img = Image.open(img_path)
        img = transform(img)
        img.to(device)
        #注意要reshape一下,module是默认有批次的
        img = img.reshape([-1, 3, 227, 227])
        preds = model(img)
        preds = torch.max(preds, 1)[1]
        print(preds.item())
        img = cv2.imread(img_path)
        show('{}'.format(preds.item()), img)
    自己选择的路,跪着也要走完。朋友们,虽然这个世界日益浮躁起来,只要能够为了当时纯粹的梦想和感动坚持努力下去,不管其它人怎么样,我们也能够保持自己的本色走下去。
  • 相关阅读:
    git报错
    rabbitmq关于guest用户登录失败解决方法
    【转】Linux下RabbitMQ服务器搭建(单实例)
    saltstack安装配置(yum)
    linux下搭建禅道项目管理系统
    git用户限制ssh登录服务器
    中央定调,“新基建”彻底火了!这七大科技领域要爆发
    数据可视化使用小贴士,这样的错误别再犯了
    5G国战:一部国家奋斗的血泪史,看看各国是如何角力百年?
    还没有一个人能够把并发编程讲解的这么透彻
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15325507.html
Copyright © 2011-2022 走看看