zoukankan      html  css  js  c++  java
  • 在多分类任务实验中 用torch.nn实现 𝑳𝟐 正则化

    1 导入实验所需要的包

    import torch
    import torch.nn as nn
    import numpy as np
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    %matplotlib inline

    2 下载MNIST数据集以及读取数据

    train_dataset = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=True, transform=transforms.ToTensor(),download=True)
    test_dataset = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=False, transform = transforms.ToTensor(),download=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    3 初始化参数

    num_inputs, num_outputs, num_hiddens = 784, 10, 600
    def init_w_b():
        W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
        b1 = torch.zeros(num_hiddens, dtype = torch.float)
        W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
        b2 = torch.zeros(num_outputs,dtype=torch.float)
    
        params = [W1,b1,W2,b2]
        for param in params:
            param.requires_grad_(requires_grad=True)
        return W1,b1,W2,b2

    4 定义模型

    class LinearNet(nn.Module):
        def __init__(self,num_inputs, num_outputs, num_hiddens):
            super(LinearNet,self).__init__()
            self.linear1 = nn.Linear(num_inputs,num_hiddens)
            self.relu = nn.ReLU()
            self.linear2 = nn.Linear(num_hiddens,num_outputs)
            self.flatten  = nn.Flatten()
        
        def forward(self,x):
            x = self.flatten(x)
            x = self.linear1(x)
            x = self.relu(x)
            x = self.linear2(x)
            y = self.relu(x)
            return y

    5 定义训练函数

    def train_torch(lamda):
        num_epochs = 20
        train_ls, test_ls = [], []
        for epoch in range(num_epochs):
            ls, count = 0, 0
            for X,y in train_loader:
                l=loss(net(X),y)
                optimizer_w.zero_grad()
                optimizer_b.zero_grad()
                l.backward()
                optimizer_w.step()
                optimizer_b.step()
                ls += l.item()
                count += y.shape[0]
            train_ls.append(ls)
            ls, count = 0, 0
            for X,y in test_loader:
                l=loss(net(X),y)
                ls += l.item()
                count += y.shape[0]
            test_ls.append(ls)
            
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
        return train_ls,test_ls

    6 开始训练

    Lamda = [0,0.1,0.2,0.3,0.4,0.5]
    torch_Train_ls, torch_Test_ls = [], []
    for lamda in Lamda:
        W1,b1,W2,b2 = init_w_b()
        loss = nn.CrossEntropyLoss()
        net = LinearNet(num_inputs, num_outputs, num_hiddens)
        optimizer_w = torch.optim.SGD([W1,W2],lr = 0.001,weight_decay=lamda)
        optimizer_b = torch.optim.SGD([b1,b2],lr = 0.001)
        train_ls, test_ls = train_torch(lamda)
        torch_Train_ls.append(train_ls)
        torch_Test_ls.append(test_ls)

    7 绘制训练集和测试集的loss曲线

    x = np.linspace(0,len(torch_Train_ls[1]),len(torch_Train_ls[1]))
    plt.figure(figsize=(10,8))
    for i in range(0,len(Lamda)):
        plt.plot(x,torch_Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)
        plt.xlabel('epoch')
        plt.ylabel('loss')
    plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderaxespad = 0.)
    plt.title('train loss')
    plt.show()

    因上求缘,果上努力~~~~ 作者:希望每天涨粉,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15513546.html

  • 相关阅读:
    [转]C++ Operator Overloading Guidelines
    SICP学习笔记(2.2.1)
    .net中模拟键盘和鼠标操作
    javaScript系列 [17]运算符
    javaScript系列 [24]Math
    javaScript系列 [19]string
    javaScript系列 [22]引用类型
    javaScript系列 [12]Canvas绘图(曲线)
    javaScript系列 [15]Canvas绘图(压缩)
    javaScript系列 [21]Array
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15513546.html
Copyright © 2011-2022 走看看