zoukankan      html  css  js  c++  java
  • 『TensorFlow』张量尺寸获取

    tf.shape(a)和a.get_shape()比较

    相同点:都可以得到tensor a的尺寸

    不同点:tf.shape()中a 数据的类型可以是tensor, list, array

        a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

    import tensorflow as tf  
    import numpy as np  
    
    x=tf.constant([[1,2,3],[4,5,6]]  
    y=[[1,2,3],[4,5,6]]  
    z=np.arange(24).reshape([2,3,4]))  
    
    sess=tf.Session()  
    # tf.shape()  
    x_shape=tf.shape(x)                    #  x_shape 是一个tensor  
    y_shape=tf.shape(y)                    #  <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32>  
    z_shape=tf.shape(z)                    #  <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32>  
    print sess.run(x_shape)              # 结果:[2 3]  
    print sess.run(y_shape)              # 结果:[2 3]  
    print sess.run(z_shape)              # 结果:[2 3 4]  
    
    
    # a.get_shape()  
    # 返回的是TensorShape([Dimension(2), Dimension(3)]),
    # 不能使用 sess.run() 因为返回的不是tensor 或string,而是元组  
    x_shape=x.get_shape()  
    x_shape=x.get_shape().as_list()  # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3]  
    y_shape=y.get_shape()  # AttributeError: 'list' object has no attribute 'get_shape'  
    z_shape=z.get_shape()  # AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'  
    或者a.shape.as_list()
    

    tf.shape(x)

      tf.shape()中x数据类型可以是tensor,list,array,返回是一个tensor.

    shape=tf.placeholder(tf.float32, shape=[None, 227,227,3] )

      我们经常会这样来feed数据,如果在运行的时候想知道None到底是多少,这时候,只能通过tf.shape(x)[0]这种方式来获得.

      由于返回的时tensor,所以我们可以使用其他tensorflow节点操作进行处理,如下面的转置卷积中,使用stack来合并各个shape的分量,

    def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
        with tf.variable_scope('conv_transpose'):
    
            shape = [kernel, kernel, output_filters, input_filters]
            weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
    
            batch_size = tf.shape(x)[0]
            height = tf.shape(x)[1] * strides
            width = tf.shape(x)[2] * strides
            output_shape = tf.stack([batch_size, height, width, output_filters])
    return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')
    

    tensor.get_shape()

      只有tensor有这个方法, 返回是一个tuple。也正是由于返回的是TensorShape([Dimension(2), Dimension(3)])这样的元组,所以可以调用as_list化为[2, 3]样list,或者get_shape()[i].value得到具体值.

    tensor.set_shape()

      设置tensor的shape,一般不会用到,在tfrecode中,由于解析出来的tensor不会被设置shape,后续的函数是需要shape的维度等相关属性的,所以这里会使用.

  • 相关阅读:
    关于asp.net中Repeater控件的一些应用
    Linux查看程序端口占用情况
    php 验证身份证有效性,根据国家标准GB 11643-1999 15位和18位通用
    给Nginx配置一个自签名的SSL证书
    让你提升命令行效率的 Bash 快捷键 [完整版]
    关系数据库常用SQL语句语法大全
    php 跨域 form提交 2种方法
    Vimium~让您的Chrome起飞
    vim tab设置为4个空格
    CENTOS 搭建SVN服务器(附自动部署到远程WEB)
  • 原文地址:https://www.cnblogs.com/hellcat/p/8568099.html
Copyright © 2011-2022 走看看