zoukankan      html  css  js  c++  java
  • Keras模型的保存方式

    Keras模型的保存方式

    在运行并且训练出一个模型后获得了模型的结构与许多参数,为了防止再次训练以及需要更好地去使用,我们需要保存当前状态

    基本保存方式 h5

    # 此处假设model为一个已经训练好的模型类
    model.save('my_model.h5')
    

    转换为json格式存储基本参数

    # 此处假设model为一个已经训练好的模型类
    json_string = model.to_json()
    open('my_model_architecture.json','w').write(json_string)
    

    转换为二进制pb格式

    以下代码为我从网络中寻找到的,可以将模型中的内容转换为pb格式,但需要更改其中的h5为你的模型的h5

    import sys \
    from keras.models import load_model \
    import tensorflow as tf
    import os
    import os.path as osp
    from keras import backend as K
    
    def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        graph = session.graph
        with graph.as_default():
            freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
            output_names = output_names or [] 
            output_names += [v.op.name for v in tf.global_variables()] 
            input_graph_def = graph.as_graph_def()
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = convert_variables_to_constants(session, input_graph_def,output_names,freeze_var_names)
        return frozen_graph
    
    input_fld = sys.path[0] 
    weight_file = 'my_model.h5'
    output_graph_name = 'tensor_model.pb'
    
    output_fld = input_fld + '/tensorflow_model/'
    if not os.path.isdir(output_fld):
        os.mkdir(output_fld) 
        weight_file_path = osp.join(input_fld, weight_file)
    K.set_learning_phase(0) 
    net_model = load_model(weight_file_path) 
    print('input is :', net_model.input.name) 
    print ('output is:', net_model.output.name) 
    sess = K.get_session() 
    frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name]) 
    from tensorflow.python.framework import graph_io 
    graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False) 
    print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
    
  • 相关阅读:
    JVM调优总结(四)-垃圾回收面临的问题
    JVM调优总结(三)-基本垃圾回收算法
    JVM调优总结(二)-一些概念
    Java8 Lambda表达式教程
    Java8 Lambda表达式教程
    Java8 Lambda表达式教程
    JVM调优总结(一)-- 一些概念
    Hibernate 3中如何获得库表所有字段的名称
    easyUI-datagrid带有工具栏和分页器的数据网格
    easyui-tabs
  • 原文地址:https://www.cnblogs.com/Phoenix-blog/p/10166031.html
Copyright © 2011-2022 走看看