zoukankan      html  css  js  c++  java
  • Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林
    Google Machine Learning Recipes 7
    官方中文博客 - 视频地址
    Github工程地址 https://github.com/ahangchen/GoogleML
    欢迎Star,也欢迎到Issue区讨论

    mnist问题

    • 计算机视觉领域的Hello world
    • 给定55000个图片,处理成28*28的二维矩阵,矩阵中每个值表示一个像素点的灰度,作为feature
    • 给定每张图片对应的字符,作为label,总共有10个label,是一个多分类问题

    TensorFlow

    • 可以按教程用Docker安装,也可以直接在Linux上安装
    • 你可能会担心,不用Docker的话怎么开那个notebook呢?其实notebook就在主讲人的Github页
    • 可以用这个Chrome插件:npviewer直接在浏览器中阅读ipynb格式的文件,而不用在本地启动iPython notebook
    • 我们的教程在这里:ep7.ipynb
    • 把代码从ipython notebook中整理出来:tflearn_mnist.py

    代码分析

    • 下载数据集
    mnist = learn.datasets.load_dataset('mnist')
    

    恩,就是这么简单,一行代码下载解压mnist数据,每个img已经灰度化成长784的数组,每个label已经one-hot成长度10的数组

    在我的深度学习笔记看One-hot是什么东西

    • numpy读取图像到内存,用于后续操作,包括训练集(只取前10000个)和验证集
    data = mnist.train.images
    labels = np.asarray(mnist.train.labels, dtype=np.int32)
    test_data = mnist.test.images
    test_labels = np.asarray(mnist.test.labels, dtype=np.int32)
    max_examples = 10000
    data = data[:max_examples]
    labels = labels[:max_examples]
    
    • 可视化图像
    def display(i):
        img = test_data[i]
        plt.title('Example %d. Label: %d' % (i, test_labels[i]))
        plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
        plt.show()
    

    用matplotlib展示灰度图

    • 训练分类器
      • 提取特征(这里每个图的特征就是784个像素值)
    feature_columns = learn.infer_real_valued_columns_from_input(data)
    
    • 创建线性分类器并训练
    classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
    classifier.fit(data, labels, batch_size=100, steps=1000)
    

    注意要制定n_classes为labels的数量

    • 分类器实际上是在根据每个feature判断每个label的可能性,
    • 不同的feature有的重要,有的不重要,所以需要设置不同的权重
    • 一开始权重都是随机的,在fit的过程中,实际上就是在调整权重

    • 最后可能性最高的label就会作为预测输出

    • 传入测试集,预测,评估分类效果

    result = classifier.evaluate(test_data, test_labels)
    print result["accuracy"]
    

    速度非常快,而且准确率达到91.4%

    可以只预测某张图,并查看预测是否跟实际图形一致

    # here's one it gets right
    print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0]))
    display(0)
    # and one it gets wrong
    print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8]))
    display(8)
    
    • 可视化权重以了解分类器的工作原理
    weights = classifier.weights_
    a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)
    

    • 这里展示了8个张图中,每个像素点(也就是feature)的weights,
    • 红色表示正的权重,蓝色表示负的权重
    • 作用越大的像素,它的颜色越深,也就是权重越大
    • 所以权重中红色部分几乎展示了正确的数字

    Next steps

  • 相关阅读:
    java里的分支语句--程序运行流程的分类(顺序结构,分支结构,循环结构)
    Java里的构造函数(构造方法)
    Java里this的作用和用法
    JAVA中的重载和重写
    从键盘接收字符类型的数据并实现剪刀石头布的规则
    使用Notepad++编码编译时报错(已解决?)
    云就是网络,云计算呢
    使用JavaMail创建邮件和发送邮件
    mysql锁机制
    java中几种常用的设计模式
  • 原文地址:https://www.cnblogs.com/hellocwh/p/5783249.html
Copyright © 2011-2022 走看看