zoukankan      html  css  js  c++  java
  • 在多分类任务实验中用torch.nn实现dropout

    1 导入需要的包

    import torch
    import torch.nn as nn
    import numpy as np
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt

    2 下载MNIST数据集以及读取数据

    mnist_train = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=True, download=True, transform=transforms.ToTensor())  
    mnist_test = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=False,download=True, transform=transforms.ToTensor())  
    batch_size = 256 
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  

    3 定义模型

    class LinearNet(nn.Module):
        def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):
            super(LinearNet,self).__init__()
            self.linear1 = nn.Linear(num_inputs,num_hiddens1)
            self.relu = nn.ReLU()
            self.drop1 = nn.Dropout(drop_prob1)
            self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)
            self.drop2 = nn.Dropout(drop_prob2)
            self.linear3 = nn.Linear(num_hiddens2,num_outputs)
            self.flatten  = nn.Flatten()
        
        def forward(self,x):
            x = self.flatten(x)
            x = self.linear1(x)
            x = self.relu(x)
            x = self.drop1(x)
            x = self.linear2(x)
            x = self.relu(x)
            x = self.drop2(x)
            x = self.linear3(x)
            y = self.relu(x)
            return y

    4 定义训练模型

    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
        train_ls, test_ls = [], []
        for epoch in range(num_epochs):
            ls, count = 0, 0
            for X,y in train_iter:
                l=loss(net(X),y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                ls += l.item()
                count += y.shape[0]
            train_ls.append(ls)
            ls, count = 0, 0
            for X,y in test_iter:
                l=loss(net(X),y)
                ls += l.item()
                count += y.shape[0]
            test_ls.append(ls)
            if(epoch+1)%5==0:
                print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
        return train_ls,test_ls

    5 比较不同dropout的影响

    num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
    num_epochs=20
    lr = 0.001
    drop_probs = np.arange(0,1.1,0.1)
    Train_ls, Test_ls = [], []
    for drop_prob in drop_probs:
        net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)
        for param in net.parameters():
            nn.init.normal_(param,mean=0, std= 0.01)
        loss = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(net.parameters(),lr)
        train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)
        Train_ls.append(train_ls)
        Test_ls.append(test_ls)

    6 绘制不同dropout损失图

    x = np.linspace(0,len(train_ls),len(train_ls))
    plt.figure(figsize=(10,8))
    for i in range(0,len(drop_probs)):
        plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
        plt.xlabel('epoch')
        plt.ylabel('loss')
    plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
    plt.title('train loss with dropout')
    plt.show()

    nn.Flatten() demo

    input = torch.randn(2, 5, 5)
    m = nn.Sequential(
    nn.Flatten()
    )
    output = m(input)
    output.size()

    因上求缘,果上努力~~~~ 作者:每天卷学习,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15513445.html

  • 相关阅读:
    有关android UI 线程
    lang3日期工具类底层源码解析
    JSON业务模型拆解技巧
    math3底层源码解决多元方程组
    关于日期解析Scala语言
    maven仓库支持cdh版本配置
    kudu数据库个人简单的总结
    json数据入库kafka
    json数据写入hbase
    一只鸟的故事
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15513445.html
Copyright © 2011-2022 走看看