zoukankan      html  css  js  c++  java
  • tensorflow中手写识别笔记

    教程链接:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-c1ov28so.html

    tensorflow 常用的函数:

    # 导入tensorflow,使用tf代替

    import tensorflow as tf

    # 计算x,和w的乘积,这里计算x矩阵和w矩阵的乘积

    tf.matmul(x, w)      

    #  先计算labels和logits的交叉熵(区别),在对结果进行归一化处理,softmax参考

    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)

    # 然后求交叉熵的平均值

    cross_entrony = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

    # 以梯度下降法,0.5的幅度,减小交叉熵

    tf.train.GradientDescentOptimizer(0.5).minimize(cross_entrony)

    # 初始化变量(tf,Variable())

    tf.global_variables_initializer().run()

    # 获取一行最大值的索引

    tf.argmax(y, 1)

    # 比较a和b对应位置是否是相同的,返回结果是bool类型

    tf.equal(a, b)

    # 把x的值转化为另一种y类型

    tf.cast(x, y)

    代码整体解读:

    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    # 加载数据
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  # 下载数据集,存储在/home/msl/Downloads
    
    # 构建回归模型
    x = tf.placeholder(tf.float32, [None, 784])  # None * 784   测试集[60000, 784]
    w = tf.Variable(tf.zeros([784, 10]))  # 784 * 10            和每个像素相乘,得到[None, 10],即为labels
    b = tf.Variable(tf.zeros([10]))
    y = tf.matmul(x, w) + b  # 预测值
    
    # 使用梯度下降法最小化交叉熵
    
    y_ = tf.placeholder(tf.float32, [None, 10])
    cross_entrony = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))  # 计算预测值和真实值的区别,并求均值
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entrony)
    
    # 初始化变量
    # init = tf.global_variables_initializer()
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    
    # 开始训练
    old = time.time()
    with tf.device("/gpu:0"):   # 使用gpu为:/gpu:0
        for i in range(1000):
            batch_xs, batch_xy = mnist.train.next_batch(100)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_xy})
    print(time.time() - old)
    
    # 评估模型
    # tf.argmax(y, 1)返回y中每行的最大值的索引
    # tf.equal(x, y)判断x和y的值是否一致,返回值为bool类型
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # tf.cast(a, b)把a转化为b类型, 再求平均值
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 使用测试集评估模型
  • 相关阅读:
    版本控制,django缓存,跨域问题解决
    Linux之文件系统结构
    Linux之定时任务
    Linux之LVM
    Linux之硬盘与分区
    Linux之文件重定向与查找
    Linux之文件压缩
    Linux之文件权限
    Linux之用户管理
    Linux之文件管理
  • 原文地址:https://www.cnblogs.com/smartmsl/p/10877683.html
Copyright © 2011-2022 走看看