zoukankan      html  css  js  c++  java
  • tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html

    一、模型的保存

    使用tensorflow训练模型的过程中,需要适时对模型进行保存,以及对保存的模型进行restore,以便后续对模型进行处理。如:测试、部署、拿别的模型进行fine-tune等。

    保存模型是整个内容的第一步,操作十分简单,只需要创建一个saver,并在一个Session里完成保存。

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.save(sess, model_name)

    以上代码在0.11以下版本的tensorflow里会保存与下面类似的3个文件

    checkpoint
    model.ckpt-1000.meta
    model.ckpt-1000.ckpt

    其中checkpoint列出保存的所有模型以及最近的模型;meta文件是模型定义的内容;ckpt(或data和index)文件是保存的模型数据。

    除了上面最简单的保存方式,也可以指定保存的步数,多长时间保存一次,磁盘上最多保存几个模型(将前面的删除以保持固定个数),需要做的是在创建saver时指定参数

    saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

    其中,savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量,如果省略,则保存所有。

    max_to_keep指定磁盘上最多保存有几个模型。

    keep_checkpoint_every_n_hours指定多少小时保存一次。

    保存模型时指定参数

    saver.save(sess, 'model_name', global_step=step, write_meta_graph=False)

    其中,可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph,等等。

    二、模型的恢复及查看模型参数

    with tf.Session() as sess:
        # 加载模型定义的graph
        saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
        # 方式一:加载指定文件夹下最近保存的一个模型的数据
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        # 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
        # saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))
    
        # 查看模型中的trainable variables
        tvs = [v for v in tf.trainable_variables()]
        for v in tvs:
            print(v.name)
            print(sess.run(v))
    
        # 查看模型中的所有tensor或者operations
        gv = [v for v in tf.global_variables()]
        for v in gv:
            print(v.name)
    
        # 获得几乎所有的operations相关的tensor
        ops = [o for o in sess.graph.get_operations()]
        for o in ops:
            print(o.name)

    说明:

    1、global_variables()比trainable_variables()多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。

    2、sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。

    三、将ckpt转化为pb

    freeze_graph就是将模型固化,具体说就是将训练数据和模型固化成pb文件。

    参数: (必选: 表示必须有值;可选: 表示可以为空):
    1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)
    2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。
    3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False
    4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
    5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
    6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all
    7、filename_tensor_name:(可选)已弃用。默认:save/Const:0
    8、output_graph:(必选)用来保存整合后的模型输出文件。
    9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
    10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
    11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。 

    if __name__ == '__main__':
        args = parse_args()
    
        # model path
        demonet = args.demo_net
        dataset = args.dataset
        tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])
    
        if not os.path.isfile(tfmodel + '.meta'):
            print(tfmodel)
            raise IOError(('{:s} not found.
    Did you download the proper networks from '
                           'our server and place them properly?').format(tfmodel + '.meta'))
    
        # set config
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
    
        # init session
        sess = tf.Session(config=tfconfig)
        # load network
        if demonet == 'vgg16':
            net = vgg16(batch_size=1)
        else:
            raise NotImplementedError
    
        net.create_architecture(sess, "TEST", 4,
                                tag='default', anchor_scales=[8, 16, 32])
        saver = tf.train.Saver()
        saver.restore(sess, tfmodel)
    
        # 保存图
        tf.train.write_graph(sess.graph_def, 'pb/pb_model', 'model.pb')
        # 把图和参数结构一起
        freeze_graph.freeze_graph('pb/pb_model/model.pb',
                                  '',
                                  False,
                                  tfmodel,
                                  'vgg_16/cls_score/BiasAdd,vgg_16/cls_prob,vgg_16/bbox_pred/BiasAdd,vgg_16/rois/PyFunc',
                                  'save/restore_all',
                                  'save/Const:0',
                                  'pb/pb_model/frozen_model.pb',
                                  False,
                                  "")
  • 相关阅读:
    Day60
    Day53
    Day50
    Day49
    Day48
    Day47
    Day46(2)
    Day46(1)
    Day45
    Day44
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11887811.html
Copyright © 2011-2022 走看看