zoukankan      html  css  js  c++  java
  • keras_0_快速开始部分

    1. "sample", "batch", "epoch" 分别是什么?

    为了正确地使用 Keras,以下是必须了解和理解的一些常见定义:

    • Sample: 样本,数据集中的一个元素,一条数据。
    • 例1: 在卷积神经网络中,一张图像是一个样本。
    • 例2: 在语音识别模型中,一段音频是一个样本。
    • Batch: 批,含有 N 个样本的集合。每一个 batch 的样本都是独立并行处理的。在训练时,一个 batch 的结果只会用来更新一次模型。 一个 batch 的样本通常比单个输入更接近于总体输入数据的分布,batch 越大就越近似。然而,每个 batch 将花费更长的时间来处理,并且仍然只更新模型一次。在推理(评估/预测)时,建议条件允许的情况下选择一个尽可能大的 batch,(因为较大的 batch 通常评估/预测的速度会更快)。
    • Epoch: 轮次,通常被定义为 「在整个数据集上的一轮迭代」,用于训练的不同的阶段,这有利于记录和定期评估。
    • 当在 Keras 模型的 fit 方法中使用 evaluation_dataevaluation_split 时,评估将在每个 epoch 结束时运行。
    • 在 Keras 中,可以添加专门的用于在 epoch 结束时运行的 callbacks 回调。例如学习率变化和模型检查点(保存)。

    2. 为什么训练误差比测试误差高很多?

    1. Keras 模型有两种模式:训练和测试。正则化机制,如 Dropout 和 L1/L2 权重正则化,在测试时是关闭的。(是这样吗?我不太确定!)

    2. 此外,训练误差是每批训练数据的平均误差。由于你的模型是随着时间而变化的,一个 epoch 中的第一批数据的误差通常比最后一批的要。另一方面,测试误差是模型在一个 epoch 训练完后计算的,因而误差较小。

    3. 如何在keras中获取可复现的实验结果?

    1. 一个例子:
      • PYTHONHASHSEED
      • numpy中使用的随机数,python中使用的随机数
      • 强制 TensorFlow 使用单线程。
    import numpy as np
    import tensorflow as tf
    import random as rn
    
    # 以下是 Python 3.2.3 以上所必需的,
    # 为了使某些基于散列的操作可复现。
    # https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
    # https://github.com/keras-team/keras/issues/2280#issuecomment-306959926
    
    import os
    os.environ['PYTHONHASHSEED'] = '0'
    
    # 以下是 Numpy 在一个明确的初始状态生成固定随机数字所必需的。
    
    np.random.seed(42)
    
    # 以下是 Python 在一个明确的初始状态生成固定随机数字所必需的。
    
    rn.seed(12345)
    
    # 强制 TensorFlow 使用单线程。
    # 多线程是结果不可复现的一个潜在的来源。
    # 更多详情,见: https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
    
    session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    
    from keras import backend as K
    
    # `tf.set_random_seed()` 将会以 TensorFlow 为后端,
    # 在一个明确的初始状态下生成固定随机数字。
    # 更多详情,见: https://www.tensorflow.org/api_docs/python/tf/set_random_seed
    
    tf.set_random_seed(1234)
    
    sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
    K.set_session(sess)
    
    # 剩余代码 ...
    
  • 相关阅读:
    js异步编程
    gitreset
    js数据类型
    vuex报错
    个人管理系统综述
    ffmpeg第7篇:数据流选择神器map指令
    eltable多选框根据条件隐藏显示
    [域渗透内网渗透] 从 web 到域控,你未曾设想的攻击链
    宽字节第二期线下培训开始招生啦!!!
    cve20212394 weblogic反序列化漏洞分析
  • 原文地址:https://www.cnblogs.com/LS1314/p/10380589.html
Copyright © 2011-2022 走看看