zoukankan      html  css  js  c++  java
  • tensorflow 模型前向传播 保存ckpt tensorbard查看 ckpt转pb pb 转snpe dlc 实例

    参考:

    TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

    TensorFlow 模型保存与恢复

    snpe

    tensorflow 模型前向传播 保存ckpt  tensorbard查看 ckpt转pb  pb 转snpe dlc 实例

    log文件

    输入节点 图像高度 图像宽度 图像通道数

    input0 6,6,3

    输出节点

     --out_node add 

    snpe-tensorflow-to-dlc --graph ./simple_snpe_log/model200.pb -i input0 6,6,3 --out_node add

    #coding:utf-8
    #http://blog.csdn.net/zhuiqiuk/article/details/53376283
    #http://blog.csdn.net/gan_player/article/details/77586489
    from __future__ import absolute_import, unicode_literals
    import tensorflow as tf
    import shutil
    import os.path
    from tensorflow.python.framework import graph_util
    import mxnet as mx
    import numpy as np
    import random
    import cv2
    from time import sleep
    from easydict import EasyDict as edict
    import logging   
    import math
    import tensorflow as tf
    import numpy as np
    
    def FullyConnected(input, fc_weight, fc_bias, name):
        fc = tf.matmul(input, fc_weight) + fc_bias
        return fc
    
    def inference(body, name_class,outchannel): 
        wkernel = 3
        inchannel = body.get_shape()[3].value
        conv_weight = np.arange(wkernel * wkernel * inchannel * outchannel,dtype=np.float32).reshape((outchannel,inchannel,wkernel,wkernel))
        conv_weight =  conv_weight / (outchannel*inchannel*wkernel*wkernel)
        print("conv_weight ", conv_weight)
        conv_weight = conv_weight.transpose(2,3,1,0)
        conv_weight = tf.Variable(conv_weight, dtype=np.float32, name = "conv_weight")
        body = tf.nn.conv2d(body, conv_weight, strides=[1, 1, 1, 1], padding='SAME', name = "conv0")
        conv = body
        conv_shape = body.get_shape()
        dim = conv_shape[1].value * conv_shape[2].value * conv_shape[3].value 
        body = tf.reshape(body, [1, dim], name = "fc0")
        fc_weight = np.ones((dim, name_class))
        fc_bias = np.zeros((1, name_class))
        fc_weight = tf.Variable(fc_weight, dtype=np.float32, name="fc_weight")
        fc_bias = tf.Variable(fc_bias, dtype=np.float32, name="fc_bias")
        # tf.constant(100,dtype=np.float32, shape=(body.get_shape()[1] * body.get_shape()[2]  * body.get_shape()[3], name_class])
        # fc_bias = tf.constant(10, dtype=np.float32, shape=(1, name_class])
        body = FullyConnected(body, fc_weight, fc_bias, "fc0")
        return conv, body
    
    export_dir = "simple_snpe_log"
    def saveckpt():
        height = 6
        width = 6
        inchannel = 3
        outchannel = 3
        graph = tf.get_default_graph()
        with tf.Graph().as_default():
            input_image = tf.placeholder("float", [1, height, width, inchannel], name = "input0")
            conv, logdit = inference(input_image,10,outchannel)
            init = tf.global_variables_initializer()
            with tf.Session() as sess:
                sess.run(init)
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc = sess.run(logdit,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc", fc)
                #np.savetxt("tfconv.txt",fc) 
                #print( "fc", fc.transpose(0,3,2,1))
                #np.savetxt("tfrelu.txt",fc.transpose(0,3,2,1)[0][0]) 
    
                # #save ckpt
                export_dir = "simple_snpe_log"
                saver = tf.train.Saver()
                step = 200
                # if os.path.exists(export_dir):
                #     os.system("rm -rf " + export_dir)
                if not os.path.isdir(export_dir): # Create the log directory if it doesn't exist
                    os.makedirs(export_dir)
    
                checkpoint_file = os.path.join(export_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
    
    def LoadModelToTensorBoard():
        graph = tf.get_default_graph()
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        saver = tf.train.import_meta_graph(checkpoint_file)
        print(saver)
        summary_write = tf.summary.FileWriter(export_dir , graph)
        print(summary_write)
    
    def ckptToPb():
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        ckpt = tf.train.get_checkpoint_state(export_dir)
        print("model ", ckpt.model_checkpoint_path)
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
        graph = tf.get_default_graph()
        with tf.Session() as sess:
            saver.restore(sess,ckpt.model_checkpoint_path)
            height = 6
            width = 6
            input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
            fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
            sess.run(tf.global_variables_initializer())
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph.as_graph_def(), ['add'])
            model_name = os.path.join(export_dir, 'model200.pb')
            with tf.gfile.GFile(model_name, "wb") as f:  
                f.write(output_graph_def.SerializeToString()) 
    
    def PbTest():
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            output_graph_path = os.path.join(export_dir,'model200.pb')
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
    
            with tf.Session() as sess:
                tf.initialize_all_variables().run()
                height = 6
                width = 6
                inchannel = 3
                outchannel = 3
                input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
                fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
                conv = tf.get_default_graph().get_tensor_by_name("conv0:0")
    
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc0_output = sess.run(fc0_output,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc0_output", fc0_output)
    
    if __name__ == '__main__':
    
        saveckpt() #1
        LoadModelToTensorBoard()#2 
        ckptToPb()#3
        PbTest()#4
  • 相关阅读:
    程序员常去的14个顶级开发社区
    为何技术领域中女程序员较少?
    为何技术领域中女程序员较少?
    为何技术领域中女程序员较少?
    关于HTTP和HTTPS的区别
    关于HTTP和HTTPS的区别
    关于HTTP和HTTPS的区别
    Coupled model
    java和javascript日期详解
    Java 线程总结(十四)
  • 原文地址:https://www.cnblogs.com/adong7639/p/9241855.html
Copyright © 2011-2022 走看看