zoukankan      html  css  js  c++  java
  • mxnet symbol reshape用法

    mx.symbol.reshape

    对于给定输入的array和其shape,可以返回一个含有新shape的一个copy。shape是整形元组类型,可以包含可选的几个负数。

    一些维度的可选值有:{0, -1, -2, -3, -4}

    1. 维度0的作用是复制输入的该维度到对应输出:

    data=mx.sym.Variable('data')   # 输入symbol
    data=mx.sym.Reshape(data=data, shape=(4,0,2))    # reshape目标
    print(data.infer_shape(data=(2,3,4))[1])    # 用输入形状推理输出形状,infer_shape用法见这里~

    输出: (4,3,2)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(2,0,0))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,3,4)

    2. 维度-1的作用是利用剩余的维度来推断该维度,要保持所有维度尺寸一样:

    data=mx.sym.Variable('data')    
    data=mx.sym.Reshape(data=data, shape=(6,1,-1))   
    print(data.infer_shape(data=(2,3,4))[1])    

    输出:(6,1,4)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(3,-1,8))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(3,1,8)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-1))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(24,)

    3. 维度-2的作用是拷贝全部或剩余的维度到输出

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-2))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,3,4)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(2,-2))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,3,4)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-2,1,1))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,3,4,1,1)

    4. 维度-3的作用是利用两个连续维度之积作为对应输出维度

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-3,4))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(6,4)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-3,-3))
    print(data.infer_shape(data=(2,3,4,5))[1])

    输出:(6,20)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(0,-3))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,12)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-3,-2))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(6,4)

    5. 维度-4的作用是将输入的一个维度划分成后续的两个维度(可含-1)

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-4,1,2,-2))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(1,2,3,4)

    这个稍难理解解释一下:输入是(2,3,4),reshape的目标是(-4,1,2,-2),且-4后续的两个维度为1和2,即希望将-4对应的维度(对应输入的2)分解成(1,2)。 此时-2将剩余的(3,4)拉过来就变成了(1,2,3,4)。

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(2,-4,-1,3,-2))
    print(data.infer_shape(data=(2,3,4))[1])

    输出:(2,1,3,4)

    也可以连续reshape:

    data=mx.sym.Variable('data')
    data=mx.sym.Reshape(data=data, shape=(-1, -4, -1, 1, 3, 256, 256))
    print(data.infer_shape(data=(16, 8, 3, 256, 256))[1])
    data=mx.sym.Reshape(data=data, shape=(-3,-3,-2))
    print(data.infer_shape(data=(16, 8, 3, 256, 256))[1])

    输出:16,8,1,3,256,256

    输出:128,3,256,256

  • 相关阅读:
    抽象工厂例子
    学习boost::asio一些小例子
    boost::asio学习(定时器)
    共享内存
    网络流程图
    粘包
    端游服务器群
    38 写一个函数,求一个字符串的长度,在main函数中输入字符串,并输出其长度。
    37 有n个人围成一圈,顺序排号,从第一个人开始报数(从1到3报数),凡报到3的人退出圈子,问最后留下的是原来第几号那位.
    36 有n个整数,使其前面各数顺序向后移n个位置,最后m个数变成最前面的m个数
  • 原文地址:https://www.cnblogs.com/king-lps/p/13127456.html
Copyright © 2011-2022 走看看