zoukankan      html  css  js  c++  java
  • 吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型

    # 1. 自定义模型并训练。
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    tf.logging.set_verbosity(tf.logging.INFO)
    
    def lenet(x, is_training):
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
    
        conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
        conv1 = tf.layers.max_pooling2d(conv1, 2, 2)
    
        conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
        conv2 = tf.layers.max_pooling2d(conv2, 2, 2)
    
        fc1 = tf.contrib.layers.flatten(conv2)
        fc1 = tf.layers.dense(fc1, 1024)
        fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
        return tf.layers.dense(fc1, 10)
    
    def model_fn(features, labels, mode, params):
        predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN)
    
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)})
    
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels))
    
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"])
    
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
    
        eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)}
    
        return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops)
    
    mnist = input_data.read_data_sets("F:\TensorFlowGoogle\201806-github\datasets\MNIST_data", one_hot=False)
    
    model_params = {"learning_rate": 0.01}
    estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params)
      
    train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True)
    
    estimator.train(input_fn=train_input_fn, steps=30000)

    # 2. 在测试数据上测试模型。
    test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False)
    
    test_results = estimator.evaluate(input_fn=test_input_fn)
    accuracy_score = test_results["accuracy"]
    print("
    Test accuracy: %g %%" % (accuracy_score*100))
    # 3. 预测过程。
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False)
    
    predictions = estimator.predict(input_fn=predict_input_fn)
    for i, p in enumerate(predictions):
        print("Prediction %s: %s" % (i + 1, p["result"]))
  • 相关阅读:
    Opportunities
    去考試6/16
    WP数据绑定 GIS
    wp 之path详细 以及一个关于LinearGradientBrush 的动画 GIS
    windows phone 多触控画图并保存到 手机图片库 GIS
    windwos phone 的多任务 GIS
    导航基础 GIS
    Windows Phone 7 页面旋转动画 GIS
    一个小范围 滑动的动画 GIS
    wp 的手势Gestures: flick, pan, and stretch GIS
  • 原文地址:https://www.cnblogs.com/tszr/p/12097244.html
Copyright © 2011-2022 走看看