zoukankan      html  css  js  c++  java
  • Pytorch torch.cat(inputs, dimension=0)

    1. torch.cat(inputs, dimension=0)说明

     torch.cat用于对tensor的拼接,dim默认为0,即从第一维度拼接。表示为4维的图像tensor中,第一维默认为batchSize,第二维为channel(通道),第三维为height(图片的高),第四维为width(图片的宽),一般需要基于通道进行拼接。

    2. 例子

    2.1 定义输入

    2.1.1 code

        # ====================================
        # 定义两个4维tensor数据:
        # (batchSize, channel, height, width),
        # 这里定义的一个是一个4维数据,可以定义其
        # 他维度的数据。
        # ====================================
        data1 = torch.rand([1, 1, 3, 3])
        data2 = torch.rand([1, 1, 3, 3])
        print("data1_shape: ", data1.shape)
        print("data1: ", data1)
        print("data2_shape: ", data2.shape)
        print("data2: ", data2)

    2.1.2 输出显示

     data1_shape和data2_shape是tensor的维度信息,代表2个4维tensor。

    2.2 拼接数据

    2.2.1 code

       # ====================================
        # 拼接数据,可以根据dim进行调整,此处的
        # dim = 0: 代表基于batchSize拼接
        # dim = 1: 代表基于通道拼接
        # dim = 2: 代表基于高拼接
        # dim = 3: 代表基于宽拼接
        # ====================================
        data3 = torch.cat([data1, data2], dim=0)
        data4 = torch.cat([data1, data2], dim=1)
        data5 = torch.cat([data1, data2], dim=2)
        data6 = torch.cat([data1, data2], dim=3)
    
        print("data3_shape: ", data3.shape)
        print("data3: ", data3)
    
        print("data4_shape: ", data4.shape)
        print("data4: ", data4)
    
        print("data5_shape: ", data5.shape)
        print("data5: ", data5)
    
        print("data6_shape: ", data6.shape)
        print("data6: ", data6)

    2.2.2 输出显示

    分别从batchSize,channel,height,width进行拼接。

     

  • 相关阅读:
    P3916 图的遍历
    P1656 炸铁路
    P6722 「MCOI-01」Village 村庄
    P1341 无序字母对
    P1072 [NOIP2009 提高组] Hankson 的趣味题
    10大主流自动化测试工具介绍
    Altium Designer中off grid pin问题的解决方法
    Easylogging++的使用及扩展
    博客园粒子特效稳定版
    C#中使用jieba.NET、WordCloudSharp制作词云图
  • 原文地址:https://www.cnblogs.com/haifwu/p/12790416.html
Copyright © 2011-2022 走看看