zoukankan      html  css  js  c++  java
  • pytorch中的Variable

    """
    Variable为tensor数据构建计算图,便于网络的运算
    """
    import torch
    from torch.autograd import Variable
    
    tensor = torch.FloatTensor([[1,2],[3,4]])            # 创建一个tensor类型的数据
    variable = Variable(tensor, requires_grad=True)      # 创建一个variable类型的数据
    
    print(tensor)       # [torch.FloatTensor of size 2x2]
    print(variable)     # [torch.FloatTensor of size 2x2]
    
    t_out = torch.mean(tensor*tensor)       
    v_out = torch.mean(variable*variable) 
    print(t_out)
    print(v_out)    # 7.5
    
    v_out.backward()    # 从v_out开始反向传播

    # 计算谁的梯度,就让开始反向传播的变量对谁进行求导 # v_out = 1/4 * sum(variable*variable) # the gradients w.r.t the variable, d(v_out)/d(variable) = 1/4*2*variable = variable/2 print(variable.grad) ''' 0.5000 1.0000 1.5000 2.0000 ''' print(variable) # variable格式 """ Variable containing: 1 2 3 4 [torch.FloatTensor of size 2x2] """ print(variable.data) # tensor格式 """ 1 2 3 4 [torch.FloatTensor of size 2x2] """ print(variable.data.numpy()) # variable是Variable数据类型,variable.data是tensor类型,variable不可转换为numpy类型 """ [[ 1. 2.] [ 3. 4.]] """
  • 相关阅读:
    MvvmTest
    win8 app 相关的几个网站
    autp
    分析WPF代码工具
    mdsn
    线程和委托
    C#guanli
    学习Boost小结(一)
    Boost.test库的配置
    自己真是太没正事了.
  • 原文地址:https://www.cnblogs.com/czz0508/p/10333359.html
Copyright © 2011-2022 走看看