zoukankan      html  css  js  c++  java
  • Task1.PyTorch的基本概念

    1.什么是Pytorch,为什么选择Pytroch?

      PyTorch的前身便是Torch,其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接口。是一个以Python优先的深度学习框架,不仅能够实现强大的GPU加速,同时还支持动态神经网络,这是很多主流深度学习框架比如Tensorflow等都不支持的。
      PyTorch既可以看作加入了GPU支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络。除了Facebook外,它已经被Twitter、CMU和Salesforce等机构采用
      为什么选择PyTorch呢?

      因为PyTorch是当前难得的简洁优雅且高效快速的框架。在笔者眼里,PyTorch达到目前深度学习框架的最高水平。当前开源的框架中,没有哪一个框架能够在灵活性、易用性、速度这三个方面有两个能同时超过PyTorch。下面是许多研究人员选择PyTorch的原因。

      • 简洁:PyTorch的设计追求最少的封装,尽量避免重复造轮子。不像TensorFlow中充斥着session、graph、operation、name_scope、variable、tensor、layer等全新的概念,PyTorch的设计遵循  tensor→variable(autograd)→nn.Module 三个由低到高的抽象层次,分别代表高维数组(张量)、自动求导(变量)和神经网络(层/模块),而且这三个抽象之间联系紧密,可以同时进行修改和操作。
    简洁的设计带来的另外一个好处就是代码易于理解。PyTorch的源码只有TensorFlow的十分之一左右,更少的抽象、更直观的设计使得PyTorch的源码十分易于阅读。在笔者眼里,PyTorch的源码甚至比许多框架的文档更容易理解。

      • 速度:PyTorch的灵活性不以速度为代价,在许多评测中,PyTorch的速度表现胜过TensorFlow和Keras等框架 。框架的运行速度和程序员的编码水平有极大关系,但同样的算法,使用PyTorch实现的那个更有可能快过用其他框架实现的。

      • 易用:PyTorch是所有的框架中面向对象设计的最优雅的一个。PyTorch的面向对象的接口设计来源于Torch,而Torch的接口设计以灵活易用而著称,Keras作者最初就是受Torch的启发才开发了Keras。PyTorch继承了Torch的衣钵,尤其是API的设计和模块的接口都与Torch高度一致。PyTorch的设计最符合人们的思维,它让用户尽可能地专注于实现自己的想法,即所思即所得,不需要考虑太多关于框架本身的束缚。

      • 活跃的社区:PyTorch提供了完整的文档,循序渐进的指南,作者亲自维护的论坛 供用户交流和求教问题。Facebook 人工智能研究院对PyTorch提供了强力支持,作为当今排名前三的深度学习研究机构,FAIR的支持足以确保PyTorch获得持续的开发更新,不至于像许多由个人开发的框架那样昙花一现。

      在PyTorch推出不到一年的时间内,各类深度学习问题都有利用PyTorch实现的解决方案在GitHub上开源。同时也有许多新发表的论文采用PyTorch作为论文实现的工具,PyTorch正在受到越来越多人的追捧 。作为论文实现的工具,PyTorch正在受到越来越多人的追捧 。如果说 TensorFlow的设计是“Make It Complicated”,Keras的设计是“Make It Complicated And Hide It”,那么PyTorch的设计真正做到了“Keep it Simple,Stupid”。简洁即是美。使用TensorFlow能找到很多别人的代码,使用PyTorch能轻松实现自己的想法

    2.Pytroch的安装

    pip3 install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-win_amd64.whl

    安装好后Pytorch后,还有个torchvision安装,这个主要集成了一些数据集,深度学习模型,一些转换等,以后需要使用还是很方便的。
    pip3 install torchvision

    3、测试安装

    进入Python环境测试安装的包是否成功,如下图所示。import torch导入没报错就说明安装成功,torch.cuda.is_available()是True就表示支持GPU啦~

     

    4.PyTorch基础概念

    Tensor(张量)

      Tensors与Numpy中的 ndarrays类似,但是在PyTorch中 Tensors可以使用GPU进行计算。

    Tensor是神经网络框架中重要的基础数据类型,可以简单理解为N维数组的容器对象。Tensor之间的通过运算进行连接,从而形成计算图。

    自动求导

      PyTorch 中所有神经网络的核心是 autograd 包。

      autograd包为张量上的所有操作提供了自动求导。 它是一个在运行时定义的框架,这意味着反向传播是根据你的代码来确定如何运行,并且每次迭代可以是不同的。

    神经网络

      torch.nn模块提供了创建神经网络的基础构件,这些层都继承自Module类。当实现神经网络时需要继承自此模块,并在初始化函数中创建网络需要包含的层,并实现forward函数完成前向计算,网络的反向计算会由自动求导机制处理。


    5.通用代码实现流程(实现一个深度学习的代码流程)

     手写数字识别

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.utils.data as Data
    import torchvision
    import matplotlib.pyplot as plt
    %matplotlib inline
    torch.manual_seed(1)
    EPORCH=1
    BATCH_SIZE=50
    LR=0.001
    DOWNLOAD_MNIST = False
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,#this is train data,
        transform=torchvision.transforms.ToTensor(),
        download=DOWNLOAD_MNIST,
    )
    print(train_data.train_data.size())
    print(train_data.train_labels.size())
    plt.imshow(train_data.train_data[0].numpy(),cmap='gray')
    plt.title('%i'%train_data.train_labels[0])
    plt.show()
    1 train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
    2 test_data = torchvision.datasets.MNIST(root='./mnist',train=False)
    3 test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1)).type(torch.FloatTensor)[:2000]/255
    4 test_y = test_data.test_labels[:2000]
     1 class CNN(nn.Module):
     2     def __init__(self):
     3         super(CNN,self).__init__()
     4         self.conv1=nn.Sequential(
     5             nn.Conv2d(
     6                 in_channels=1,
     7                 out_channels=16,
     8                 kernel_size=5,
     9                 stride=1,
    10                 padding=2,
    11             ),
    12             nn.ReLU(),
    13             nn.MaxPool2d(kernel_size=2),
    14         )
    15         self.conv2=nn.Sequential(
    16             nn.Conv2d(16,32,5,1,2),
    17             nn.ReLU(),
    18             nn.MaxPool2d(2),
    19         )
    20         self.out=nn.Linear(32*7*7,10)  #fully connected layer,output10 classes
    21     def forward(self,x):
    22         x = self.conv1(x)
    23         x = self.conv2(x)
    24         x = x.view(x.size(0),-1)
    25         output = self.out(x)
    26         return output
    1 cnn = CNN()
    2 print(cnn)
    optimizer = torch.optim.Adam(cnn.parameters(),lr=LR) #optimize all cnn parameters
    loss_func = nn.CrossEntropyLoss()
     1 from matplotlib import cm
     2 # try: from sklearn.manifold import TSNE; HAS_SK = True
     3 # except: HAS_SK = False; print('Please install sklearn for layer visualization')
     4 # def plot_with_labels(lowDWeights, labels):
     5 #     plt.cla()
     6 #     X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
     7 #     for x, y, s in zip(X, Y, labels):
     8 #         c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
     9 #     plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
    10 
    11 # plt.ion()
    12 for epoch in range(EPORCH):
    13     for step ,(x,y) in enumerate(train_loader):
    14         b_x = Variable(x)
    15         b_y = Variable(y)
    16         
    17         output=cnn(b_x)
    18         loss = loss_func(output,b_y)
    19         optimizer.zero_grad()
    20         loss.backward()
    21         optimizer.step()
    22         if step % 50 == 0:
    23             test_output = cnn(test_x)
    24             pred_y = torch.max(test_output,1)[1].data.squeeze()
    25             accuracy = (pred_y == test_y).sum().item() / test_y.size(0)
    26             print('Epoch: ',epoch,'| train loss: %.4f'%loss.item(),'|test accuracy: %.2f'%accuracy)
    27 test_output = cnn(test_x[:10])
    28 pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
    29 print(pred_y,'prediction number')
    30 print(test_y[:10].numpy(),'real number')

    参考:https://blog.csdn.net/qingxuanmingye/article/details/90138744
    https://blog.csdn.net/broadview2006/article/details/79147351

  • 相关阅读:
    关于轨道交通的一些知识点和关键词
    关于芯片的一些关键词
    关于ADC采集
    Linux记录
    在VMware运行Linux下,密码错误的原因
    气体传感器
    AD采集问题
    Maven [ERROR] 不再支持源选项 5,请使用 7 或更高版本的解决办法
    Maven 专题(九):后记
    Maven 专题(六):Maven核心概念详解(二)
  • 原文地址:https://www.cnblogs.com/NPC-assange/p/11308981.html
Copyright © 2011-2022 走看看