zoukankan      html  css  js  c++  java
  • Pytorch之认识Variable

    Tensor是Pytorch的一个完美组件(可以生成高维数组),但是要构建神经网络还是远远不够的,我们需要能够计算图的Tensor,那就是Variable。Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn。

    # 通过一下方式导入Variable
    from torch.autograd import Variable
    import torch
    x_tensor = torch.randn(10,5)
    y_tensor = torch.randn(10,5)

    #将tensor转换成Variable
    x = Variable(x_tensor,requires_grad=True) #Varibale 默认时不要求梯度的,如果要求梯度,需要说明
    y = Variable(y_tensor,requires_grad=True)
    z = torch.sum(x + y)
    print(z.data)
    print(z.grad_fn)

    z.backward()
    print(x.grad)
    print(y.grad)

    tensor(7.0406)
    <SumBackward0 object at 0x000002A557C47908>
    tensor([[1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]])
    tensor([[1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]])
    上面打印出了z的Tensor数值,以及通过.grad_fn得到其是通过sum这种方式得到的,通过.grad得到了x和y的梯度

    #构建一个y = x^2 函数 求x = 2 的导数
    import numpy as np
    import torch
    from torch.autograd import Variable
    # 1、画出函数图像
    import matplotlib.pyplot as plt
    x = np.arange(-3,3.01,0.1)
    y = x**2
    plt.plot(x,y)
    plt.plot(2,4,'ro')
    plt.show()

    #定义点variable类型的x = 2

    x = Variable(torch.FloatTensor([2]),requires_grad=True)
    y = x ** 2
    y.backward()
    print(x.grad)





  • 相关阅读:
    MyEclipse添加SVN插件
    Postgresql的character varying = bytea问题
    Hibernate主键增加方式
    java配置环境变量
    Maven常用构建命令
    Postgresql的主键自增长
    js判断金额
    最精简的SQL教程
    SQL练习1:统计班级男女生人数
    sql 百万级数据库优化方案
  • 原文地址:https://www.cnblogs.com/ryluo/p/10190218.html
Copyright © 2011-2022 走看看