zoukankan      html  css  js  c++  java
  • pytorch重要函数介绍

    一、torch.nn.Embedding

    模块可以看做一个字典,字典中每个索引对应一个词和词的embedding形式。利用这个模块,可以给词做embedding的初始化操作

    torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

    num_embeddings :字典中词的个数

    embedding_dim:embedding的维度

    padding_idx(索引指定填充):如果给定,则遇到padding_idx中的索引,则将其位置填0(0是默认值)。

    输入输出:

    input:(∗) , LongTensor 结构

    output:(*,e):*是input的大小,e是embedding_dim,即每个词的embedding的维度

    注:embeddings中的值是正态分布N(0,1)中随机取值。

    注意:这里的embedding向量都存储在nn.Embedding.weight变量里,要打印所有的embedding向量,需要

    print(embeddings.weight)

    import torch
    import torch.nn as nn
    x = torch.LongTensor([[1,2,4],[4,3,2]])
    embeddings = nn.Embedding(5,5,padding_idx=4) #5个词,每个词也是5维
    print(embeddings(x))
    print(embeddings(x).size())
     
     
    output:
    tensor([[[ 0.8839, -1.2889,  0.0697, -0.9998, -0.7471],
             [-0.5681,  0.8486,  0.8176,  0.8349,  0.1719],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],  ->index=4 赋值 0
     
            [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],   ->index=4 赋值 0
             [ 1.4224,  0.2333,  1.9383, -0.7320,  0.9987],
             [-0.5681,  0.8486,  0.8176,  0.8349,  0.1719]]],
           grad_fn=<EmbeddingBackward>)
    torch.Size([2, 3, 5])

     二、unsqueeze与squeeze函数

    torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),在0的位置加了一维就变成一行三列(1,3)。注意unsqueeze()必须加上参数,表示在哪个维度前面加上1个维度

    a=torch.arange(6)
    a=a.view(2,3)
    print(a)
    #在第2个维度前面加1个为1的维度。
    print(a.unsqueeze(1)) a.unsqueeze(1).shape
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor([[[0, 1, 2]],
    
            [[3, 4, 5]]])
    
    Out[58]:
    torch.Size([2, 1, 3])
    a=torch.arange(6)
    a=a.view(2,3)
    print(a)
    #在第1个维度前面加上维度为1的维度。
    print(a.unsqueeze(0)) a.unsqueeze(0).shape
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor([[[0, 1, 2],
             [3, 4, 5]]])
    
    Out[60]:
    torch.Size([1, 2, 3])
    unsqueeze(-1)表示在最后后面加上维度为1的维度。


    torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉,不为1的维度没有影响。注意只会影响维度为1的维度。

    b=torch.arange(3)
    b=b.view(3,1)
    print(b)
    #判断第二个维度是否为1,如果为1就去掉该维度。(注:第1个维度是0,第二个维度是1。)
    print(b.squeeze(1)) b.squeeze(1).shape
    tensor([[0],
            [1],
            [2]])
    tensor([0, 1, 2])
    
    Out[52]:
    torch.Size([3])
    b=torch.arange(3)
    b=b.view(3,1)
    print(b)
    #第1个维度不是1,所以不受影响,原b没有变化。
    print(b.squeeze(0)) b.squeeze(0).shape
    tensor([[0],
            [1],
            [2]])
    tensor([[0],
            [1],
            [2]])
    
    Out[53]:
    torch.Size([3, 1])
    b=torch.arange(3)
    b=b.view(3,1,1)
    print(b)
    #不加位置表示去掉所有维度为1的
    print(b.squeeze()) b.squeeze().shape
    tensor([[[0]],
    
            [[1]],
    
            [[2]]])
    tensor([0, 1, 2])

    三、Numpy.cumsum(a,axis=None)

    1、不指定axis
    a = np.array([[1, 2, 3], [4, 5, 6]])
    print(a)
    a.cumsum()
    [[1 2 3]
     [4 5 6]]
    
    Out[61]:
    array([ 1,  3,  6, 10, 15, 21], dtype=int32)
    a=np.array([2,3])
    print(a)
    a.cumsum()
    [2 3]
    
    Out[62]:
    array([2, 5], dtype=int32)

    2、指定axis
    a = np.array([[1, 2, 3], [4, 5, 6]])
    print(a)
    a.cumsum(axis=0)
    [[1 2 3]
     [4 5 6]]
    
    Out[63]:
    array([[1, 2, 3],
           [5, 7, 9]], dtype=int32)
    a = np.array([[1, 2, 3], [4, 5, 6]])
    print(a)
    a.cumsum(axis=1)
    [[1 2 3]
     [4 5 6]]
    
    Out[64]:
    array([[ 1,  3,  6],
           [ 4,  9, 15]], dtype=int32)

    技巧:array前面加一个元素
    a=np.array([2,3])
    np.array((0,*a))
    array([0, 2, 3])

    四、torch.sum
    a = torch.ones((2, 3))
    a1 =  torch.sum(a)
    a2 =  torch.sum(a, dim=0)
    a3 =  torch.sum(a, dim=1)
    a
    a1
    a2
    a3
    tensor([[1., 1., 1.],
            [1., 1., 1.]])
    Out[90]:
    tensor(6.)
    Out[90]:
    tensor([2., 2., 2.])
    Out[90]:
    tensor([3., 3.])
  • 相关阅读:
    C 语言编程经典 100 例
    visual studio.net已检测到指定的web服务器运行的不是asp.net1.1版。无法运行asp.net web应用程序
    如何编译及运行java
    VBScript 函数集
    SQL SERVER定时作业的设置方法
    显示桌面按钮不小心被删,有什么办法找回?
    随机抽取n个记录的SQL
    打开项目时提示如下错误:Visual Studio .NET 无法创建应用程序 。问题很可能是因为本地 Web 服务器上没有安装所需的组件
    简单的数据库连接
    ASP中各种数据库连接代码
  • 原文地址:https://www.cnblogs.com/gczr/p/14353477.html
Copyright © 2011-2022 走看看