zoukankan      html  css  js  c++  java
  • 神经网络架构PYTORCH-宏观分析

    基本概念和功能:

    PyTorch是一个能够提供两种高级功能的python开发包,这两种高级功能分别是:
      使用GPU做加速的矢量计算
      具有自动重放功能的深度神经网络
    从细的粒度来分,PyTorch是一个包含如下类别的库:

    1. Torch:类似于Numpy的通用数组库,可以在将张量类型转换为2 (torch.cuda.TensorFloat)并在GPU上进行计算。
    2. torch.autograd 支持全微分张量运算的基于磁带的自动微分库
    3. torch.nn 一个具有最大设计灵活性的高度集成的神经网络库
    4. torch.multiprocessing python的多重处理系统,通常用在数据加载和高强度的训练
    5. torch.utils 数据记载,训练和转换的接口函数
    6. torch.legacy(.nn/.optim) 从Torch上移植过来的代码,为了保证向后兼容.

    安装指南:

     安装有两种方式,一种是库文件安装详见目录:https://pytorch.org/

    另外一种是源码安装:在github上把东西下载下来:https://github.com/pytorch/pytorch.git

    下载之首先要进行源码安装,在根目录下执行:

    python setup.py install

    这个是linux下的源码安装,安装过程中很多情况下会缺少一些库,这个要根据实际的问题去谷歌搜,答案都能找到的.

    源码分析:

    源码的目录如下所示:

    分解:

    • aten: 在torch中实现矢量运算的简单的矢量库.
    • caffe2:caffe2的源码和例子
    • docs: 该系统的文档
    • third_party 第三方的库文件和和源码
    • torch torch的源码和使用例子
    • binaries 各种基准的生成源码

    最简实例:

    下面一个例子是使用PyTorch做线性回归的例子,源码如下:

     1 # -*- coding: utf-8 -*-
     2 
     3 import torch
     4 import torch.optim as optim
     5 import matplotlib.pyplot as plt
     6 
     7 learning_rate = 0.001
     8 
     9 def get_fake_data(batch_size=32):
    10     ''' y=x*2+3 '''
    11     x = torch.randn(batch_size, 1) * 20
    12     y = x * 2 + 3 + torch.randn(batch_size, 1)
    13     return x, y
    14 
    15 x, y = get_fake_data()
    16 
    17 class LinerRegress(torch.nn.Module):
    18     def __init__(self):
    19         super(LinerRegress, self).__init__()
    20         self.fc1 = torch.nn.Linear(1, 1)
    21 
    22     def forward(self, x):
    23         return self.fc1(x)
    24 
    25 
    26 net = LinerRegress()
    27 loss_func = torch.nn.MSELoss()
    28 optimzer = optim.SGD(net.parameters(), lr=learning_rate)
    29 
    30 for i in range(40000):
    31 
    32     optimzer.zero_grad()
    33 
    34     out = net(x)
    35     loss = loss_func(out, y)
    36     loss.backward()
    37 
    38     optimzer.step()
    39 
    40 w, b = [param.item() for param in net.parameters()]
    41 print w, b  # 2.01146, 3.184525
    42 
    43 # 显示原始点与拟合直线
    44 plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
    45 plt.plot(x.squeeze().numpy(), (x*w + b).squeeze().numpy())
    46 plt.show()

      运行结果:

      到此为止,PyTorch的基本认识算是结束,后面就要开始深入的分析它在各个方面的应用和代码了.

  • 相关阅读:
    MySQL5.7安装教程(压缩版)
    ASP.Net 添加 Interop for Word, excel 插件
    OpenLayers-加载地图数据(Google Map)
    OpenLayers学习方法
    OpenLayers 项目分析——(二)源代码总体结构分析{感谢原作者对于大家的贡献!}
    OpenLayers项目分析——(一)项目介绍{感谢原作者为大家的贡献!}
    OpenLayers 项目分析——(三)BaseTypes {感谢原文作者对于大家的贡献!!!}
    GeoServer安装及配置过程(配置shapefile)
    如何查询postgreSQL 里面某个数据库中所有用户定义的数据表的名字
    WebGIS(PostgreSQL+GeoServer+OpenLayers)之三 OpenLayers客户端数据显示
  • 原文地址:https://www.cnblogs.com/dylancao/p/9855695.html
Copyright © 2011-2022 走看看