zoukankan      html  css  js  c++  java
  • tensorflow——MNIST机器学习入门

    这里的代码在项目中执行下载并安装数据集。

    执行下面代码,训练、并评估模型:

     1 # _*_coding:utf-8_*_
     2 import inputdata
     3 mnist = inputdata.read_data_sets('MNIST_data', one_hot=True)
     4 
     5 import tensorflow as tf
     6 
     7 x = tf.placeholder("float",[None, 784])
     8 W = tf.Variable(tf.zeros([784,10]))
     9 b = tf.Variable(tf.zeros([10]))
    10 
    11 y = tf.nn.softmax(tf.matmul(x,W) + b)    # 预测值
    12 y_ = tf.placeholder("float", [None, 10])    # 真实值
    13 
    14 # 训练模型
    15 cross_entropy = -tf.reduce_sum(y_*tf.log(y))    # 交叉熵,交叉熵是用来衡量我们的预测用于描述真相
    16                                                 # 的低效性。注意这里是对所有记录所有假设的sum操作
    17 
    18 train_step= tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)    # 调用梯度下降,TensorFlow会自动使用反向传播算法
    19                                                                                # 并且通过minimize来指定要最小化的代价方程
    20 
    21 # 初始化变量
    22 init = tf.initialize_all_variables()
    23 
    24 # 运行
    25 sess = tf.Session()
    26 sess.run(init)
    27 
    28 for i in range(1000):
    29     batch_xs, batch_ys = mnist.train.next_batch(100)    # 随机训练 (stochastic training), 减小计算开销
    30     sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})
    31 
    32 # 评估模型
    33 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))    # argmax返回的是沿着某个轴的最大值的索引值
    34 
    35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))    # cast把布尔值转为浮点值
    36 
    37 print sess.run(accuracy, feed_dict={x: mnist.test.images, y_:mnist.test.labels})
    38 
    39 sess.close()

    运行结果:(由于是随机训练每次结果可能一样)

    0.9131
  • 相关阅读:
    14. Longest Common Prefix[E]最长公共前缀
    13. Roman to Integer[E]罗马数字转整数
    12. Integer to Roman[M]整数转罗马数字
    11. Container With Most Water[M]盛最多水的容器
    10. Regular Expression Matching[H]正则表达式匹配
    清除浮動,父類塌陷解決
    html 定位
    微信支付这个坑,终于过了
    浮动
    盒子模型高级应用
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/7148359.html
Copyright © 2011-2022 走看看