zoukankan      html  css  js  c++  java
  • 机器学习tensorflow框架初试

    本文来自网易云社区

    作者:汪洋


    前言

    新手学习可以点击参考Google的教程。开始前,我们先在本地安装好 TensorFlow机器学习框架。 

    1. 首先我们在本地window下安装好python环境,约定安装3.6版本;

    2. 安装Anaconda工具集后,创建名为 tensorflow 的conda 环境:conda create -n tensorflow pip python=3.6;

    3. conda切换环境:activate tensorflow;

    4. 我们安装支持CPU的TensorFlow版本(快速):pip install --ignore-installed --upgrade tensorflow;

    5. 最后验证安装是否成功,进入 python dos命名,输入以下代码校验:

      import tensorflow as tf
      hello = tf.constant('Hello, TensorFlow')
      sess = tf.Session()
      print(sess.run(hello))

      输出Hello, TensorFlow,表示成功了。如果失败的话,就选择低版本重新安装如:pip install --ignore-installed --upgrade tensorflow==1.5.0。
      其它安装方式点击参考教程


    监督学习实践

    官方针对新手演示了一个入门示例,点击教程可查看,本文就围绕这个教程分享。

    1.分类

    官方示例里讲解了分类鸢尾花问题的解决,我们想到的就是用监督学习训练机器模型。采用这种学习方式后,我们需要确定用鸢尾花的哪些特征来分类,鸢尾花的特征还是蛮多的,官方示例里用的是花萼和花瓣的长度和宽度。
    鸢尾花种类非常多,官方也仅是针对三种进行分类:

    expected = ['Setosa', 'Versicolor', 'Virginica']

    接下来就是获取大量数据,进行预处理,官方示例里直接引用了他人整理的数据源,省略了前期数据处理步骤,前5条数据结构如下:


    SepalLengthSepalWidthPetalLengthPetalWidthSpecies
    06.42.85.62.22
    15.02.33.31.01
    24.92.54.51.72
    34.93.11.50.10
    45.73.81.70.10

    说明:

    1. 最后一列代表着鸢尾花的品种,也就是说它是监督学习中的标签;

    2. 中间四列从左到右表示花萼的长度和宽度、花瓣的长度和宽度;

    3. 表格数据代表了从120个样本的数据集中抽集的5个样本;
      机器学习一般依赖数值,因此当前数据集中标签值都为数字,对应关系: 

    012
    SetosaVersicolorVirginica

    接下来将编写代码,先复习下概念,模型指特征和标签之间的关系;训练指机器学习阶段,这个阶段模型不断优化。示例里选择的监督试学习方式,模型通过包含标签的样本进行训练。

    2. 导入和解析数据集 

    首先我们要获取训练集和测试集,其中训练集是训练模型的样本,测试集是评估训练后模型效果的样本。
    首先设置我们选择的数据集地址

     """训练集"""TRAN_URL = "http://download.tensorflow.org/data/iris_training.csv""""测试集"""TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

    使用tensorflow.keras.utils.get_file函数下载数据集,该方法第一个参数为文件名称,第二个参数为下载地址,点击查看详细)。

    import tensorflow as tfdef download():
        train_path = tf.keras.utils.get_file('iris_training.csv', TRAN_URL)
        test_path = tf.keras.utils.get_file('iris_test.csv', TEST_URL)    return train_path, test_path

    然后用pandas.read_csv函数解析下载的数据,解析后生成的格式是一个表格,然后再分成特征列表和标签列表,返回训练集和测试集

    import pandas as pd
    CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',                    'PetalLength', 'PetalWidth', 'Species']def load_data(y_species='Species'):
        train_path, test_path = download()
        train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
        train_x, train_y = train, train.pop(y_species)
    
        test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
        test_x, test_y = test, test.pop(y_species)    return (train_x, train_y), (test_x, test_y)

    3. 特征列-数值列 

    我们已经获取到数据集,在tensorflow中需要将数据转换为模型(Estimator)可以使用的数据结构,这时候调用tf.feature_column模块中的函数来转换。鸢尾花例子中,需将特征数据转换为浮点数,调用tf.feature_column.numeric_column方法。

    import iris_data
    
    (train_x, train_y), (test_x, test_y) = iris_data.load_data()
    my_feature_columns = []for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key=key))

    其中key是 ['SepalLength' , 'SepalWidth' , 'PetalLength' , 'PetalWidth'] 其中之一。

    4. 模型选择 

    官方例子中选择全连接神经网络解决鸢尾花问题,用神经网络来发现特征与标签之间的复杂关系。tensorflow中,通过实例化一个Estimator类指定模型类型,这里我们使用官方提供的预创建的Estimator类,tf.estimator.DNNClassifier,此Estimator会构建一个对样本进行分类的神经网络。

    classifier = tf.estimator.DNNClassifier(
        feature_columns = my_feature_columns,
        hidden_units = [10,10],
        n_classes = 3)

    feature_columns 参数指训练的特征列(这里是数值列);
    hidden_units 参数定义神经网络内每个隐藏层中的神经元数量,这里设置了2个隐藏层,每个隐藏层中神经元数量都是10个;
    n_classes 参数表示要预测的标签数量,这里我们需要预测3个品种;
    其它参数点击查看

    5. 训练模型 

    上一步我们已经创建了一个学习模型,接下来将数据导入到模型中进行训练。tensorflow中,调用Estimator对象的train方法训练。

    classifier.train(
        input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100)
        steps = 1000)

    input_fn 参数表示提供训练数据的函数; steps 参数表示训练迭代次数;
    在train_input_fn函数里,我们将数据转换为 train方法所需的格式。 

    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    为了保证训练效果,训练样本需随机排序。buffer_size 设置大于样本数(120),可确保数据得到充分的随机化处理。 

    dataset = dataset.shuffle(1000)

    为了保证训练期间,有无限量的训练样本,需调用 tf.data.Dataset.repeat。

    dataset = dataset.repeat()

    train方法一次处理一批样本, tf.data.Dataset.batch 方法通过组合多个样本创建一个批次,这里组合多个包含100个样本的批次。

    dataset = dataset.batch(100)

    6. 模型评估 

    接下来我们将训练好的模型预测效果。tensorflow中,每个Estimator对象提供了evaluate方法。

    eval_result = classifier.evaluate(
        input_fn = lambda:iris_data.eval_input_fn(test_x, test_y, 100)
    )

    在eval_input_fn函数里,我们将数据转换为 evaluate方法所需的格式。实现跟训练一样,只是无需随机化处理和无限量重复使用测试集。

    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    dataset.batch(100);return dataset

    7. 预测 

    接下来将该模型对无标签样本进行预测。官方手动提供了三个无标签样本。

    predict_x = {    'SepalLength': [5.1, 5.9, 6.9],    'SepalWidth': [3.3, 3.0, 3.1],    'PetalLength': [1.7, 4.2, 5.4],    'PetalWidth': [0.5, 1.5, 2.1],
    }

    tensorflow中,每个Estimator对象提供了predict方法。

    predictions = classifier.predict(
        input_fn = lambda:iris_data.eval_input_fn(predict_x, labels=None, 100)
    )

    改造下eval_input_fn方法,使其能够接受 labels = none 情况

    features=dict(features)if labels is None:
        inputs = featureselse:
        inputs = (features, labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)

    接下来打印下预测结果, predictions 中 class_ids表示可能性最大的品种,probabilities 表示每个品种的概率

    for pred_dict in predictions:
        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]
        print(class_id, probability)

    结果如下:



    00.99706334
    10.997407
    20.97377485


    结尾

    通过官方例子,新手可初步了解其使用,当然更深入的使用还得学习理论和多使用API。本文是根据官方例子,作为新手重新梳理了一遍。



    网易云免费体验馆,0成本体验20+款云产品

    更多网易研发、产品、运营经验分享请访问网易云社区


    相关文章:
    【推荐】 云计算交互设计师的正确出装姿势

  • 相关阅读:
    win7下的vxworks总结
    ubuntu 无法获得锁 /var/lib/dpkg/lock
    项目中用到了的一些批处理文件
    win7下安装 WINDRIVER.TORNADO.V2.2.FOR.ARM
    使用opencv统计视频库的总时长
    January 05th, 2018 Week 01st Friday
    January 04th, 2018 Week 01st Thursday
    January 03rd, 2018 Week 01st Wednesday
    January 02nd, 2018 Week 01st Tuesday
    January 01st, 2018 Week 01st Monday
  • 原文地址:https://www.cnblogs.com/163yun/p/9722699.html
Copyright © 2011-2022 走看看