zoukankan      html  css  js  c++  java
  • Pytorch模型之卷积神经网络

    pytorch实现卷积神经网络

    import torch 
    import torch.nn as nn
    import torchvision
    import torchvision.transforms as transforms
    
    
    # 配置计算设备
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # 设置超参数
    num_epochs = 5
    num_classes = 10
    batch_size = 100
    learning_rate = 0.001
    
    # 下载 MNIST 数据集
    train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                               train=True, 
                                               transform=transforms.ToTensor(),
                                               download=True)
    
    test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                              train=False, 
                                              transform=transforms.ToTensor())
    
    # DataLoader加载数据集
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size, 
                                              shuffle=False)
    
    # CNN, 两个卷积层
    class ConvNet(nn.Module):
        def __init__(self, num_classes=10):
            super(ConvNet, self).__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer2 = nn.Sequential(
                nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.fc = nn.Linear(7*7*32, num_classes)
            
        def forward(self, x):
            out = self.layer1(x)
            out = self.layer2(out)
            out = out.reshape(out.size(0), -1)
            out = self.fc(out)
            return out
    
    model = ConvNet(num_classes).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # 训练模型
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            optimizer.zero_grad() # 手动梯度归零
            loss.backward()
            optimizer.step()
            
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    
    # 测试模型
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    
    # Save the model checkpoint
    torch.save(model.state_dict(), 'model.ckpt')
    
  • 相关阅读:
    php命令注入
    mysql事物
    安装php环境
    移除服务器缓存实例
    show user profile synchronization tools
    manual start user profile import
    JSON is undefined. Infopath Form People Picker in SharePoint 2013
    asp.net web 应用站点支持域账户登录
    Load sharepoint envirement by powershell
    sharepoint 2016 download
  • 原文地址:https://www.cnblogs.com/liulunyang/p/14490493.html
Copyright © 2011-2022 走看看