zoukankan      html  css  js  c++  java
  • pytorch处理模型过拟合

    演示代码如下

     1 import torch
     2 from torch.autograd import Variable
     3 import torch.nn.functional as F
     4 import matplotlib.pyplot as plt
     5 # make fake data
     6 n_data = torch.ones(100, 2)
     7 x0 = torch.normal(2*n_data, 1)      #每个元素(x,y)是从 均值=2*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
     8 y0 = torch.zeros(100)               # 给每个元素一个0标签
     9 x1 = torch.normal(-2*n_data, 1)     # 每个元素(x,y)是从 均值=-2*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
    10 y1 = torch.ones(100)                # 给每个元素一个1标签
    11 x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # shape (200, 2) FloatTensor = 32-bit floating
    12 y = torch.cat((y0, y1), ).type(torch.LongTensor)    # shape (200,) LongTensor = 64-bit integer
    13 # torch can only train on Variable, so convert them to Variable
    14 x, y = Variable(x), Variable(y)
    15 
    16 # draw the data
    17 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy())#c是一个颜色序列
    18 
    19 
    20 #plt.show()
    21 #神经网络模块
    22 net2 = torch.nn.Sequential(
    23     torch.nn.Linear(2,10),
    24     torch.nn.Dropout(0.2),#处理过拟合,当然这个模型本身很简单,不需要处理过拟合,这个只是一个演示
    25     torch.nn.ReLU(),
    26     torch.nn.Linear(10,2)
    27 )
    28 
    29 plt.ion()#在Plt.ion和plt.ioff之间的代码,交互绘图
    30 plt.show()
    31 #神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时间。
    32 optimizer = torch.optim.SGD(net2.parameters(),lr = 0.01)#其实就是神经网络的反向传播,第一个参数是更新权重等参数,第二个对应的是学习率
    33 loss_func = torch.nn.CrossEntropyLoss()#标签误差代价函数
    34 
    35 for t in range(50):
    36     out = net2(x)
    37     loss = loss_func(out,y)#计算损失
    38     optimizer.zero_grad()#梯度置零
    39     loss.backward()#反向传播
    40     optimizer.step()#计算结点梯度并优化,
    41     if t % 2 == 0:
    42         net2.eval()#模型做预测的时候不需要dropout,切换为eval()模式
    43         plt.cla()# Clear axis即清除当前图形中的之前的轨迹
    44         prediction = torch.max(F.softmax(out), 1)[1]#转换为概率,后面的一是最大值索引,如果为0则返回最大值
    45         pred_y = prediction.data.numpy().squeeze()
    46         target_y = y.data.numpy()
    47         plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
    48         accuracy = sum(pred_y == target_y) / 200.#求准确率
    49         plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
    50         plt.pause(0.1)
    51         net2.train()#切花为训练模式
    52 
    53 plt.ioff()
    54 plt.show()

    注意model.eval和model.train的使用

  • 相关阅读:
    Cookie的定义和分类,及优缺点
    网页开发和设计
    电视精灵(新手练习项目)
    C#体检套餐项目
    C#简单的对象交互
    那些年我们学过的构造函数(构造方法,C#)
    员工打卡课后小项目
    SpringMVC类型转换器
    SpringMVC 异常处理3种方案
    SSH整合(一)hibernate+spring
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11944737.html
Copyright © 2011-2022 走看看