zoukankan      html  css  js  c++  java
  • TensorFlow 内置重要函数解析

    概要

    本部分介绍一些在 TensorFlow 中内置的重要函数,了解这些函数有时候更加方便我们进行数据的处理或者构建神经网络。

    这些函数如下:

     
        tf.one_hot()
        tf.random_shuffle()
     


     

    主要内容

    tf.one_hot()

     
    这是一个用来生成符合 one_hot 编码的张量的函数。完整参数形式是:

    tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
    

    下面我们一一通过实例来了解各个参数表示什么意思。

    为了容易理解,我们举个例子,比如我们熟悉的 mnist 数据集中标签的 one_hot 编码中,数字 4 是用向量 ([0,0,0,0,1,0,0,0,0,0]) 来表示的。

    • on_value ,float 类型,表示在 one_hot 编码中标签标记值,在上述编码中 on_value 的值就是 1
    • off_value, float 类型,就是标记点除外的其它值,即 off_value 为 0
    • indices ,一个列表,表示要生成的 one_hot 张量中标记值所在索引,即 indices = [4]
    • depth,int 类型,表示要生成的 one_hot 张量的长度,即 depth = 10
    • Axis,取值为 -1,0 或 1,Axis 取 -1 时造成的张量 shape=[indices 长度, depth],默认值虽是 None,但是和取 -1 效果一样。为 0 时 shape=[depth, indices 长度],取 1 时,比较复杂,是指在三维以上情况下,比方考虑批量输入中,有个批 batch 大小, shape=[batch, indices 长度, depth],具体的可以做下实验验证就好,不需要刻意去记。

    下面用代码验证一下:

    # -*- coding: utf-8 -*-
    """
    Created on Mon Jun  4 08:56:57 2018
    
    @author: zhoukui
    """
    
    import tensorflow as tf
    
    tf.reset_default_graph()
    
    indices = [0, 2, -1, 1, 2]
    depth = 4
    on_value = 3.0
    off_value = 0.0
    axis = -1
    
    t = tf.one_hot(indices, depth, on_value, off_value, axis)
    
    with tf.Session() as sess:
        print(sess.run(t))  #输出 [[ 3.  0.  0.  0.]
                            #     [ 0.  0.  3.  0.]
                            #     [ 0.  0.  0.  0.]
                            #     [ 0.  3.  0.  0.]
                            #     [ 0.  0.  3.  0.]]
    

     

    tf.random_shuffle()

     
    这个函数相对简单,它就一个参数 input,表示沿着 input 的第一维度进行随机重新排列,在进行数据分批的时候特别实用。实例如下:

    # -*- coding: utf-8 -*-
    """
    Created on Mon Jun  4 08:56:57 2018
    
    @author: zhoukui
    """
    
    import tensorflow as tf
    
    tf.reset_default_graph()
    
    input = tf.reshape(tf.linspace(1.0, 10.0, 10), (-1,2))
    
    tf.set_random_seed(666)  # 可以选择固定种子
    t = tf.random_shuffle(input)
    
    with tf.Session() as sess:
        
        print(sess.run(input)) # 输出 [[  1.   2.]
                               #       [  3.   4.]
                               #       [  5.   6.]
                               #       [  7.   8.]
                               #       [  9.  10.]]
        
        print(sess.run(t))  #输出 [[  7.   8.]
                            #      [  5.  6.]
                            #      [  1.   2.]
                            #      [  3.   4.]
                            #      [  9.   10.]]
    

     
     

  • 相关阅读:
    linux --- 3 vim 网络 用户 权限 软连接 压缩 定时任务 yum源
    linux --- 2.常用命令 , python3, django安装
    linux --- 1.初始linux
    admin ---11.admin , 展示列表 和 分页
    并发 ---- 6. IO 多路复用
    django基础 -- 10.form , ModelForm ,modelformset
    django基础 -- 9.中间件
    flask基础
    MySQL-数据库增删改查
    面试题目二
  • 原文地址:https://www.cnblogs.com/zhoukui/p/9157566.html
Copyright © 2011-2022 走看看