zoukankan      html  css  js  c++  java
  • TensorFlow高层次机器学习API (tf.contrib.learn)

    TensorFlow高层次机器学习API (tf.contrib.learn)

    1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

    2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

    3.classifer.fit 训练模型

    4.classifier.evaluate 评价模型

    5.classifier.predict 预测新样本

    完整代码:

    复制代码
     1 from __future__ import absolute_import
     2 from __future__ import division
     3 from __future__ import print_function
     4 
     5 import tensorflow as tf
     6 import numpy as np
     7 
     8 # Data sets
     9 IRIS_TRAINING = "iris_training.csv"
    10 IRIS_TEST = "iris_test.csv"
    11 
    12 # Load datasets.
    13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    14     filename=IRIS_TRAINING,
    15     target_dtype=np.int,
    16     features_dtype=np.float32)
    17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    18     filename=IRIS_TEST,
    19     target_dtype=np.int,
    20     features_dtype=np.float32)
    21 
    22 # Specify that all features have real-value data
    23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
    24 
    25 # Build 3 layer DNN with 10, 20, 10 units respectively.
    26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
    27                                             hidden_units=[10, 20, 10],
    28                                             n_classes=3,
    29                                             model_dir="/tmp/iris_model")
    30 
    31 # Fit model.
    32 classifier.fit(x=training_set.data,
    33                y=training_set.target,
    34                steps=2000)
    35 
    36 # Evaluate accuracy.
    37 accuracy_score = classifier.evaluate(x=test_set.data,
    38                                      y=test_set.target)["accuracy"]
    39 print('Accuracy: {0:f}'.format(accuracy_score))
    40 
    41 # Classify two new flower samples.
    42 new_samples = np.array(
    43     [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
    44 y = list(classifier.predict(new_samples, as_iterable=True))
    45 print('Predictions: {}'.format(str(y)))
    复制代码

     结果:

    Accuracy:0.966667

  • 相关阅读:
    Apollo的Oracle适配改动
    尝试Java,从入门到Kotlin(下)
    尝试Java,从入门到Kotlin(上)
    RabbitMQ权限控制原理
    一文彻底掌握二叉查找树(多组动图)(史上最全总结)
    图解:深度优先搜索与广度优先搜索及其六大应用
    图解:如何理解与实现散列表
    图解:什么是“图”?
    查找算法系列文(一)一文入门二叉树
    线性表(数组、链表、队列、栈)详细总结
  • 原文地址:https://www.cnblogs.com/bonelee/p/7903436.html
Copyright © 2011-2022 走看看