zoukankan      html  css  js  c++  java
  • pytorch中tensor.mean(axis, keepdim)

     1 import numpy as np
     2 import torch
     3 
     4 x=[
     5 [[1,2,3,4],
     6  [5,6,7,8],
     7  [9,10,11,12]],
     8 
     9 [[13,14,15,16],
    10  [17,18,19,20],
    11  [21,22,23,24]]
    12 ]
    13 x=torch.tensor(x).float()
    14 #
    15 print("shape of x:")  ##[2,3,4]
    16 print(x.shape)
    17 #
    18 print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
    19 print(x.mean(axis=0,keepdim=True).shape)
    20 print(x.mean(axis=0,keepdim=True))
    21 #
    22 print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
    23 print(x.mean(axis=0,keepdim=False).shape)
    24 print(x.mean(axis=0,keepdim=False))
    25 #
    26 print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
    27 print(x.mean(axis=1,keepdim=True).shape)
    28 print(x.mean(axis=1,keepdim=True))
    29 #
    30 print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
    31 print(x.mean(axis=1,keepdim=False).shape)
    32 print(x.mean(axis=1,keepdim=False))
    33 #
    34 print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
    35 print(x.mean(axis=2,keepdim=True).shape)
    36 print(x.mean(axis=2,keepdim=True))
    37 #
    38 print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
    39 print(x.mean(axis=2,keepdim=False).shape)
    40 print(x.mean(axis=2,keepdim=False))
    shape of x:
    torch.Size([2, 3, 4])
    shape of x.mean(axis=0,keepdim=True):
    torch.Size([1, 3, 4])
    tensor([[[ 7.,  8.,  9., 10.],
             [11., 12., 13., 14.],
             [15., 16., 17., 18.]]])
    shape of x.mean(axis=0,keepdim=False):
    torch.Size([3, 4])
    tensor([[ 7.,  8.,  9., 10.],
            [11., 12., 13., 14.],
            [15., 16., 17., 18.]])
    shape of x.mean(axis=1,keepdim=True):
    torch.Size([2, 1, 4])
    tensor([[[ 5.,  6.,  7.,  8.]],
    
            [[17., 18., 19., 20.]]])
    shape of x.mean(axis=1,keepdim=False):
    torch.Size([2, 4])
    tensor([[ 5.,  6.,  7.,  8.],
            [17., 18., 19., 20.]])
    shape of x.mean(axis=2,keepdim=True):
    torch.Size([2, 3, 1])
    tensor([[[ 2.5000],
             [ 6.5000],
             [10.5000]],
    
            [[14.5000],
             [18.5000],
             [22.5000]]])
    shape of x.mean(axis=2,keepdim=False):
    torch.Size([2, 3])
    tensor([[ 2.5000,  6.5000, 10.5000],
            [14.5000, 18.5000, 22.5000]])

    keepdim=True
    运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);

    keepdim=False
    运算完之后一般少一维度,求平均变为1的那一维没有了;

    axis=k
    按第k维运算,其他维度不遍,第k维变为1

    # print(x.mean().shape)
    # print(x.mean())

    shape of x:
    torch.Size([2, 3, 4])
    torch.Size([])
    tensor(12.5000)#所有值的平均值

  • 相关阅读:
    Oracle 建用户、 表空间脚本
    Java常见Jar包的用途
    EF:无法检查模型兼容性,因为数据库不包含模型元数据。
    Eclipse -Xms256M -Xmx640M -XX:PermSize=256m -XX:MaxPermSize=768m
    CentOS远程连接Windows操作系统
    spring boot / cloud (二十) 相同服务,发布不同版本,支撑并行的业务需求
    jvm
    jvm
    spring boot / cloud (十九) 并发消费消息,如何保证入库的数据是最新的?
    spring boot / cloud (十八) 使用docker快速搭建本地环境
  • 原文地址:https://www.cnblogs.com/tingtin/p/13617470.html
Copyright © 2011-2022 走看看