1. torch.cat(inputs, dimension=0)说明
torch.cat用于对tensor的拼接,dim默认为0,即从第一维度拼接。表示为4维的图像tensor中,第一维默认为batchSize,第二维为channel(通道),第三维为height(图片的高),第四维为width(图片的宽),一般需要基于通道进行拼接。
2. 例子
2.1 定义输入
2.1.1 code
# ==================================== # 定义两个4维tensor数据: # (batchSize, channel, height, width), # 这里定义的一个是一个4维数据,可以定义其 # 他维度的数据。 # ==================================== data1 = torch.rand([1, 1, 3, 3]) data2 = torch.rand([1, 1, 3, 3]) print("data1_shape: ", data1.shape) print("data1: ", data1) print("data2_shape: ", data2.shape) print("data2: ", data2)
2.1.2 输出显示
data1_shape和data2_shape是tensor的维度信息,代表2个4维tensor。
2.2 拼接数据
2.2.1 code
# ==================================== # 拼接数据,可以根据dim进行调整,此处的 # dim = 0: 代表基于batchSize拼接 # dim = 1: 代表基于通道拼接 # dim = 2: 代表基于高拼接 # dim = 3: 代表基于宽拼接 # ==================================== data3 = torch.cat([data1, data2], dim=0) data4 = torch.cat([data1, data2], dim=1) data5 = torch.cat([data1, data2], dim=2) data6 = torch.cat([data1, data2], dim=3) print("data3_shape: ", data3.shape) print("data3: ", data3) print("data4_shape: ", data4.shape) print("data4: ", data4) print("data5_shape: ", data5.shape) print("data5: ", data5) print("data6_shape: ", data6.shape) print("data6: ", data6)
2.2.2 输出显示
分别从batchSize,channel,height,width进行拼接。