zoukankan      html  css  js  c++  java
  • pytorch实战(2)-----回归例子

    一、回归任务介绍:

    拟合一个二元函数 y = x ^ 2.

    二、步骤:

    1. 导入包
    2. 创建数据
    3. 构建网络
    4. 设置优化器和损失函数
    5. 前向和后向传播训练网络
    6. 画图

    三、代码:

    导入包:

    import torch
    from torch.autograd import Variable
    import torch.nn.functional as F
    import matplotlib.pyplot as plt

    创建数据

    #torch中的数据要是二维的,unsqueeze是将一维数据转化成二维数据
    tmp = torch.linspace(-1,1,100)
    x = torch.unsqueeze(tmp,dim=1)
    y = x.pow(2) + 0.2*torch.rand(x.size())
    
    print(tmp)  #torch.Size([100])
    print(x)  #torch.Size([100, 1])
    #转成向量
    x,y = Variable(x),Variable(y)

       查看数据图像:

    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.show()

    构建网络

    #Net类继承了Module这个模块
    class Net(torch.nn.Module):
        def __init__(self,n_feature,n_hidden,n_output):
            #在搭建模型之前需要继承的一些信息,super表示继承nn.Module的信息,此步骤必须有
            super(Net,self).__init__()
            self.hidden = torch.nn.Linear(n_feature,n_hidden)
            self.predict = torch.nn.Linear(n_hidden,n_output)
        #神经网络前向传递的一个过程,流程图
        def forward(self,x):
            x = F.relu(self.hidden(x))
            x = self.predict(x)
            return x
    net = Net(1,10,1)
    plt.ion()
    plt.show()
    #可以看到搭建的图流程
    print(net)
     打印的结果:
    Net(
      (hidden): Linear(in_features=1, out_features=10, bias=True)
      (predict): Linear(in_features=10, out_features=1, bias=True)
    )

     设置优化器和损失函数

    optimizer = torch.optim.SGD(net.parameters(),lr = 0.5)  #传入网络的参数来优化它们
    loss_func = torch.nn.MSELoss()

    前向和后向传播训练网络

    for t in range(100):
        
        #forward
        prediction = net(x)
        loss = loss_func(prediction,y)  #预测值pre在前,实际值y在后,不然结果会不一样
        
        #backward()
        optimizer.zero_grad()   #梯度全部设为0
        loss.backward()  #loss计算参数的梯度
        optimizer.step()  #采用优化器以lr=0.5来优化梯度
        
    ###########################以下为可视化过程##################################
        if t % 5 == 0:
            plt.cla()
            plt.scatter(x.data.numpy(),y.data.numpy())
            plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
            plt.text(0.5,0,'Loss=%.4f' % loss.data[0],fontdict={'size':20,'color':'red'})
            plt.pause(0.1)
    plt.ioff()
    plt.show()

    训练结果:

    第一次:

    最后一次:

  • 相关阅读:
    Random 种子问题
    Matrix issue
    Two sum.
    Best Time to Buy and Sell Stock
    Maximum difference between two elements
    二分查找法的实现和应用汇总
    Why you want to be restrictive with shrink of database files [From karaszi]
    Palindrome
    NetBeans vs Eclipse 之性能参数对比 [java060515]
    国内各大互联网公司相关技术站点不完全收录[转]
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/9885011.html
Copyright © 2011-2022 走看看