zoukankan      html  css  js  c++  java
  • TF2常用函数

    1.

    • 利用 tf.cast (张量名,dtype=数据类型)强制将 Tensor 转换为该数据类型;
    • 利用tf.reduce_min (张量名)计算张量维度上元素的最小值;
    • 利用tf.reduce_max (张量名)计算张量维度上元素的最大值。

    举例如下:

     2.维度的定义

    由上图可知对于一个二维张量,如果 axis=0 表示纵向操作(沿经度方向) ,axis=1表示横向操作(沿纬度方向)。
    可用 tf.reduce_mean (张量名,axis=操作轴)计算张量沿着指定维度的平均值;可用 f.reduce_sum (张量名,axis=操作轴)计算张量沿着指定维度的和,如不指定 axis,则表示对所有元素进行操作。

    3.tf.Variable()函数

    可利用 tf.Variable(initial_value,trainable,validate_shape,name)函数可以将变量标记为“可训练”的被它标记了的变量,会在反向传播中记录自己的梯度信息
    其中
    • initial_value 默认为 None,可以搭配 tensorflow 随机生成函数来初始化参数;
    • trainable 默认为 True,表示可以后期被算法优化的,如果不想该变量被优化,即改为 False;
    • validate_shape 默认为 True,形状不接受更改,如果需要更改,validate_shape=False;
    • name 默认为 None,给变量确定名称。
    举例如下:w = tf.Variable(tf.random.normal([2, 2], mean=0, stddev=1))
    表示首先随机生成正态分布随机数,再给生成的随机数标记为可训练,这样在反向传播中就可以通过梯度下降更新参数 w了

     4.对应元素的四则运算

    • 利用 TensorFlow 中函数对张量进行四则运算。利用 tf.add (张量 1,张量2)实现两个张量的对应元素相加;
    • 利用 tf.subtract (张量 1,张量 2)实现两个张量的对应元素相减;
    • 利用 tf.multiply (张量 1,张量 2)实现两个张量的对应元素相乘;
    • 利用 tf.divide (张量 1,张量 2)实现两个张量的对应元素相除。
    注:只有维度相同的张量才可以做四则运算
    举例如下:

     

     5.利用 TensorFlow 中函数对张量进行幂次运算。

    • 可用 tf.square (张量名)计算某个张量的平方;
    • 利用 tf.pow (张量名,n 次方数)计算某个张量的 n 次方;
    • 利用 tf.sqrt (张量名)计算某个张量的开方。

    举例如下:

    6.可利用 tf.matmul(矩阵 1,矩阵 2)实现两个矩阵的相乘

    a = tf.ones([3, 2])
    b = tf.fill([2, 3], 3.)
    print(tf.matmul(a, b))
    输出结果:tf.Tensor([[6. 6. 6.] [6. 6. 6.] [6. 6. 6.]], shape=(3, 3), dtype=float32),
    即 a为一个 3 行 2 列的全 1 矩阵,b 为 2 行 3 列的全 3 矩阵,二者进行矩阵相乘。
     
    7.可利用 tf.data.Dataset.from_tensor_slices((输入特征, 标签))切分传入张量的第一维度,生成输入特征/标签对,构建数据集,
    此函数对 Tensor 格式与 Numpy格式均适用,其切分的是第一维度,表征数据集中数据的数量,之后切分 batch等操作都以第一维为基础。
    举例如下:
    features = tf.constant([12,23,10,17])
    labels = tf.constant([0, 1, 1, 0])
    dataset = tf.data.Dataset.from_tensor_slices((features, labels))
    print(dataset)
    for element in dataset:
    print(element)
    

     8.可利用 tf.GradientTape( )函数搭配 with 结构计算损失函数在某一张量处的梯度

    举例如下:

    9.可利用 enumerate(列表名)函数枚举出每一个元素,并在元素前配上对应的索引号,常在 for 循环中使用。
    举例如下:

     

    10.可用 tf.one_hot(待转换数据,depth=几分类)函数实现用独热码表示标签
    在分类问题中很常见。标记类别为为 1 和 0,其中 1 表示是,0 表示非。
    如在鸢尾花分类任务中,如果标签是 1,表示分类结果是 1 杂色鸢尾,其用把它用独热码表示就是 0,1,0,这样可以表示出每个分类的概率:也就是百分之 0 的可能是 0狗尾草鸢尾,百分百的可能是 1 杂色鸢尾,百分之 0 的可能是弗吉尼亚鸢尾。
    举例如下: 
    classes = 3
    labels = tf.constant([1,0,2])
    output = tf.one_hot( labels, depth=classes )
    print(output)
    输出结果:tf.Tensor([[0. 1. 0.] [1. 0. 0.] [0. 0. 1.]], shape=(3, 3), dtype=float32)
    索引从 0 开始,待转换数据中各元素值应小于 depth,若带转换元素值大于等于depth,则该元素输出编码为 [0, 0 … 0, 0]。即 depth 确定列数,待转换元素的个数确定行数。
    举例如下:
    classes = 3
    labels = tf.constant([1,4,2]) # 输入的元素值 4 超出 depth-1
    output = tf.one_hot(labels,depth=classes)
    print(output) 
    输出结果:tf.Tensor([[0. 1. 0.] [0. 0. 0.] [0. 0. 1.]], shape=(3, 3), dtype=float32),即元素 4 对应的输出编码为[0. 0. 0.]。

    11. tf.nn.softmax

    可利用 tf.nn.softmax( )函数使前向传播的输出值符合概率分布,进而与独热码形式的标签作比较,其计算公式为
                                                            
     其中 yi是前向传播的输出。在前一部分,我们得到了前向传播的输出值,分别为 1.01、2.01、-0.66,通过上述计算公式,可计算对应的概率值

    上式中,0.256 表示为 0 类鸢尾的概率是 25.6%,0.695 表示为 1 类鸢尾的概率是69.5%,0.048 表示为 2 类鸢尾的概率是 4.8%。

     

     

    y = tf.constant ( [1.01, 2.01, -0.66] )
    y_pro = tf.nn.softmax(y)
    print("After softmax, y_pro is:", y_pro)
    输出结果:
    After softmax, y_pro is:tf.Tensor([0.25598174 0.69583046 0.0481878], shape=(3,), dtype=float32)

     12.assign_sub

    可利用 assign_sub 对参数实现自更新。使用此函数前需利用 tf.Variable定义变量 w为可训练(可自更新)
    举例如下:
    w = tf.Variable(4)
    w.assign_sub(1)
    print(w)
    输出结果:<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3> 即实现了参数
    w自减 1。
    注:直接调用 tf.assign_sub 会报错,要用 w.assign_sub
     
    13.可利用 tf.argmax (张量名,axis=操作轴)返回张量沿指定维度最大值的索引

    14 tf.reshape()

     

  • 相关阅读:
    android gradle 打包命令
    android RRO
    android adb 常用命令
    mui执行滑动事件: Unable to preventDefault inside passive event listener
    获取 input[type=file] 文件上传尺寸
    MySQL:You can't specify target table for update in FROM clause
    input标签中autocomplete="off" 失效的解决办法
    @media属性针对苹果手机写法
    centos7 下mysql5.7修改默认编码格式为UTF-8
    使用Mui加载数据后a标签点击事件失效
  • 原文地址:https://www.cnblogs.com/GumpYan/p/13549293.html
Copyright © 2011-2022 走看看