zoukankan      html  css  js  c++  java
  • pytorch(十六):多层感知机多分类

    一、基本代码

    import torch
    import torch.optim as optim
    from torch.nn import functional as F
    import torch.nn as nn
    import torchvision
    
    # [ch_out,ch_in]
    
    w1, b1 = torch.randn(200, 784, requires_grad=True),
             torch.zeros(200, requires_grad=True)
    w2, b2 = torch.randn(200, 200, requires_grad=True),
             torch.zeros(200, requires_grad=True)
    w3, b3 = torch.randn(10, 200, requires_grad=True),
             torch.zeros(10, requires_grad=True)
    
    torch.nn.init.kaiming_normal_(w1)
    torch.nn.init.kaiming_normal_(w2)
    torch.nn.init.kaiming_normal_(w3)
    
    def forward(x):
        x = x@w1.t() + b1
        x = F.relu(x)
        x = x@w2.t() + b2
        x = F.relu(x)
        x = x@w3.t() + b3
        x = F.relu(x)     # logits
        
        return x
    
    learning_rate = 0.01
    epochs = 1
    batch_size = 64
    
    train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('datasets/mnist_data',
                    train=True,
                    download=True,
                    transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),                       # 数据类型转化
                    torchvision.transforms.Normalize((0.1307, ), (0.3081, )) # 数据归一化处理
        ])), batch_size=batch_size,shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('datasets/mnist_data/',
                    train=False,
                    download=True,
                    transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
        ])),batch_size=batch_size,shuffle=False)
    
    optimizer = optim.SGD([w1,b1,w2,b2,w3,b3], lr = learning_rate)
    criteon = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        for batch_idx,(data,target) in enumerate(train_loader):
            data = data.view(-1,28*28)
            logits = forward(data)
            loss = criteon(logits,target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
    
    
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data = data.view(-1, 28 * 28)
            logits = forward(data)
            test_loss += criteon(logits, target).item()
    
            pred = logits.data.max(1)[1]
            correct += pred.eq(target.data).sum()
    
        test_loss /= len(test_loader.dataset)
        print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

    二、截图

     

  • 相关阅读:
    loadrunner获取Http信息头中指定值作为参数
    soapUI使用-DataSource获取oracle库中的参数
    [转]vim编辑器---批量注释与反注释
    String() 函数把对象的值转换为字符串。
    自定义滚动条mCustomScrollbar
    css实现强制不换行/自动换行/强制换行
    在网页中添加新浪微博“加关注”按钮
    移动前端调试方案(Android + Chrome 实现远程调试)
    font-family
    移动端touch事件滚动
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14026475.html
Copyright © 2011-2022 走看看