zoukankan      html  css  js  c++  java
  • 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier

    # 1. 模型定义。
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    tf.logging.set_verbosity(tf.logging.INFO)
    
    mnist = input_data.read_data_sets("F:\TensorFlowGoogle\201806-github\datasets\MNIST_data", one_hot=False)
    
    # 定义模型的输入。
    feature_columns = [tf.feature_column.numeric_column("image", shape=[784])]
    
    # 通过DNNClassifier定义模型。
    estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns,hidden_units=[500],n_classes=10,optimizer=tf.train.AdamOptimizer(),model_dir="log")

    # 2. 训练模型。
    train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True)
    
    estimator.train(input_fn=train_input_fn, steps=10000)

    # 3. 测试模型。
    test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False)
    
    test_results = estimator.evaluate(input_fn=test_input_fn)
    accuracy_score = test_results["accuracy"]
    print("
    Test accuracy: %g %%" % (accuracy_score*100))
    
    print(test_results)

  • 相关阅读:
    Metadata Lock原理5
    Seconds_Behind_Master
    Metadata Lock原理4
    MySQL Troubleshoting:Waiting on query cache mutex 腾讯数据库工程师:幕南风
    Metadata Lock原理2
    Metadata Lock原理1
    Online DDL与pt-online-schema-change
    Solaris 安装JDK
    RAID 概述
    4K Block Size的Device和 Aligned IO
  • 原文地址:https://www.cnblogs.com/tszr/p/12096916.html
Copyright © 2011-2022 走看看