zoukankan      html  css  js  c++  java
  • tensorflow

    节点(Node)表示数学操作,多维数据数组,也就是张量(tensor),由线(edges)联系,表示节点之间的输入输出关系

    计算图computational graph是TF中很重要的一个概念,其是由一系列节点(nodes)组成的图模型,每个节点对应的是TF的一个算子(operation)。每个算子会有输入与输出,并且输入和输出都是张量。所以我们使用TF的算子可以构建自己的深度学习模型,其背后的就是一个计算图。还有一点这个计算图是静态的,意思是这个计算图每个节点接收什么样的张量和输出什么样的张量已经固定下来。要运行这个计算图,你需要开启一个会话(session),在session中这个计算图才可以真正运行。

    会话(Session):为了获得图的计算结果,图必须在会话中被启动。图是会话类型的一个成员,会话类型还包括一个runner,负责执行这张图。会话的主要任务是在图运算时分配CPU或GPU。

    Runner:在建立图之后,必须使用会话中的Runner来运行图,才能得到结果。在运行图时,需要为所有的变量和占位符赋值,否则就会报错。

    TensorflowSharp中的几类主要变量

    • Const:常量,这很好理解。它们在定义时就必须被赋值,而且值永远无法被改变。
    • Placeholder:占位符。这是一个在定义时不需要赋值,但在使用之前必须赋值(feed)的变量,通常用作训练数据。
    • Variable:变量,它和占位符的不同是它在定义时需要赋值,而且它的数值是可以在图的计算过程中随时改变的。因此,占位符通常用作图的输入(即训练数据),而变量用作图中可以被“训练”或“学习”的那些tensor,例如y=ax+b中的a和b。
    """A very simple MNIST classifier.
    See extensive documentation at
    http://tensorflow.org/tutorials/mnist/beginners/index.md
    """
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    # Import data
    from tensorflow.examples.tutorials.mnist import input_data
    
    import tensorflow as tf
    
    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data') # 把数据放在/tmp/data文件夹中
    
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)   # 读取数据集
    
    
    # 建立抽象模型
    x = tf.placeholder(tf.float32, [None, 784]) # 占位符
    y = tf.placeholder(tf.float32, [None, 10])
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    a = tf.nn.softmax(tf.matmul(x, W) + b)
    
    # 定义损失函数和训练方法
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(a), reduction_indices=[1]))  # 损失函数为交叉熵
    optimizer = tf.train.GradientDescentOptimizer(0.5) # 梯度下降法,学习速率为0.5
    train = optimizer.minimize(cross_entropy) # 训练目标:最小化损失函数
    
    # Test trained model
    correct_prediction = tf.equal(tf.argmax(a, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Train
    sess = tf.InteractiveSession()      # 建立交互式会话
    tf.initialize_all_variables().run()
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        train.run({x: batch_xs, y: batch_ys})
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
    
  • 相关阅读:
    Debian/Ubuntu/Raspbian 时间同步
    linux 安裝mitmproxy
    Raspbian Lite Guide GUI 树莓派安装桌面
    SSH连接 提示 ssh_exchange_identification: Connection closed by remote host
    Navicat15 永久激活版教程
    docker企业级镜像仓库Harbor管理
    centos7.4安装docker
    Linux系统硬件时间12小时制和24小时制表示设置
    windows server 2012 R2系统安装部署SQLserver2016企业版(转)
    细说show slave status参数详解
  • 原文地址:https://www.cnblogs.com/rjxu/p/13257563.html
Copyright © 2011-2022 走看看