zoukankan      html  css  js  c++  java
  • tensorflow 模型前向传播 保存ckpt tensorbard查看 ckpt转pb pb 转snpe dlc 实例

    参考:

    TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

    TensorFlow 模型保存与恢复

    snpe

    tensorflow 模型前向传播 保存ckpt  tensorbard查看 ckpt转pb  pb 转snpe dlc 实例

    log文件

    输入节点 图像高度 图像宽度 图像通道数

    input0 6,6,3

    输出节点

     --out_node add 

    snpe-tensorflow-to-dlc --graph ./simple_snpe_log/model200.pb -i input0 6,6,3 --out_node add

    #coding:utf-8
    #http://blog.csdn.net/zhuiqiuk/article/details/53376283
    #http://blog.csdn.net/gan_player/article/details/77586489
    from __future__ import absolute_import, unicode_literals
    import tensorflow as tf
    import shutil
    import os.path
    from tensorflow.python.framework import graph_util
    import mxnet as mx
    import numpy as np
    import random
    import cv2
    from time import sleep
    from easydict import EasyDict as edict
    import logging   
    import math
    import tensorflow as tf
    import numpy as np
    
    def FullyConnected(input, fc_weight, fc_bias, name):
        fc = tf.matmul(input, fc_weight) + fc_bias
        return fc
    
    def inference(body, name_class,outchannel): 
        wkernel = 3
        inchannel = body.get_shape()[3].value
        conv_weight = np.arange(wkernel * wkernel * inchannel * outchannel,dtype=np.float32).reshape((outchannel,inchannel,wkernel,wkernel))
        conv_weight =  conv_weight / (outchannel*inchannel*wkernel*wkernel)
        print("conv_weight ", conv_weight)
        conv_weight = conv_weight.transpose(2,3,1,0)
        conv_weight = tf.Variable(conv_weight, dtype=np.float32, name = "conv_weight")
        body = tf.nn.conv2d(body, conv_weight, strides=[1, 1, 1, 1], padding='SAME', name = "conv0")
        conv = body
        conv_shape = body.get_shape()
        dim = conv_shape[1].value * conv_shape[2].value * conv_shape[3].value 
        body = tf.reshape(body, [1, dim], name = "fc0")
        fc_weight = np.ones((dim, name_class))
        fc_bias = np.zeros((1, name_class))
        fc_weight = tf.Variable(fc_weight, dtype=np.float32, name="fc_weight")
        fc_bias = tf.Variable(fc_bias, dtype=np.float32, name="fc_bias")
        # tf.constant(100,dtype=np.float32, shape=(body.get_shape()[1] * body.get_shape()[2]  * body.get_shape()[3], name_class])
        # fc_bias = tf.constant(10, dtype=np.float32, shape=(1, name_class])
        body = FullyConnected(body, fc_weight, fc_bias, "fc0")
        return conv, body
    
    export_dir = "simple_snpe_log"
    def saveckpt():
        height = 6
        width = 6
        inchannel = 3
        outchannel = 3
        graph = tf.get_default_graph()
        with tf.Graph().as_default():
            input_image = tf.placeholder("float", [1, height, width, inchannel], name = "input0")
            conv, logdit = inference(input_image,10,outchannel)
            init = tf.global_variables_initializer()
            with tf.Session() as sess:
                sess.run(init)
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc = sess.run(logdit,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc", fc)
                #np.savetxt("tfconv.txt",fc) 
                #print( "fc", fc.transpose(0,3,2,1))
                #np.savetxt("tfrelu.txt",fc.transpose(0,3,2,1)[0][0]) 
    
                # #save ckpt
                export_dir = "simple_snpe_log"
                saver = tf.train.Saver()
                step = 200
                # if os.path.exists(export_dir):
                #     os.system("rm -rf " + export_dir)
                if not os.path.isdir(export_dir): # Create the log directory if it doesn't exist
                    os.makedirs(export_dir)
    
                checkpoint_file = os.path.join(export_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
    
    def LoadModelToTensorBoard():
        graph = tf.get_default_graph()
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        saver = tf.train.import_meta_graph(checkpoint_file)
        print(saver)
        summary_write = tf.summary.FileWriter(export_dir , graph)
        print(summary_write)
    
    def ckptToPb():
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        ckpt = tf.train.get_checkpoint_state(export_dir)
        print("model ", ckpt.model_checkpoint_path)
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
        graph = tf.get_default_graph()
        with tf.Session() as sess:
            saver.restore(sess,ckpt.model_checkpoint_path)
            height = 6
            width = 6
            input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
            fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
            sess.run(tf.global_variables_initializer())
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph.as_graph_def(), ['add'])
            model_name = os.path.join(export_dir, 'model200.pb')
            with tf.gfile.GFile(model_name, "wb") as f:  
                f.write(output_graph_def.SerializeToString()) 
    
    def PbTest():
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            output_graph_path = os.path.join(export_dir,'model200.pb')
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
    
            with tf.Session() as sess:
                tf.initialize_all_variables().run()
                height = 6
                width = 6
                inchannel = 3
                outchannel = 3
                input_image = tf.get_default_graph().get_tensor_by_name("input0:0")
                fc0_output = tf.get_default_graph().get_tensor_by_name("add:0")
                conv = tf.get_default_graph().get_tensor_by_name("conv0:0")
    
                img = np.arange(height * width * inchannel, dtype=np.float32).reshape((1,inchannel,height,width)) 
                      / (1 * inchannel * height * width) * 255.0 - 127.5
                print("img",img)
                img = img.transpose(0,2,3,1)
                import time 
                since = time.time()
                fc0_output = sess.run(fc0_output,{input_image:img})
                conv = sess.run(conv, {input_image: img})
                time_elapsed = time.time() - since
                print("tf inference time ", str(time_elapsed))
                print("conv", conv.transpose(0, 2, 3, 1))
                print("fc0_output", fc0_output)
    
    if __name__ == '__main__':
    
        saveckpt() #1
        LoadModelToTensorBoard()#2 
        ckptToPb()#3
        PbTest()#4
  • 相关阅读:
    windows下Graphviz安装及入门教程
    安装配置Xdebug模块详解
    Git结合tar自动打升级包
    git stash命令详解
    redis启动出错Creating Server TCP listening socket 127.0.0.1:6379: bind: No error(转)
    航司二字码
    BeanCreationException: Error creating bean with name 'classPathFileSystemWatcher'之解决办法
    org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'jpaMappingContext之解决办法
    Postman测试后台使用@RequestBody接收参数的坑
    Idea必知必会
  • 原文地址:https://www.cnblogs.com/adong7639/p/9241855.html
Copyright © 2011-2022 走看看