zoukankan      html  css  js  c++  java
  • pytorch基础-搭建网络

    搭建网络的步骤大致为以下:

    1.准备数据

    2. 定义网络结构model

    3. 定义损失函数
    4. 定义优化算法 optimizer
    5. 训练
      5.1 准备好tensor形式的输入数据和标签(可选)
      5.2 前向传播计算网络输出output和计算损失函数loss
      5.3 反向传播更新参数
        以下三句话一句也不能少:
        5.3.1 optimizer.zero_grad()  将上次迭代计算的梯度值清0
        5.3.2 loss.backward()  反向传播,计算梯度值
        5.3.3 optimizer.step()  更新权值参数
      5.4 保存训练集上的loss和验证集上的loss以及准确率以及打印训练信息。(可选
    6. 图示训练过程中loss和accuracy的变化情况(可选)
    7. 在测试集上测试

    代码注释都写的很详细 

     1 import torch
     2 import torch.nn.functional as F
     3 import matplotlib.pyplot as plt
     4 
     5 # 1.准备数据 generate data
     6 x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
     7 print(x.shape)
     8 y=x*x+0.2*torch.rand(x.size())
     9 #显示数据散点图
    10 plt.scatter(x.data.numpy(),y.data.numpy())
    11 
    12 # 2.定义网络结构 build net
    13 class Net(torch.nn.Module):
    14     #n_feature:输入特征个数  n_hidden:隐藏层个数 n_output:输出层个数
    15     def __init__(self,n_feature,n_hidden,n_output):
    16         # super表示继承Net的父类,并同时初始化父类的参数
    17         super(Net,self).__init__()
    18         # nn.Linear代表线性层 代表y=w*x+b  其中w的shape为[n_hidden,n_feature] b的shape为[n_hidden]
    19         # y=w^T*x+b 这里w的维度是转置前的维度 所以是反的
    20         self.hidden =torch.nn.Linear(n_feature,n_hidden)
    21         self.predict =torch.nn.Linear(n_hidden,n_output)
    22         print(self.hidden.weight)
    23         print(self.predict.weight)
    24     #定义一个前向传播过程函数
    25     def forward(self, x):
    26         #         n_feature  n_hidden  n_output
    27         #举例(2,5,1)   2         5         1
    28         #                    -  **  -
    29         #             ** - - -  **  - -
    30         #                    -  **  - - - **
    31         #             ** - - -  **  - -
    32         #                    -  **  -
    33         #            输入层    隐藏层    输出层
    34         x=F.relu(self.hidden(x))
    35         x=self.predict(x)
    36         return x
    37 # 实例化一个网络为net
    38 net = Net(n_feature=1,n_hidden=10,n_output=1)
    39 print(net)
    40 # 3.定义损失函数 这里使用均方误差(mean square error)
    41 loss_func=torch.nn.MSELoss()
    42 # 4.定义优化器 这里使用随机梯度下降
    43 optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    44 #定义300遍更新 每10遍显示一次
    45 plt.ion()
    46 # 5.训练
    47 for t in range(100):
    48     prediction = net(x)     # input x and predict based on x
    49     loss = loss_func(prediction, y)     # must be (1. nn output, 2. target)
    50     # 5.3反向传播三步不可少
    51     optimizer.zero_grad()   # clear gradients for next train
    52     loss.backward()         # backpropagation, compute gradients
    53     optimizer.step()        # apply gradients
    54 
    55     if t % 10 == 0:
    56         # plot and show learning process
    57         plt.cla()
    58         plt.scatter(x.data.numpy(), y.data.numpy())
    59         plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    60         plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})
    61         plt.show()
    62         plt.pause(0.1)
    63 
    64 plt.ioff()

    参考:莫烦python

  • 相关阅读:
    在使用触摸屏的情况下插拔USB鼠标,鼠标箭头消失
    使用网卡在接收数据包时不会自动组包
    linux开机发现会有个kworker进程规律性占用CPU负载超过50%
    系统时间是否可以精确到ms级别?
    linux开机进入登录界面,输入密码后屏幕黑屏3-10s,然后重新回到登录界面
    linux多网卡情况下,一个网卡进行组播,一个网卡进行点播,同时配置网关后无法通信
    linux中常见内存分配函数(kmalloc,vmalloc等)
    linux内核中的两个标记GFP_KERNEL和GFP_ATOMIC作用是什么?
    gcc: error: unrecognized argument in option ‘-mabi=aapcs-linux’
    shell脚本100例
  • 原文地址:https://www.cnblogs.com/bob-jianfeng/p/11407955.html
Copyright © 2011-2022 走看看