zoukankan      html  css  js  c++  java
  • Tensorflow中Tensor对象的常用方法(持续更新)

      Tensor是Tensorflow中重要的对象。下面是Tensor的常用方法,后面还会写一篇随笔记录Variable的用法。

      1. 生成一个(常)Tensor对象

    >>>A = tf.constant(4)
    >>>B = tf.constant([[1, 2], [3, 4]))
    >>>A
    <tf.Tensor: id=76, shape=(), dtype=int32, numpy=4>
    >>>B
    <tf.Tensor: id=77, shape=(2, 2), dtype=int32, numpy=
    array([[1, 2],
           [3, 4]], dtype=int32)>
    

      Tensor对象和ndarray对象看起来很像但也有差别,一个最大的差别是Tensor是不可变的(immutable)。这意味着你永远也无法随心所欲的对Tensor进行赋值,只能新创建一个新的Tensor。

      2. 和Ndarray的互相转换

    >>>B.numpy()
    array([[1, 2],
           [3, 4]], dtype=int32)
    >>>np.array(B)
    array([[1, 2],
           [3, 4]], dtype=int32)
    >>>D = np.array([[1, 2], [3, 4]])
    >>>tf.convert_to_tensor(D, dtype='int32')
    <tf.Tensor: id=79, shape=(2, 2), dtype=int32, numpy=
    array([[1, 2],
           [3, 4]], dtype=int32)>
    

      Tensorflow2引入了叫做Eager execution的机制,让Tensor和ndarray具有一样的运算灵活性。除了以上的转换方式之外,任意的Tensorflow操作都可以生成(返回)Tensor对象。

      3. 矩阵运算

    >>>a = tf.constant([[1, 2],
                     [3, 4]])
    >>>b = tf.constant([[1, 1],
                     [1, 1]]) # Could have also said `tf.ones([2,2])`
    >>>print(tf.add(a, b), "
    ")
    tf.Tensor(
    [[2 3]
     [4 5]], shape=(2, 2), dtype=int32) 
    >>>print(tf.multiply(a, b), "
    ")
    tf.Tensor(
    [[1 2]
     [3 4]], shape=(2, 2), dtype=int32)
    >>>print(tf.matmul(a, b), "
    ")
    tf.Tensor(
    [[3 3]
     [7 7]], shape=(2, 2), dtype=int32) 
    

      以上三个操作返回的都是Tensor对象,同时这三个操作可以使用'+', '*', '@'代替。

      4. 三种常用的操作

    >>>c = tf.constant([[4.0, 5.0], [10.0, 1.0]])
    >>>print(tf.reduce_max(c)) # Find the largest value
    tf.Tensor(10.0, shape=(), dtype=float32)
    >>>print(tf.argmax(c)) # Find the index of the largest value
    tf.Tensor([1 0], shape=(2,), dtype=int64)
    >>>print(tf.nn.softmax(c)) # # Compute the softmax
    tf.Tensor(
    [[2.6894143e-01 7.3105860e-01]
     [9.9987662e-01 1.2339458e-04]], shape=(2, 2), dtype=float32)
    

      三种看名字就能看出功能的操作,其中tf.reduce_XX()是tensorflow中降维的操作。类似的操作:'reduce_all', 'reduce_any', 'reduce_logsumexp', 'reduce_max', 'reduce_mean', 'reduce_min', 'reduce_prod', 'reduce_sum'。

      

       5. Dtype转换  

    >>>the_f64_tensor = tf.constant([2.2, 3.3, 4.4], dtype=tf.float64)
    >>>the_f16_tensor = tf.cast(the_f64_tensor, dtype=tf.float16)
    # Now, let's cast to an uint8 and lose the decimal precision
    >>>the_u8_tensor = tf.cast(the_f16_tensor, dtype=tf.uint8) 
    >>>print(the_u8_tensor)
    tf.Tensor([2 3 4], shape=(3,), dtype=uint8)
    

      numpy中可以使用astype()来进行转换,Tensorflow中则使用tf.cast()方法来转化不同数据类型的Tensor。

       

      6. 广播操作

      Tensor的广播操作和numpy中基本完全一样,机制可以看这篇文章:https://jakevdp.github.io/PythonDataScienceHandbook/02.05-computation-on-arrays-broadcasting.html

  • 相关阅读:
    PHP算法每日一练 双向链表
    Web开发者必备的十大免费在线工具网站
    使用PXE+DHCP+Apache+Kickstart无人值守安装CentOS5.5
    linux服务器状态、性能相关命令
    PHP算法每日一练 单链表
    [转]DELPHI2006中for in语句的应用
    [转]Delphi线程类
    [转]解耦:Delphi下IoC 模式的实现
    [DELPHI]单例模式(singleton) 陈省
    [转][Delphi]解决窗体闪烁的方法
  • 原文地址:https://www.cnblogs.com/chester-cs/p/13020001.html
Copyright © 2011-2022 走看看