zoukankan      html  css  js  c++  java
  • tf.stack() /tf.unstack()

    tf.stack函数

    tf.stack(
        values,
        axis=0,(default)
        name='stack'
    )

    将 values 中的张量列表打包成一个张量,该张量比 values 中的每个张量都高一个秩,通过沿 axis 维度打包。给定一个形状为(A, B, C)的张量的长度 N 的列表;

    如果 axis == 0,那么 output 张量将具有形状(N, A, B, C)。如果 axis == 1,那么 output 张量将具有形状(A, N, B, C)。

    如果 axis == 2,那么 output 张量将具有形状( A, B, N, C)。如果 axis == 3,那么 output 张量将具有形状(A, B, C, N)。

    函数参数:

    • values:具有相同形状和类型的 Tensor 对象列表.
    • axis:一个 int,要一起堆叠的轴,默认为第一维,负值环绕,所以有效范围是[-(R+1), R+1).
    • name:此操作的名称(可选).

    函数返回值:

    • output:与 values 具有相同的类型的堆叠的 Tensor.

    可能引发的异常:

    • ValueError:如果 axis 超出范围 [ - (R + 1),R + 1),则引发此异常
     
    import tensorflow as tf
    a = tf.constant([[1,2,3],[4,5,6]])#2*3
    b = tf.constant([[7,8,9],[0,1,7]])
    c1 = tf.stack([a,b],axis = 0)
    c2 = tf.stack([a,b],axis = 1)
    c3 = tf.stack([a,b],axis = 2)
    #c4 = tf.stack([a,b],axis = 3)
    
    with tf.Session() as sess:
        result1 = sess.run(c1)
        print('1  :',result1)
        print(result1.shape)
        result2= sess.run(c2)
        print('2 :',result2)
        print(result2.shape)
        result3 = sess.run(c3)
        print('3  :',result3)
        print(result3.shape)

    1 : [[[1 2 3]
    [4 5 6]]

    [[7 8 9]
    [0 1 7]]]
    (2, 2, 3)
    2 : [[[1 2 3]
    [7 8 9]]

    [[4 5 6]
    [0 1 7]]]
    (2, 2, 3)
    3 : [[[1 7]
    [2 8]
    [3 9]]

    [[4 0]
    [5 1]
    [6 7]]]
    (2, 3, 2)

    tf.unstack()
    tf.unstack(value, num=None, axis=0, name=’unstack’)
    以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R-1的张量。即将一组张量以指定的轴,减少一个维度。正好和stack()相反。

    将张量value分割成num个张量数组。如果num没有指定,则是根据张量value的形状来指定。如果value.shape[axis]不存在,则抛出ValueError的异常。

    假如一个张量的形状是(A, B, C, D)。
    如果axis == 0,则输出的张量是value[i, :, :, :],i取值为[0,A),每个输出的张量的形状为(B,C,D)。
    如果axis == 1,则输出的张量是value[:, i, :, :],i取值为[0,B),每个输出的张量的形状为(A,C,D)。
    如果axis == 2,则输出的张量是value[:, :, i, :],i取值为[0,C),每个输出的张量的形状为(A,B,D)。依次类推。

    axis可这样理解:unstack就是要将一个张量降低为低一个维度的张量数组。axis就是将axis指定的维度,用所有这个张量里同维度的数据代替。

    参数:
    value: 一个将要被降维的维度大于0的张量。
    num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
    axis: 整数.以轴axis指定的维度来转变 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
    name: 这个操作的名字(可选)
    返回:
    从张量value降维后的张量数组。
    异常:
    ValueError: 如果num没有指定并且无法求出来。
    ValueError: 如果axis超出范围 [-R, R)。

    import tensorflow as tf
    import numpy as np
    t=np.random.randint(1,10,(3,4))
    ustack1=tf.unstack(t,axis=1)#4个(3)
    ustack2=tf.unstack(t,axis=0)#3个(4)
    sess=tf.Session()
    print(t)
    print(sess.run(ustack1))
    print(sess.run(ustack2))

    [[6 8 4 1]
    [4 9 2 7]
    [2 6 1 3]]
    [array([6, 4, 2]), array([8, 9, 6]), array([4, 2, 1]), array([1, 7, 3])]
    [array([6, 8, 4, 1]), array([4, 9, 2, 7]), array([2, 6, 1, 3])]

     
  • 相关阅读:
    linux 解压tgz 文件指令
    shell 脚本没有执行权限 报错 bash: ./myshell.sh: Permission denied
    linux 启动solr 报错 Your Max Processes Limit is currently 31202. It should be set to 65000 to avoid operational disruption.
    远程查询批量导入数据
    修改 MZTreeView 赋权节点父节点选中子节点自动选中的问题
    关于乱码的问题解决记录
    我的网站优化之路
    对设计及重构的一点反思
    我的五年岁月
    奔三的路上
  • 原文地址:https://www.cnblogs.com/tingtin/p/12554233.html
Copyright © 2011-2022 走看看