zoukankan      html  css  js  c++  java
  • pytorch使用说明

    pytorch使用说明

    1.torch和numpy的转换

    import torch
    import numpy as np
    
    np_data = np.arange(6).reshape((2, 3))
    torch_data = torch.from_numpy(np_data)
    tensor2array = torch_data.numpy()
    

    2.torch中的数学运算

    # abs 绝对值计算
    data = [-1, -2, 1, 2]
    tensor = torch.FloatTensor(data)  # 转换成32位浮点 tensor
    print(
        '
    abs',
        '
    numpy: ', np.abs(data),          # [1 2 1 2]
        '
    torch: ', torch.abs(tensor)      # [1 2 1 2]
    )
    
    # sin   三角函数 sin
    print(
        '
    sin',
        '
    numpy: ', np.sin(data),      # [-0.84147098 -0.90929743  0.84147098  0.90929743]
        '
    torch: ', torch.sin(tensor)  # [-0.8415 -0.9093  0.8415  0.9093]
    )
    
    # mean  均值
    print(
        '
    mean',
        '
    numpy: ', np.mean(data),         # 0.0
        '
    torch: ', torch.mean(tensor)     # 0.0
    )
    
    # matrix multiplication 矩阵点乘
    data = [[1,2], [3,4]]
    tensor = torch.FloatTensor(data)  # 转换成32位浮点 tensor
    # correct method
    print(
        '
    matrix multiplication (matmul)',
        '
    numpy: ', np.matmul(data, data),     # [[7, 10], [15, 22]]
        '
    torch: ', torch.mm(tensor, tensor)   # [[7, 10], [15, 22]]
    )
    
    # !!!!  下面是错误的方法 !!!!
    data = np.array(data)
    print(
        '
    matrix multiplication (dot)',
        '
    numpy: ', data.dot(data),        # [[7, 10], [15, 22]] 在numpy 中可行
        '
    torch: ', tensor.dot(tensor)     # torch 会转换成 [1,2,3,4].dot([1,2,3,4) = 30.0
    )
    
    

    3. 什么是Variable

    在Torch中的Variable就是一个存放会变化的值的地理位置。里面的值会不停的变化。其中的值就是torch的Tensor.如果用Variable进行计算,那返回的也是一个同类型的Variable.

    定义一个Variable:

    import torch
    from torch.autograd import Variable # torch 中 Variable 模块
    
    # 先生鸡蛋
    tensor = torch.FloatTensor([[1,2],[3,4]])
    # 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
    variable = Variable(tensor, requires_grad=True)
    
    print(tensor)
    """
     1  2
     3  4
    [torch.FloatTensor of size 2x2]
    """
    
    print(variable)
    """
    Variable containing:
     1  2
     3  4
    [torch.FloatTensor of size 2x2]
    """
    
    

    对比一下tensor的计算和variable的计算

    t_out = torch.mean(tensor*tensor)       # x^2
    v_out = torch.mean(variable*variable)   # x^2
    print(t_out)
    print(v_out)    # 7.5
    

    时刻计住,Variable计算是,它在背景幕布后面一步步默默搭建着一个庞大的系统,叫做计算图,computational graph.这个图将所有的计算步骤(节点)都连接起来,最后进行误差反向传递的时候一次性将所有variable里面的修改幅度(梯度)都计算出来,而tensor就没有这个能力。

    获取Variable里面的数据

    直接print(variable)只会输出Variable形式的数据,在很多时候是用不了的(画图), 所以我们要将其变成tensor形式。

    print(variable)     #  Variable 形式
    """
    Variable containing:
     1  2
     3  4
    [torch.FloatTensor of size 2x2]
    """
    print(variable.data)    # tensor 形式
    """
     1  2
     3  4
    [torch.FloatTensor of size 2x2]
    """
    print(variable.data.numpy())    # numpy 形式
    """
    [[ 1.  2.]
     [ 3.  4.]]
    """
    
    

    4.激活函数

    import torch
    import numpy as np
    import torch
    import torch.nn.functional as F     # 激励函数都在这
    from torch.autograd import Variable
    
    # 做一些假数据来观看图像
    x = torch.linspace(-5, 5, 200)  # x data (tensor), shape=(100, 1)
    x = Variable(x)
    x_np = x.data.numpy()   # 换成 numpy array, 出图时用
    
    # 几种常用的 激励函数
    y_relu = F.relu(x).data.numpy()
    y_sigmoid = F.sigmoid(x).data.numpy()
    y_tanh = F.tanh(x).data.numpy()
    y_softplus = F.softplus(x).data.numpy()
    # y_softmax = F.softmax(x)  softmax 比较特殊, 不能直接显示, 不过他是关于概率的, 用于分类
    
    if __name__ == '__main__':
        import matplotlib.pyplot as plt  # python 的可视化模块, 我有教程 			   (https://morvanzhou.github.io/tutorials/data-manipulation/plt/)
    
    	plt.figure(1, figsize=(8, 6))
    	plt.subplot(221)
    	plt.plot(x_np, y_relu, c='red', label='relu')
    	plt.ylim((-1, 5))
        plt.legend(loc='best')
    
        plt.subplot(222)
        plt.plot(x_np, y_sigmoid, c='red', label='sigmoid')
        plt.ylim((-0.2, 1.2))
        plt.legend(loc='best')
    
        plt.subplot(223)
        plt.plot(x_np, y_tanh, c='red', label='tanh')
        plt.ylim((-1.2, 1.2))
        plt.legend(loc='best')
    
        plt.subplot(224)
        plt.plot(x_np, y_softplus, c='red', label='softplus')
        plt.ylim((-0.2, 6))
        plt.legend(loc='best')
    
        plt.show()
    
    
  • 相关阅读:
    uploadify上传文件代码
    事务处理拼接sql语句对数据库的操作.异常回滚
    Scrum【转】
    Redis
    mybatis
    Spring MVC
    IOC的理解(转载)
    spring IOC与AOP
    git
    python基础2
  • 原文地址:https://www.cnblogs.com/o-v-o/p/10946130.html
Copyright © 2011-2022 走看看