zoukankan      html  css  js  c++  java
  • Keras的一些功能函数

    1、模型的信息提取

    1 # 节点信息提取
    2 config = model.get_config()  # 把model中的信息,solver.prototxt和train.prototxt信息提取出来
    3 model = Model.from_config(config)  # 还回去
    4 # or, for Sequential:
    5 model = Sequential.from_config(config) # 重构一个新的Model模型,用去其他训练,fine-tuning比较好用

    2、模型概况查询

    # 1、模型概括打印
    model.summary()
    
    # 2、返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型:
    from models import model_from_json
    
    json_string = model.to_json()
    model = model_from_json(json_string)
    
    # 3、model.to_yaml:与model.to_json类似,同样可以从产生的YAML字符串中重构模型
    from models import model_from_yaml
    
    yaml_string = model.to_yaml()
    model = model_from_yaml(yaml_string)
    
    # 4、权重获取
    model.get_layer()      #依据层名或下标获得层对象
    model.get_weights()    #返回模型权重张量的列表,类型为numpy array
    model.set_weights()    #从numpy array里将权重载入给模型,要求数组具有与model.get_weights()相同的形状。
    
    # 查看model中Layer的信息
    model.layers 查看layer信息

    3、模型保存与加载

    model.save_weights(filepath)
    # 将模型权重保存到指定路径,文件类型是HDF5(后缀是.h5)
    
    model.load_weights(filepath, by_name=False)
    # 从HDF5文件中加载权重到当前模型中, 默认情况下模型的结构将保持不变。
    # 如果想将权重载入不同的模型(有些层相同)中,则设置by_name=True,只有名字匹配的层才会载入权重

    4、在keras中设定GPU的大小

    import tensorflow as tf
    from keras.backend.tensorflow_backend import set_session
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    set_session(tf.Session(config=config))

    5、训练与保存模型

    filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
    checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    # fit model
    model.fit(x, y, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=(x, y))

    6、在keras中使用tensorboard

    RUN = RUN + 1 if 'RUN' in locals() else 1   # locals() 函数会以字典类型返回当前位置的全部局部变量。
    
        LOG_DIR = model_save_path + '/training_logs/run{}'.format(RUN)
        LOG_FILE_PATH = LOG_DIR + '/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5'   # 模型Log文件以及.h5模型文件存放地址
    
        tensorboard = TensorBoard(log_dir=LOG_DIR, write_images=True)
        checkpoint = ModelCheckpoint(filepath=LOG_FILE_PATH, monitor='val_loss', verbose=1, save_best_only=True)
        early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
    
        history = model.fit_generator(generator=gen.generate(True), steps_per_epoch=int(gen.train_batches / 4),
                                      validation_data=gen.generate(False), validation_steps=int(gen.val_batches / 4),
                                      epochs=EPOCHS, verbose=1, callbacks=[tensorboard, checkpoint, early_stopping])
    谢谢!
  • 相关阅读:
    Django的日志操作,记录一下自己的使用
    初学jupyter 与爬虫
    mysql的库名或者表名带空格不能删除的问题
    Linux的vim命令的快捷键
    DjangoORM相关(多表操作)
    DjangoORM相关(单表操作)
    Django模板
    Django URL相关
    Monkeyrunner学习记录之运行模拟器
    Monkeyrunner学习记录之环境搭建
  • 原文地址:https://www.cnblogs.com/ylxn/p/10721575.html
Copyright © 2011-2022 走看看