zoukankan      html  css  js  c++  java
  • Pytorch的torch.cat实例

    import torch
    

      

    通过 help((torch.cat)) 可以查看 cat 的用法
    cat(seq,dim,out=None)
     
    其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b),  a,b 为两个可以连接的序列
    dim 表示以哪个维度连接,dim=0, 横向连接
                          dim=1,纵向连接
     
    

     

    #实例:
     
        #dim=0 时:
        
        import torch
        n_data = torch.ones((100,2))
        x0_data = torch.normal(2*n_data,1)
        y0_data = torch.zeros((100,1))
        x1_data = torch.normal(-2*n_data,1)
        y1_data = torch.ones((100,1))
        x_data = torch.cat((x0_data,x1_data),0).type(torch.FloatTensor)
        y_data = torch.cat((y0_data,y1_data),0).type(torch.LongTensor)
        print('x_data的形状:',x_data.shape)
        print("y_data的形状:",y_data.shape)
    

      

    result:
        
        x_data的形状: torch.Size([200, 2])
        y_data的形状: torch.Size([200, 1])
    

      

    #实例:
     
        #dim=1 时:
        
        import torch
        n_data = torch.ones((100,2))
        x0_data = torch.normal(2*n_data,1)
        y0_data = torch.zeros((100,1))
        x1_data = torch.normal(-2*n_data,1)
        y1_data = torch.ones((100,1))
        x_data = torch.cat((x0_data,x1_data),1).type(torch.FloatTensor)
        y_data = torch.cat((y0_data,y1_data),1).type(torch.LongTensor)
        print('x_data的形状:',x_data.shape)
        print("y_data的形状:",y_data.shape)
    

      

    result:
     
        x_data的形状: torch.Size([100, 4])
        y_data的形状: torch.Size([100, 2])
    

      

  • 相关阅读:
    printf和sprintf
    操作数、运算符、表达式
    全自动加法机
    Ascll、GB2312、Ansi
    数组
    循环
    编程命名规范
    浮点数及缺陷
    Android编码规范
    RGB着色对照表
  • 原文地址:https://www.cnblogs.com/pythonClub/p/10412418.html
Copyright © 2011-2022 走看看