zoukankan      html  css  js  c++  java
  • tf.estimator

    estimator同keras是tensorflow的高级API。在tensorflow1.13以上,estimator已经作为一个单独的package从tensorflow分离出来了。
    estimator抽象了tensorflow底层的api, 同keras一样,他分离了model和data, 不同于keras这个不得不认养的儿子,estimator作为tensorflow的亲儿子,天生具有分布式的基因,更容易在生产环境里面使用

    tensorflow官方文档提供了比较详细的estimator程序的构建过程:
    https://www.tensorflow.org/guide#estimators

    tensorflow model提供了estimator构建的mnist程序:
    https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

    estimator模型由model_fn决定:
    官方文档:

    其中features, labels是必需的。model, params, config参数是可选的
    如下是estiamtor定义的一个模型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    def (features, labels, mode, params):
    """DNN with three hidden layers and learning_rate=0.1."""

    net = tf.feature_column.input_layer(features, params['feature_columns'])
    for units in params['hidden_units']:
    net = t 大专栏  tf.estimatorf.layers.dense(net, units=units, activation=tf.nn.relu)

    # Compute logits (1 per class).
    logits = tf.layers.dense(net, params['n_classes'], activation=None)

    # Compute predictions.
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
    'class_ids': predicted_classes[:, tf.newaxis],
    'probabilities': tf.nn.softmax(logits),
    'logits': logits,
    }
    return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    # Compute loss.
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    # Compute evaluation metrics.
    accuracy = tf.metrics.accuracy(labels=labels,
    predictions=predicted_classes,
    name='acc_op')
    metrics = {'accuracy': accuracy}
    tf.summary.scalar('accuracy', accuracy[1])

    if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
    mode, loss=loss, eval_metric_ops=metrics)

    # Create training op.
    assert mode == tf.estimator.ModeKeys.TRAIN

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

    可以看见features, labels分别是模型的input,output。所以是必须的。但是在有的模型里面。比如bert预训练的模型里面,我们不需要训练,所以只用到features。

  • 相关阅读:
    Java 抽象类 初学者笔记
    JAVA super关键字 初学者笔记
    Java 标准输入流 初学者笔记
    JAVA 将对象引用作为参数修改实例对象参数 初学者笔记
    JAVA 根据类构建数组(用类处理数组信息) 初学者笔记
    JAVA 打印日历 初学者笔记
    Python 测试代码 初学者笔记
    Python 文件&异常 初学者笔记
    Python 类 初学者笔记
    ubuntu网络连接失败
  • 原文地址:https://www.cnblogs.com/lijianming180/p/12366348.html
Copyright © 2011-2022 走看看