zoukankan      html  css  js  c++  java
  • TensorFlow入门

    TensorFlow入门 - 使用TensorFlow给鸢尾花分类(线性模型)


    第一个例子将使用TensorFlow封装的Estimator来实现一个简单的Classifier,该Classifier能够区分3种比较难分辨的鸢尾花,分别是Iris Setosa(山鸢尾)、Iris Versicolour(变色鸢尾)和Iris Virginica(维吉尼亚鸢尾)。

    不同种类植物有不同的性状,我们对不同的鸢尾花的区分是根据它们的某些性状来进行的。具体地说,我们通过4个特征值来对鸢尾花进行区分,这四个特征值(单位cm)包括sepal length(萼片长度)、sepal width(萼片宽度)、petal length(花瓣长度)、petal width(花瓣宽度)。

    环境搭建


    本例使用Jupyter Notebook进行具体实现,使用它需要安装Anaconda,这个部分参考博主之前的博文。
    Ubuntu16.04使用Anaconda5搭建TensorFlow使用环境 图文详细教程

    Jupyter Notebook即此前的Ipython Notebook,是一个web应用程序,可以以文档形式保存所有输入和输出。

    鸢尾花数据集


    鸢尾花数据集是一个经典的机器学习数据集,非常适合用来入门。它包括5列数据:前4列代表4个特征值即sepal length(萼片长度)、sepal width(萼片宽度)、petal length(花瓣长度)、petal width(花瓣宽度);最后一列为Species,即鸢尾花的种类,是我们训练目标,在机器学习中称作label。这种数据也被称作标记数据(labeled data)。
    鸢尾花数据集

    机器学习中,为了保证测试结果的准确性,一般会从数据集中抽取一部分数据专门留作测试,其余数据用于训练。本例使用了两个CSV格式的数据文件,一个是iris_training.csv即训练文件,另一个是iris_test.csv,即测试文件。

    具体实现

    在tensorflow虚拟环境中启动jupyter notebook
    steve@steve-Lenovo-V2000:~$ source activate tensorflow
    (tensorflow) steve@steve-Lenovo-V2000:~$ jupyter notebook
    In[1]       
    import tensorflow as tf
    import numpy as np
    
    print(tf.__version__)
    
    1.3.0
    In[2]       
    from tensorflow.contrib.learn.python.learn.datasets import base
    
    #所用的数据集文件
    IRIS_TRAINING = "iris_training.csv"
    IRIS_TEST = "iris_test.csv"
    #加载数据集
    training_set = base.load_csv_with_header(filename = IRIS_TRAINING, 
                                             features_dtype = np.float32,
                                             target_dtype = np.int)
    test_set = base.load_csv_with_header(filename = IRIS_TEST, 
                                             features_dtype = np.float32,
                                             target_dtype = np.int)
    
    print(training_set.data)
    print(training_set.target)
    
    [[ 6.4000001   2.79999995  5.5999999   2.20000005]
     [ 5.          2.29999995  3.29999995  1.        ]
     [ 4.9000001   2.5         4.5         1.70000005]
     [ 4.9000001   3.0999999   1.5         0.1       ]
     [ 5.69999981  3.79999995  1.70000005  0.30000001]
     [ 4.4000001   3.20000005  1.29999995  0.2       ]
     [ 5.4000001   3.4000001   1.5         0.40000001]
     [ 6.9000001   3.0999999   5.0999999   2.29999995]
     [ 6.69999981  3.0999999   4.4000001   1.39999998]
     [ 5.0999999   3.70000005  1.5         0.40000001]
     [ 5.19999981  2.70000005  3.9000001   1.39999998]
     [ 6.9000001   3.0999999   4.9000001   1.5       ]
     [ 5.80000019  4.          1.20000005  0.2       ]
     [ 5.4000001   3.9000001   1.70000005  0.40000001]
     [ 7.69999981  3.79999995  6.69999981  2.20000005]
     [ 6.30000019  3.29999995  4.69999981  1.60000002]
     [ 6.80000019  3.20000005  5.9000001   2.29999995]
     [ 7.5999999   3.          6.5999999   2.0999999 ]
     [ 6.4000001   3.20000005  5.30000019  2.29999995]
     [ 5.69999981  4.4000001   1.5         0.40000001]
     [ 6.69999981  3.29999995  5.69999981  2.0999999 ]
     [ 6.4000001   2.79999995  5.5999999   2.0999999 ]
     [ 5.4000001   3.9000001   1.29999995  0.40000001]
     [ 6.0999999   2.5999999   5.5999999   1.39999998]
     [ 7.19999981  3.          5.80000019  1.60000002]
     [ 5.19999981  3.5         1.5         0.2       ]
     [ 5.80000019  2.5999999   4.          1.20000005]
     [ 5.9000001   3.          5.0999999   1.79999995]
     [ 5.4000001   3.          4.5         1.5       ]
     [ 6.69999981  3.          5.          1.70000005]
     [ 6.30000019  2.29999995  4.4000001   1.29999995]
     [ 5.0999999   2.5         3.          1.10000002]
     [ 6.4000001   3.20000005  4.5         1.5       ]
     [ 6.80000019  3.          5.5         2.0999999 ]
     [ 6.19999981  2.79999995  4.80000019  1.79999995]
     [ 6.9000001   3.20000005  5.69999981  2.29999995]
     [ 6.5         3.20000005  5.0999999   2.        ]
     [ 5.80000019  2.79999995  5.0999999   2.4000001 ]
     [ 5.0999999   3.79999995  1.5         0.30000001]
     [ 4.80000019  3.          1.39999998  0.30000001]
     [ 7.9000001   3.79999995  6.4000001   2.        ]
     [ 5.80000019  2.70000005  5.0999999   1.89999998]
     [ 6.69999981  3.          5.19999981  2.29999995]
     [ 5.0999999   3.79999995  1.89999998  0.40000001]
     [ 4.69999981  3.20000005  1.60000002  0.2       ]
     [ 6.          2.20000005  5.          1.5       ]
     [ 4.80000019  3.4000001   1.60000002  0.2       ]
     [ 7.69999981  2.5999999   6.9000001   2.29999995]
     [ 4.5999999   3.5999999   1.          0.2       ]
     [ 7.19999981  3.20000005  6.          1.79999995]
     [ 5.          3.29999995  1.39999998  0.2       ]
     [ 6.5999999   3.          4.4000001   1.39999998]
     [ 6.0999999   2.79999995  4.          1.29999995]
     [ 5.          3.20000005  1.20000005  0.2       ]
     [ 7.          3.20000005  4.69999981  1.39999998]
     [ 6.          3.          4.80000019  1.79999995]
     [ 7.4000001   2.79999995  6.0999999   1.89999998]
     [ 5.80000019  2.70000005  5.0999999   1.89999998]
     [ 6.19999981  3.4000001   5.4000001   2.29999995]
     [ 5.          2.          3.5         1.        ]
     [ 5.5999999   2.5         3.9000001   1.10000002]
     [ 6.69999981  3.0999999   5.5999999   2.4000001 ]
     [ 6.30000019  2.5         5.          1.89999998]
     [ 6.4000001   3.0999999   5.5         1.79999995]
     [ 6.19999981  2.20000005  4.5         1.5       ]
     [ 7.30000019  2.9000001   6.30000019  1.79999995]
     [ 4.4000001   3.          1.29999995  0.2       ]
     [ 7.19999981  3.5999999   6.0999999   2.5       ]
     [ 6.5         3.          5.5         1.79999995]
     [ 5.          3.4000001   1.5         0.2       ]
     [ 4.69999981  3.20000005  1.29999995  0.2       ]
     [ 6.5999999   2.9000001   4.5999999   1.29999995]
     [ 5.5         3.5         1.29999995  0.2       ]
     [ 7.69999981  3.          6.0999999   2.29999995]
     [ 6.0999999   3.          4.9000001   1.79999995]
     [ 4.9000001   3.0999999   1.5         0.1       ]
     [ 5.5         2.4000001   3.79999995  1.10000002]
     [ 5.69999981  2.9000001   4.19999981  1.29999995]
     [ 6.          2.9000001   4.5         1.5       ]
     [ 6.4000001   2.70000005  5.30000019  1.89999998]
     [ 5.4000001   3.70000005  1.5         0.2       ]
     [ 6.0999999   2.9000001   4.69999981  1.39999998]
     [ 6.5         2.79999995  4.5999999   1.5       ]
     [ 5.5999999   2.70000005  4.19999981  1.29999995]
     [ 6.30000019  3.4000001   5.5999999   2.4000001 ]
     [ 4.9000001   3.0999999   1.5         0.1       ]
     [ 6.80000019  2.79999995  4.80000019  1.39999998]
     [ 5.69999981  2.79999995  4.5         1.29999995]
     [ 6.          2.70000005  5.0999999   1.60000002]
     [ 5.          3.5         1.29999995  0.30000001]
     [ 6.5         3.          5.19999981  2.        ]
     [ 6.0999999   2.79999995  4.69999981  1.20000005]
     [ 5.0999999   3.5         1.39999998  0.30000001]
     [ 4.5999999   3.0999999   1.5         0.2       ]
     [ 6.5         3.          5.80000019  2.20000005]
     [ 4.5999999   3.4000001   1.39999998  0.30000001]
     [ 4.5999999   3.20000005  1.39999998  0.2       ]
     [ 7.69999981  2.79999995  6.69999981  2.        ]
     [ 5.9000001   3.20000005  4.80000019  1.79999995]
     [ 5.0999999   3.79999995  1.60000002  0.2       ]
     [ 4.9000001   3.          1.39999998  0.2       ]
     [ 4.9000001   2.4000001   3.29999995  1.        ]
     [ 4.5         2.29999995  1.29999995  0.30000001]
     [ 5.80000019  2.70000005  4.0999999   1.        ]
     [ 5.          3.4000001   1.60000002  0.40000001]
     [ 5.19999981  3.4000001   1.39999998  0.2       ]
     [ 5.30000019  3.70000005  1.5         0.2       ]
     [ 5.          3.5999999   1.39999998  0.2       ]
     [ 5.5999999   2.9000001   3.5999999   1.29999995]
     [ 4.80000019  3.0999999   1.60000002  0.2       ]
     [ 6.30000019  2.70000005  4.9000001   1.79999995]
     [ 5.69999981  2.79999995  4.0999999   1.29999995]
     [ 5.          3.          1.60000002  0.2       ]
     [ 6.30000019  3.29999995  6.          2.5       ]
     [ 5.          3.5         1.60000002  0.60000002]
     [ 5.5         2.5999999   4.4000001   1.20000005]
     [ 5.69999981  3.          4.19999981  1.20000005]
     [ 4.4000001   2.9000001   1.39999998  0.2       ]
     [ 4.80000019  3.          1.39999998  0.1       ]
     [ 5.5         2.4000001   3.70000005  1.        ]]
    [2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2
    2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2 1 0 2 0 1 1 0 0 1]
    (第一个list是4个特征值,第二个list是目标结果,即鸢尾的种类,用int的012表示Iris Setosa(山鸢尾)、Iris Versicolour(变色鸢尾)和Iris Virginica(维吉尼亚鸢尾)。)
    In[3]   
        #构建模型
        #假定所有的特征都有一个实数值作为数据
        feature_name = "flower_features"
    feature_columns = [tf.feature_column.numeric_column(feature_name, shape = [4])]
        classifier = tf.estimator.LinearClassifier(
                     feature_columns = feature_columns,
                         n_classes = 3,
                         model_dir = "/tmp/iris_model")
    
    INFO:tensorflow:Using default config.
    INFO:tensorflow:Using config: {'_model_dir': '/tmp/iris_model', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100}
    
    In[4]   
    # define input function 定义一个输入函数,用于为模型产生数据
    def input_fn(dataset):
                def _fn():
                    features = {feature_name: tf.constant(dataset.data)}
                    label = tf.constant(dataset.target)
                    return features, label
                return _fn
    print(input_fn(training_set)())
    
    ({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)
    In[5]   
    # 数据流向
    # raw data -> input_fn -> feature columns -> model
    # fit model 训练模型
    classifier.train(input_fn = input_fn(training_set), steps = 1000)
    print('fit already done.')
    
    INFO:tensorflow:Create CheckpointSaverHook.
    INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.
    INFO:tensorflow:loss = 131.833, step = 1
    INFO:tensorflow:global_step/sec: 1396.3
    INFO:tensorflow:loss = 37.1391, step = 101 (0.072 sec)
    INFO:tensorflow:global_step/sec: 1279.85
    INFO:tensorflow:loss = 27.8594, step = 201 (0.078 sec)
    INFO:tensorflow:global_step/sec: 1400.15
    INFO:tensorflow:loss = 23.0449, step = 301 (0.071 sec)
    INFO:tensorflow:global_step/sec: 1293.92
    INFO:tensorflow:loss = 20.058, step = 401 (0.077 sec)
    INFO:tensorflow:global_step/sec: 1610.43
    INFO:tensorflow:loss = 18.0083, step = 501 (0.062 sec)
    INFO:tensorflow:global_step/sec: 1617.19
    INFO:tensorflow:loss = 16.505, step = 601 (0.062 sec)
    INFO:tensorflow:global_step/sec: 1602.84
    INFO:tensorflow:loss = 15.3496, step = 701 (0.062 sec)
    INFO:tensorflow:global_step/sec: 1799.5
    INFO:tensorflow:loss = 14.43, step = 801 (0.056 sec)
    INFO:tensorflow:global_step/sec: 1577.18
    INFO:tensorflow:loss = 13.6782, step = 901 (0.063 sec)
    INFO:tensorflow:Saving checkpoints for 1000 into /tmp/iris_model/model.ckpt.
    INFO:tensorflow:Loss for final step: 13.0562.
    fit already done.
    In[6]   
    # Evaluate accuracy 评估模型的准确度
    accuracy_score = classifier.evaluate(input_fn = input_fn(test_set),
                                         steps = 100)["accuracy"]
    print('
    Accuracy: {0:f}'.format(accuracy_score))
    
    INFO:tensorflow:Starting evaluation at 2018-03-03-12:07:04
    INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-1000
    INFO:tensorflow:Evaluation [1/100]
    INFO:tensorflow:Evaluation [2/100]
    INFO:tensorflow:Evaluation [3/100]
    INFO:tensorflow:Evaluation [4/100]
    INFO:tensorflow:Evaluation [5/100]
    INFO:tensorflow:Evaluation [6/100]
    INFO:tensorflow:Evaluation [7/100]
    INFO:tensorflow:Evaluation [8/100]
    ……
    INFO:tensorflow:Evaluation [98/100]
    INFO:tensorflow:Evaluation [99/100]
    INFO:tensorflow:Evaluation [100/100]
    INFO:tensorflow:Finished evaluation at 2018-03-03-12:07:05
    INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.966667, average_loss = 0.120964, global_step = 1000, loss = 3.62893
    Accuracy: 0.966667

    总结与说明


    本例主要使用了TensorFlow封装的高级API,即Estimator。Estimator已经对训练过程进行了封装,因此我们只需要配置就可以进行使用。

        classifier = tf.estimator.LinearClassifier(
                     feature_columns = feature_columns,
                         n_classes = 3,
                         model_dir = "/tmp/iris_model")

    这是构建模型所使用的代码,它定义了一个简单的线性模型,并配置了三个参数:feature_columns即特征值,已在前面定义;n_class即分类的总数,本例为3;model_dir即模型的存储路径。

    本例所搭建的线性模型的最终准确度达到了96.66667%。这是一个不错的数值,因为这意味着从统计方面来说该模型能从100朵鸢尾中正确区分96朵鸢尾的品种。事实上,如果让一个真实的人来对100朵鸢尾做出品种的区分,他也有可能区分错其中4朵甚至更多。当然这并不意味着我们对此感到满足,因为这是一个示例的简单模型,我们应当追求实际应用模型的准确率超过99%!

    以上过程也给出了我们机器学习模型搭建的基本步骤,即:
    这里写图片描述

    本例参考自Plain and Simple Estimators - YouTube,中文字幕以及详细解释参考机器学习 | 更进一步,用评估器给花卉分类,本文着重于其具体实现部分,给代码加了比较详细的注释。

  • 相关阅读:
    Android 目前最稳定和高效的UI适配方案
    寄Android开发Gradle你需要知道的知识
    Android精讲--界面编程5(AdapterView及其子类)
    Android精讲--界面编程4(ImageView及其子类)
    Android精讲--界面编程3(TextView及其子类)
    Android精讲--界面编程2(布局管理器)
    Android的基类Context和View
    Android里的前端界面
    Android的活动Activity
    Android基础入门
  • 原文地址:https://www.cnblogs.com/wanghongze95/p/13842561.html
Copyright © 2011-2022 走看看