欢迎关注WX公众号:【程序员管小亮】
最近看到了tf.shape(x)、x.shape和x.get_shape()三个函数,不知道他们的差别,所以记录一下。
import tensorflow as tf
x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])
print(type(x.shape))
print(type(x.get_shape()))
print(type(tf.shape(x)))
> <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>
> <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>
> <class 'tensorflow.python.framework.ops.Tensor'>
可以看到s.shape和x.get_shape()都是返回TensorShapeV1类型对象,而tf.shape(x)返回的是Tensor类型对象。
除此之外,对tf.shape(x)来说,其中x可以是tensor,也可不是tensor,返回是一个tensor。而对x.get_shape()来说,只有tensor有这个方法, 返回是一个tuple。
所以,如果在运行下面代码的时候,
x = tf.placeholder(tf.float32, shape=[None, 227] )
想知道None到底是多少,这时候,只能通过tf.shape(x)[0]这种方式来获得。
而想要获得维度信息,则需要调用前两种方法。
import tensorflow as tf
x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])
print(x.shape)
print(x.get_shape())
print(tf.shape(x))
print(tf.rank(x))
> (3, 3)
> (3, 3)
> Tensor("Shape_3:0", shape=(2,), dtype=int32)
> Tensor("Rank_2:0", shape=(), dtype=int32)
或者是调用ts.as_list()方法,返回的是Python的list。
import tensorflow as tf
x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])
x.shape.as_list()
#x.get_shape().as_list()
> [3, 3]
python课程推荐。