zoukankan      html  css  js  c++  java
  • 【pytorch】学习笔记(四)-搭建神经网络进行关系拟合

    【pytorch学习笔记】-搭建神经网络进行关系拟合

    学习自莫烦python

    目标

    1.创建一些围绕y=x^2+噪声这个函数的散点
    2.用神经网络模型来建立一个可以代表他们关系的线条

    建立数据集

    import torch
    from torch.autograd import Variable
    import  torch.nn.functional as F
    import matplotlib.pyplot as plt
    
    x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维,x从-1到1,切分为100份
    y=x.pow(2)+0.2*torch.rand(x.size())#创建一些围绕着这y=x^2的随机点的散点
    
    # plt.scatter(x.data.numpy(),y.data.numpy())#画图
    # plt.show()
    
    x,y=Variable(x),Variable(y)#构造神经网络要使用Variable类型
    

    建立神经网络

    1.继承torch.nn.Module模块
    2.定义__init__函数,在初始化函数中定义输入层到隐藏层,从隐藏层再到输出层各个层的神经元个数
    3.再一层层搭建(forward(x))层于层的关系链接

    class Net(torch.nn.Module):
        def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
            super(Net, self).__init__()
            self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
            self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出
    
        def forward(self,x):#前向传递的过程
            #正向传播输入值,神经网络输出预测值
            x=F.relu(self.hidden(x))#激励函数加工一下
            x=self.predict(x)#输出值预测值
            return x
    

    训练神经网络

    1.定义训练工具optimizer,输入神经网络参数和学习效率
    2.定义误差函数,使用均方差来计算实际值y和训练输出值之间的误差
    3.每次训练向神经网络输入x,得到预测值,计算误差
    4.注意要清空上一步的残余更新参数值
    5.误差反向传播, 计算参数更新值
    6.将参数更新值施加到 net 的 parameters 上

    for t in range(200):#训练200次
        prediction=net(x)#输入输入值
        loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传递
        optimizer.step()#优化梯度
    

    可视化训练过程

    for t in range(200):#训练200次
        prediction=net(x)#输入输入值
        loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传递
        optimizer.step()#优化梯度
        # 接着上面来
        if t % 5 == 0:
            # plot and show learning process
            plt.cla()
            plt.scatter(x.data.numpy(), y.data.numpy())
            plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
            plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
            plt.pause(0.1)
    

    完整代码

    import torch
    from torch.autograd import Variable
    import  torch.nn.functional as F
    import matplotlib.pyplot as plt
    
    x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维
    y=x.pow(2)+0.2*torch.rand(x.size())
    
    # plt.scatter(x.data.numpy(),y.data.numpy())
    # plt.show()
    
    x,y=Variable(x),Variable(y)#构造神经网络的是琥珀要使用Variable类型的
    
    class Net(torch.nn.Module):
        def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
            super(Net, self).__init__()
            self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
            self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出
    
        def forward(self,x):#前向传递的过程
            #正向传播输入值,神经网络输出预测值
            x=F.relu(self.hidden(x))#激励函数加工一下
            x=self.predict(x)#输出值预测值
            return x
    
    net=Net(n_feature=1,n_hidden=10,n_ouput=1)#输入值是一个,隐藏层有10个神经元,输出值为y值
    print(net)
    
    optimizer=torch.optim.SGD(net.parameters(),lr=0.5)#输入神经网络的所有参数,学习效率,这个是训练工具
    loss_func=torch.nn.MSELoss()#误差处理均方差
    
    plt.ion()   # 画图
    plt.show()
    
    for t in range(200):#训练200次
        prediction=net(x)#输入输入值
        loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传递
        optimizer.step()#优化梯度
        # 接着上面来
        if t % 5 == 0:
            # plot and show learning process
            plt.cla()
            plt.scatter(x.data.numpy(), y.data.numpy())
            plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
            plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
            plt.pause(0.1)
    
    

    过程结果

    中间过程省略一部分...

  • 相关阅读:
    Android开发 使用 adb logcat 显示 Android 日志
    【嵌入式开发】向开发板中烧写Linux系统-型号S3C6410
    C语言 结构体相关 函数 指针 数组
    C语言 命令行参数 函数指针 gdb调试
    C语言 指针数组 多维数组
    Ubuntu 基础操作 基础命令 热键 man手册使用 关机 重启等命令使用
    C语言 内存分配 地址 指针 数组 参数 实例解析
    CRT 环境变量注意事项
    hadoop 输出文件 key val 分隔符
    com.mysql.jdbc.exceptions.MySQLNonTransientConnectionException: Too many connections
  • 原文地址:https://www.cnblogs.com/mengxiaoleng/p/11792802.html
Copyright © 2011-2022 走看看