zoukankan      html  css  js  c++  java
  • pytorch(二十):MLP的visdom可视化

    一、VISDOM

     

     二、具体代码

    import torch
    import torch.nn as nn
    from torchvision import datasets,transforms
    import torch.optim as optim
    from torch.nn import functional as F
    from visdom import Visdom
    import torchvision
    
    
    batch_size = 64
    learning_rate  = 1e-2
    epochs = 10
    
    
    train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('datasets/mnist_data',
                    train=True,
                    download=True,
                    transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),                       # 数据类型转化
                    torchvision.transforms.Normalize((0.1307, ), (0.3081, )) # 数据归一化处理
        ])), batch_size=batch_size,shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('datasets/mnist_data/',
                    train=False,
                    download=True,
                    transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
        ])),batch_size=batch_size,shuffle=False)
    
    class MLP(nn.Module):
    
        def __init__(self):
            super(MLP, self).__init__()
    
            self.model = nn.Sequential(
                nn.Linear(784, 200),
                nn.LeakyReLU(inplace=True),
                nn.Linear(200, 200),
                nn.LeakyReLU(inplace=True),
                nn.Linear(200, 10),
                nn.LeakyReLU(inplace=True),
            )
    
        def forward(self, x):
            x = self.model(x)
    
            return x
    
    device = torch.device('cuda:0')
    net = MLP().to(device)
    optimizer = optim.SGD(net.parameters(),lr = learning_rate)
    criteon = nn.CrossEntropyLoss().to(device)
    
    viz = Visdom()
    
    viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
    viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',
                                                       legend=['loss', 'acc.']))
    global_step = 0
    
    
    for epoch in range(epochs):
    
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(-1, 28*28)
            data, target = data.to(device), target.cuda()
    
            logits = net(data)
            loss = criteon(logits, target)
    
            optimizer.zero_grad()
            loss.backward()
            # print(w1.grad.norm(), w2.grad.norm())
            optimizer.step()
    
            global_step += 1
            viz.line([loss.item()], [global_step], win='train_loss', update='append')
    
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
    
    
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data = data.view(-1, 28 * 28)
            data, target = data.to(device), target.cuda()
            logits = net(data)
            test_loss += criteon(logits, target).item()
    
            pred = logits.argmax(dim=1)
            correct += pred.eq(target).float().sum().item()
    
        viz.line([[test_loss, correct / len(test_loader.dataset)]],
                 [global_step], win='test', update='append')
        viz.images(data.view(-1, 1, 28, 28), win='x')
        viz.text(str(pred.detach().cpu().numpy()), win='pred',
                 opts=dict(title='pred'))
    
        test_loss /= len(test_loader.dataset)
        print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
  • 相关阅读:
    jquery动画,获取,添加
    javac3p0连接池
    jquery尺寸
    jquery遍历
    jquery删除,停止,获取设置css,设置内容和属性,过滤
    javajdbc(数据库的添加,删除,修改,更新)
    博客开通
    很久每有来空来了,一些最近的想法
    竖线的显示
    一个小问题,c++
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14026651.html
Copyright © 2011-2022 走看看