zoukankan      html  css  js  c++  java
  • tensorflow 模型前向传播 保存ckpt tensorbard查看 ckpt转pb pb 转snpe dlc 实例

    参考:

    TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

    TensorFlow 模型保存与恢复

    snpe

    tensorflow 模型前向传播 保存ckpt  tensorbard查看 ckpt转pb  pb 转snpe dlc 实例

    log文件

    输入节点 图像高度 图像宽度 图像通道数

    input0 6,6,3

    输出节点

     --out_node add 

    snpe-tensorflow-to-dlc --graph ./simple_snpe_log/model200.pb -i input0 6,6,3 --out_node add

    #coding:utf-8
    #http://blog.csdn.net/zhuiqiuk/article/details/53376283
    #http://blog.csdn.net/gan_player/article/details/77586489
    from __future__ import absolute_import, unicode_literals
    import tensorflow as tf
    import shutil
    import os.path
    from tensorflow.python.framework import graph_util
    import mxnet as mx
    import numpy as np
    import random
    import cv2
    from time import sleep
    from easydict import EasyDict as edict
    import logging   
    import math
    import tensorflow as tf
    import numpy as np
    
    def FullyConnected(input, fc_weight, fc_bias, name):
        fc = tf.matmul(input, fc_weight) + fc_bias
        return fc
    
    def inference(body, name_class,outchannel): 
        wkernel = 3
        inchannel = body.get_shape()[3].value
        conv_weight = np.arange(wkernel * wkernel * inchannel * outchannel,dtype=np.float32).reshape((outchannel,inchannel,wkernel,wkernel))
        conv_weight =  conv_weight / (outchannel*inchannel*wkernel*wkernel)
        print("conv_weight ", conv_weight)
        conv_weight = conv_weight.transpose(2,3,1,0)
        conv_weight = tf.Variable(conv_weight, dtype=np.float32, name = "conv_weight")
        body = tf.nn.conv2d(body, conv_weight, strides=[1, 1, 1, 1], padding='SAME', name = "conv0")
        conv = body
        conv_shape = body.get_shape()
        dim = conv_shape[1].value * conv_shape[2].value * conv_shape[3].value 
        body = tf.reshape(body, [1, dim], name = "fc0")
        fc_weight = np.ones((dim, name_class))
        fc_bias = np.zeros((1, name_class))
        fc_weight = tf.Variable(fc_weight, dtype=np.float32, name="fc_weight")
        fc_bias = tf.Variable(fc_bias, dtype=np.float32, name="fc_bias")
        # tf.constant(100,dtype=np.float32, shape=(body.get_shape()[1] * body.get_shape()[2]  * body.get_shape()[3], name_class])
        # fc_bias = tf.constant(10, dtype=np.float32, shape=(1, name_class])
        body = FullyConnected(body, fc_weight, fc_bias, "fc0")
        return conv, body
    
    export_dir = "simple_snpe_log"
    def saveckpt():
        height = 6
        width = 6
        inchannel = 3
        outchannel = 3
        graph = tf.get_default_graph()
        with tf.Graph().as_default():
            input_image = tf.placeholder("float", [1, height, width, inchannel], name = "input0")
            conv, logdit = inference(input_image,10,outchannel)
            init = tf.global_variables_initializer()
            with tf.Session() as sess:
                sess.run(init)
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc = sess.run(logdit,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc", fc)
                #np.savetxt("tfconv.txt",fc) 
                #print( "fc", fc.transpose(0,3,2,1))
                #np.savetxt("tfrelu.txt",fc.transpose(0,3,2,1)[0][0]) 
    
                # #save ckpt
                export_dir = "simple_snpe_log"
                saver = tf.train.Saver()
                step = 200
                # if os.path.exists(export_dir):
                #     os.system("rm -rf " + export_dir)
                if not os.path.isdir(export_dir): # Create the log directory if it doesn't exist
                    os.makedirs(export_dir)
    
                checkpoint_file = os.path.join(export_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
    
    def LoadModelToTensorBoard():
        graph = tf.get_default_graph()
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        saver = tf.train.import_meta_graph(checkpoint_file)
        print(saver)
        summary_write = tf.summary.FileWriter(export_dir , graph)
        print(summary_write)
    
    def ckptToPb():
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        ckpt = tf.train.get_checkpoint_state(export_dir)
        print("model ", ckpt.model_checkpoint_path)
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
        graph = tf.get_default_graph()
        with tf.Session() as sess:
            saver.restore(sess,ckpt.model_checkpoint_path)
            height = 6
            width = 6
            input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
            fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
            sess.run(tf.global_variables_initializer())
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph.as_graph_def(), ['add'])
            model_name = os.path.join(export_dir, 'model200.pb')
            with tf.gfile.GFile(model_name, "wb") as f:  
                f.write(output_graph_def.SerializeToString()) 
    
    def PbTest():
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            output_graph_path = os.path.join(export_dir,'model200.pb')
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
    
            with tf.Session() as sess:
                tf.initialize_all_variables().run()
                height = 6
                width = 6
                inchannel = 3
                outchannel = 3
                input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
                fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
                conv = tf.get_default_graph().get_tensor_by_name("conv0:0")
    
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc0_output = sess.run(fc0_output,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc0_output", fc0_output)
    
    if __name__ == '__main__':
    
        saveckpt() #1
        LoadModelToTensorBoard()#2 
        ckptToPb()#3
        PbTest()#4
  • 相关阅读:
    前端学PHP之错误处理
    mysql数据库学习目录
    前端学数据库之存储
    前端学数据库之函数
    用shell脚本监控进程是否存在 不存在则启动的实例
    在notepad++里面使用正则表达式替换掉所有行逗号前面内容
    mysql合并 两个count语句一次性输出结果的方法
    硬件中断和DPC一直占40-52%左右 解决方法
    解决secureCRT 数据库里没有找到防火墙 '无' 此会话降尝试不通过防火墙进行连接。
    Java eclipse下 Ant build.xml实例详解 附完整项目源码
  • 原文地址:https://www.cnblogs.com/adong7639/p/9241855.html
Copyright © 2011-2022 走看看