zoukankan      html  css  js  c++  java
  • TensorFlow的checkpoint文件转换为pb文件

    由于项目需要,需要将TensorFlow保存的模型从ckpt文件转换为pb文件。

    import os
    from tensorflow.python import pywrap_tensorflow
    from net2use import inception_resnet_v2_small#这里使用自己定义的模型函数即可
    import tensorflow as tf
    if __name__=='__main__':
        pb_file = "./model/output.pb"
        ckpt_file = "./model/model.ckpt-652900"
        '''
    这里的节点名字可能跟设想的有出入,最直接的方法是直接输出ckpt中保存的节点名字,然后对应着找节点名字,具体的进入convert_variables_to_constants函数的实现中graph_util_impl.py,130行的函数:_assert_nodes_are_present 添加代码
        print('在图中的节点是:')
        for din in name_to_node:
            print('{},在图中'.format(din))
    然后运行代码,若正确就会直接保存;若失败则会保存失败,找好输出节点的名字,在output_node_names 中添加就好
    '''
        output_node_names = ["embedding"]
    
        with tf.name_scope('input'):
            image = tf.placeholder(tf.float32,shape=(None,79,199,1),name='input_image')
    
    
        net, endpoints=inception_resnet_v2_small(image, is_training=False)
        embedding = tf.nn.l2_normalize(net,1,1e-10,name='embedding')
    
        config=tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.per_process_gpu_memory_fraction = 0.45
        sess  = tf.Session(config = config)
        saver = tf.train.Saver()
        saver.restore(sess, ckpt_file)
        print('read success')
        converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                    input_graph_def  = sess.graph.as_graph_def(),
                                    output_node_names = output_node_names)
    
        with tf.gfile.GFile(pb_file, "wb") as f:
            f.write(converted_graph_def.SerializeToString())
    
        print('保存成功')
    
  • 相关阅读:
    javascript学习6
    javascript学习5
    javascript学习4
    javaccript学习3
    javaccript学习2
    javaccript学习1
    C++ 线性表实现
    深入解析策略模式(转)
    CentOS7安装MySQL
    万能媒体播放器 PotPlayer
  • 原文地址:https://www.cnblogs.com/kuadoh/p/11983978.html
Copyright © 2011-2022 走看看