zoukankan      html  css  js  c++  java
  • 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化

    TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了。tf.keras 从 1.x 版本迁移到 2.0 版本,需要注意几个地方。

    1. 设置随机种子

    import tensorflow as tf
    
    # TF 1.x
    tf.set_random_seed(args.seed)
    # TF 2.0
    tf.random.set_seed(args.seed)
    

    2. 设置并行线程数和动态分配显存

    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    
    import os
    # 将程序限定在一块GPU上
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    
    # TF 1.x
    config = tf.ConfigProto(intra_op_parallelism_threads=1,
                             inter_op_parallelism_threads=1)
    config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
    K.set_session(tf.Session(config=config))
    
    # TF 2.0,由于之前限定了GPU可见范围,这里只能看到0号GPU
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    print(gpus)
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, enable=True)
    

    3. model.compile() 中设置 metrics=['acc'] 或者 ['accuracy'],会影响 model.fit() 生成的 log,callbacks.ModelCheckpoint 需要对应填 val_acc 或者 val_accuracy:

    from tensorflow.python.keras import callbacks
    
    # TF 2.0, acc and val_acc
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['acc'])
    ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
                                                verbose=1, save_best_only=True, save_weights_only=True)
    
    # TF 2.0, accuracy and val_accuracy
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
                                                verbose=1, save_best_only=True, save_weights_only=True)
    

    4. 舍弃 model.fit_generator() 函数

    model.fit_generator() 函数在 TF 2.x 中合并到 model.fit() 函数中,并且在 TF 2.0 版本,该函数有问题,不能很好利用 GPU,训练速度很慢:
    Performance: Training is much slower in TF v2.0.0 VS v1.14.0 when using Tf.Keras and model.fit_generator #33024

    TF 2.0 版本的 model.fit() 在传入 generator 时需要手动设置 model.fit(shuffle=False)。

    解决办法:直接使用 model.fit() 函数,并且升级到 TF 2.1。

  • 相关阅读:
    HDU 3085 Nightmare Ⅱ[双向广搜]
    HDU 4028 The time of a day [离散化DP]
    HDU4027 Can you answer these queries? [线段树]
    HDU 4331 Image Recognition [边上全为1构成的正方形个数]
    HDU4026 Unlock the Cell Phone [状态压缩DP]
    HDU 4333 Revolving Digits [扩展KMP]
    HDU4335 What is N? [数论(欧拉函数)]
    工程与管理
    项目管理笔记一
    通过100个单词掌握英语语法(七)ask
  • 原文地址:https://www.cnblogs.com/wuliytTaotao/p/12016656.html
Copyright © 2011-2022 走看看