zoukankan      html  css  js  c++  java
  • Tensorflow学习二_mnist入门

    前两天刚刚装好我的Tensorflow,于是今天通过tensorflow的中文网站(http://www.tensorfly.cn/tfdoc/get_started/introduction.html),

    准备开始学习关于tensorflow的入门——mnist手写字母的识别入门。

    在此主要记录一下我运行我的第一个代码时,出现的小错误。

    一、简单示例

    在简介中有一段使用 Python API 撰写的 TensorFlow 示例代码,

    直接拿来运行:出现错误

      File "D:/tensorflow/python文件/tensorflow1.py", line 37
        print step, sess.run(W), sess.run(b)
                 ^
    SyntaxError: invalid syntax
    

    后来通过网上搜索发现,是因为在官网上所用的代码是python2.x,而我使用的是python3,

    1、改xrange为range

    2、修改print格式

    运行成功

    import tensorflow as tf
    import numpy as np
    
    # 使用 NumPy 生成假数据(phony data), 总共 100 个点.
    x_data = np.float32(np.random.rand(2, 100)) # 随机输入
    y_data = np.dot([0.100, 0.200], x_data) + 0.300
    
    # 构造一个线性模型
    # 
    b = tf.Variable(tf.zeros([1]))
    W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
    y = tf.matmul(W, x_data) + b
    
    # 最小化方差
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    # 初始化变量
    init = tf.initialize_all_variables()
    
    # 启动图 (graph)
    sess = tf.Session()
    sess.run(init)
    
    # 拟合平面
    for step in range(0, 201):
        sess.run(train)
        if step % 20 == 0:
            print (step, sess.run(W), sess.run(b))
    View Code

    二、mnist入门

    在下载数据集的时候,官网上提供了两种方法:一是下载代码并导入到项目,二是直接用python源代码自动下载和安装。

    在这里,我是直接用python源代码下载和安装。

    #导入数据集
    import input_data 
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    #实现回归模型
    import tensorflow as tf
    x = tf.placeholder("float", [None, 784])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x,W) + b)
    
    #训练模型
    y_ = tf.placeholder("float", [None,10])
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    
    for i in range(1000):
      batch_xs, batch_ys = mnist.train.next_batch(100)
      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    
    #评估模型
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
    

     报错:

    ImportError: No module named 'input_data'
    

    将import input_data 代码换成  from tensorflow.examples.tutorials.mnist import input_data

    运行成功:

    运行结果的正确率是92%

  • 相关阅读:
    今天解决了一个很郁闷的问题!
    解决了安装golive后html文件图标显示错误的问题
    [转载]Asp.Net 2.0 发布问题
    使用 Visual Studio 2005 构建“WPFE”项目
    Ajax学习网址备忘录
    [原创首发]深圳博客问测系统正式发布啦!
    如何在用户控件里联动Dropdownlist
    [转载]在ASP.NET中值得注意的两个地方
    [转]Prototype 1.5 Ajax 使用教程
    1038 Recover the Smallest Number (30 分)(贪心)
  • 原文地址:https://www.cnblogs.com/smile321/p/11205527.html
Copyright © 2011-2022 走看看