zoukankan      html  css  js  c++  java
  • tensorflow estimator 与 model_fn 是这样沟通的

    在自定义估计器过程中,搞清Estimator 与model_fn 及其他参数之间的关系十分中重要!总结一下,就是
    estimator 拿着获取到的参数往model_fn里面灌,model_fn 是作为用数据的关键用户。
    与scikit-learn和spark中的各种估计器相比,tensorflow的估计器抽象程度更高,因为他将各种由超参数知道构建的
    模型作为参数传入,estimator的结构和定义不会因为模型的变化带来特别大的变化;而spark,scikit-learn中,估计器
    往往因算法不同而有不同构造,TensorFlow的参数化程度更高,有更高自由度,因而参数管理就与前两者有所不同!

    总之,Estimator要使用传入的数据就必须了解传入的数据,java有种类型控制,Python中鸭子判断检查,或者有元数据帮忙了解传入的数据,
    或者大家有默契约定,或者有明显的协议!Esimator和mode_fn之间没有强制约束,靠大家默契约定,约定内容就在下面的英文描述中。
    Depending on the value of mode, different arguments are required. Namely

    * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
    * For `mode == ModeKeys.EVAL`: required field is `loss`.
    * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.
    

    class Estimator(object):
    """Estimator class to train and evaluate TensorFlow models.

    The Estimator object wraps a model which is specified by a model_fn,
    which, given inputs and a number of other parameters, returns the ops
    necessary to perform training, evaluation, or predictions.

    All outputs (checkpoints, event files, etc.) are written to model_dir, or a
    subdirectory thereof. If model_dir is not set, a temporary directory is
    used.

    The config argument can be passed tf.estimator.RunConfig object containing
    information about the execution environment. It is passed on to the
    model_fn, if the model_fn has a parameter named "config" (and input
    functions in the same manner). If the config parameter is not passed, it is
    instantiated by the Estimator. Not passing config means that defaults useful
    for local execution are used. Estimator makes config available to the model
    (for instance, to allow specialization based on the number of workers
    available), and also uses some of its fields to control internals, especially
    regarding checkpointing.

    The params argument contains hyperparameters. It is passed to the
    model_fn, if the model_fn has a parameter named "params", and to the input
    functions in the same manner. Estimator only passes params along, it does
    not inspect it. The structure of params is therefore entirely up to the
    developer.

    None of Estimator's methods can be overridden in subclasses (its
    constructor enforces this). Subclasses should use model_fn to configure
    the base class, and may add methods implementing specialized functionality.

    @compatibility(eager)
    Calling methods of Estimator will work while eager execution is enabled.
    However, the model_fn and input_fn is not executed eagerly, Estimator
    will switch to graph model before calling all user-provided functions (incl.
    hooks), so their code has to be compatible with graph mode execution. Note
    that input_fn code using tf.data generally works in both graph and eager
    modes.
    @end_compatibility
    """

    def init(self, model_fn, model_dir=None, config=None, params=None,
    warm_start_from=None):
    """Constructs an Estimator instance.

    See [estimators](https://tensorflow.org/guide/estimators) for more
    information.
    
    To warm-start an `Estimator`:
    
    ```python
    estimator = tf.estimator.DNNClassifier(
        feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
        hidden_units=[1024, 512, 256],
        warm_start_from="/path/to/checkpoint/dir")
    ```
    
    For more details on warm-start configuration, see
    `tf.estimator.WarmStartSettings`.
    
    Args:
      model_fn: Model function. Follows the signature:
    
        * Args:
    
          * `features`: This is the first item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `tf.Tensor` or `dict` of same.
          * `labels`: This is the second item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `tf.Tensor` or `dict` of same (for multi-head models).
                 If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
                 be passed. If the `model_fn`'s signature does not accept
                 `mode`, the `model_fn` must still be able to handle
                 `labels=None`.
          * `mode`: Optional. Specifies if this training, evaluation or
                 prediction. See `tf.estimator.ModeKeys`.
          * `params`: Optional `dict` of hyperparameters.  Will receive what
                 is passed to Estimator in `params` parameter. This allows
                 to configure Estimators from hyper parameter tuning.
          * `config`: Optional `estimator.RunConfig` object. Will receive what
                 is passed to Estimator as its `config` parameter, or a default
                 value. Allows setting up things in your `model_fn` based on
                 configuration such as `num_ps_replicas`, or `model_dir`.
    
        * Returns:
          `tf.estimator.EstimatorSpec`
    
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into an estimator to
        continue training a previously saved model. If `PathLike` object, the
        path will be resolved. If `None`, the model_dir in `config` will be used
        if set. If both are set, they must be same. If both are `None`, a
        temporary directory will be used.
      config: `estimator.RunConfig` configuration object.
      params: `dict` of hyper parameters that will be passed into `model_fn`.
              Keys are names of parameters, values are basic python types.
      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
                       warm-start from, or a `tf.estimator.WarmStartSettings`
                       object to fully configure warm-starting.  If the string
                       filepath is provided instead of a
                       `tf.estimator.WarmStartSettings`, then all variables are
                       warm-started, and it is assumed that vocabularies
                       and `tf.Tensor` names are unchanged.
    
    Raises:
      ValueError: parameters of `model_fn` don't match `params`.
      ValueError: if this is called via a subclass and if that class overrides
        a member of `Estimator`.
    """
  • 相关阅读:
    Mysql-存储过程-批量增加数据
    VIM
    查看Chrome浏览器同步数据状态工具
    Tomcat运行配置
    mysql的engine不同,导致事物回滚失败的问题
    git在MAC,Linux的terminator(命令行)下自动显示当前分支
    SED单行脚本快速参考(Unix 流编辑器)
    awk中使用shell的环境变量
    Grep Sed Awk
    shred_linux_unix
  • 原文地址:https://www.cnblogs.com/wdmx/p/10010433.html
Copyright © 2011-2022 走看看