zoukankan      html  css  js  c++  java
  • MLP实现Fashion image分类

    导入数据

    import torch
    import torch.nn as nn
    from torch.utils import data
    import torchvision
    from torchvision import transforms
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import torch.nn.functional as F
    from d2l import torch as d2l
    %matplotlib inline
    
    trans = transforms.ToTensor()
    mnist_train = torchvision.datasets.FashionMNIST(root = './image_classification/data',train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root='./image_classification/data',train=False, transform=trans, download=True)
    train_iter = data.DataLoader(mnist_train, batch_size=64, num_workers=4, shuffle=True)
    test_iter = data.DataLoader(mnist_test, batch_size=64, num_workers=4, shuffle=True)
    

    MLP

    class Linear(nn.Module):
        def __init__(self, indim, outdim):
            super(Linear, self).__init__()
            self.W = self.W = nn.Parameter(torch.FloatTensor(indim, outdim))
            nn.init.xavier_normal_(self.W)
            self.b = nn.Parameter(torch.FloatTensor(outdim))
            nn.init.normal_(self.b,mean=0, std=0.01)
            
        def forward(self, x):  
            return x @ self.W + self.b  
        
    class Net(nn.Module):
        def __init__(self, indim, classes):
            super(Net, self).__init__()
            self.Lay1 = Linear(indim, 256)
            self.Lay2 = Linear(256, 64)
            self.Lay3 = Linear(64, classes)
            
        def forward(self, x):
            x = x.reshape(-1, 784)
            x = self.Lay1(x)
            h = F.tanh(x)
            h = F.tanh(self.Lay2(h))
            o = self.Lay3(h)
            return F.log_softmax(o) 
    

    我这里使用了log_softmax,意味着我后面直接用nllLoss就行,而不用交叉熵损失函数。

    定义准确度

    def accuracy(y_hat, y):
        y_hat = y_hat.argmax(dim=1)
        cmp = (y_hat.type(y.dtype) == y)
        return cmp.type(y.dtype).sum()
    
    def evaluate(net, test_iter):
        metrics = d2l.Accumulator(2)
        for x, y in test_iter:
            y_hat = net(x)
            metrics.add(accuracy(y_hat, y), len(y))
        return metrics[0] / metrics[1]
    

    这个Accumulator是一个累加器,因为我们训练的时候都是一个batch一个batch进行训练,如果要获得一个epoch的训练误差或者说准确度,就需要定义这么一个累加器。因为test数据也是一个batch读取,所以我们也需要做一个累加器。

    训练

    indim = 784
    classes = 10
    net = Net(indim, classes)
    loss = nn.NLLLoss(reduction='mean')
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=5e-5)
    def train(net, train_iter, test_iter, loss, optimizer, num_epochs):
        animator = d2l.Animator(xlabel='epoch', xlim = [1, num_epochs], ylim=[0, 1], 
                                legend=['train_loss', 'train_acc', 'test_acc'])
        for i in range(num_epochs):
            metrics = d2l.Accumulator(3) # 记录预测准确数、loss、以及size
            for x, y in train_iter:
                y_hat = net(x)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                metrics.add(accuracy(y_hat, y), float(l)*len(y), y.numel())
            train_acc = metrics[0] / metrics[2]
            train_loss = metrics[1] / metrics[2]
            test_acc = evaluate(net, test_iter)
            animator.add(i+1, (train_loss, train_acc, test_acc))
            if i == num_epochs-1:
                print("最终的train loss:{} train acc:{} test acc:{}".format(train_loss,train_acc, test_acc)) 
    
    train(net, train_iter, test_iter, loss, optimizer, 20)
    

    这里的Animator函数,看了很久也没看懂,就记着它的api调用方法吧。

    image-20211110113308324

    预测

    def predict_ch3(net, test_iter, n=6):  #@save
        """预测标签(定义见第3章)。"""
        for X, y in test_iter:
            break
        trues = d2l.get_fashion_mnist_labels(y)
        preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
        titles = [true +'
    ' + pred for true, pred in zip(trues, preds)]
        d2l.show_images(
            X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
    
    predict_ch3(net, test_iter)
    
  • 相关阅读:
    Systemverilog for design 笔记(三)
    SystemVerilog for design 笔记(二)
    Systemverilog for design 笔记(一)
    假如m是奇数,且m>=3,证明m(m² -1)能被8整除
    SharpSvn操作 -- 获取Commit节点列表
    GetRelativePath获取相对路径
    Dictionary(支持 XML 序列化),注意C#中原生的Dictionary类是无法进行Xml序列化的
    Winform中Checkbox与其他集合列表类型之间进行关联
    Image(支持 XML 序列化),注意C#中原生的Image类是无法进行Xml序列化的
    修复使用<code>XmlDocument</code>加载含有DOCTYPE的Xml时,加载后增加“[]”字符的错误
  • 原文地址:https://www.cnblogs.com/kalicener/p/15532717.html
Copyright © 2011-2022 走看看