zoukankan      html  css  js  c++  java
  • TensorFlow 同时调用多个预训练好的模型

    在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测。

    弄明白了如何调用单个模型,其实调用多个模型也就顺理成章。我们只需要建立多个图,然后每个图导入一个模型,再针对每个图创建一个会话,分别进行预测即可。

    import tensorflow as tf
    import numpy as np
    
    # 建立两个 graph
    g1 = tf.Graph()
    g2 = tf.Graph()
    
    # 为每个 graph 建创建一个 session
    sess1 = tf.Session(graph=g1)
    sess2 = tf.Session(graph=g2)
    
    X_1 = None
    tst_1 = None
    yhat_1 = None
    
    X_2 = None
    tst_2 = None
    yhat_2 = None
    
    def load_model(sess):
        """
            Loading the pre-trained model and parameters.
        """
        global X_1, tst_1, yhat_1
        with sess1.as_default():
            with sess1.graph.as_default():
                modelpath = r'F:/resnet/model/new0.25-0.35/'
                saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
                saver.restore(sess1, tf.train.latest_checkpoint(modelpath))
                graph = tf.get_default_graph()
                X_1 = graph.get_tensor_by_name("X:0")
                tst_1 = graph.get_tensor_by_name("tst:0")
                yhat_1 = graph.get_tensor_by_name("tanh:0")
                print('Successfully load the model_1!')
    
    			
    def load_model_2():
        """
            Loading the pre-trained model and parameters.
        """
        global X_2, tst_2, yhat_2
        with sess2.as_default():
            with sess2.graph.as_default():
                modelpath = r'F:/resnet/model/new0.25-0.352/'
                saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
                saver.restore(sess2, tf.train.latest_checkpoint(modelpath))
                graph = tf.get_default_graph()
                X_2 = graph.get_tensor_by_name("X:0")
                tst_2 = graph.get_tensor_by_name("tst:0")
                yhat_2 = graph.get_tensor_by_name("tanh:0")
                print('Successfully load the model_2!')
    	
    def test_1(txtdata):
        """
            Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
            Test a single axample.
            Arg:
                    txtdata: Array in C.
            Returns:
                The normal of a face.
        """
        global X_1, tst_1, yhat_1
        data = np.array(txtdata)
        data = data.reshape(-1, 41, 41, 41, 3)
        output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True})  # (100, 3)
        output = output.reshape(-1, 1)
        ret = output.tolist()
        return ret
    
    
    def test_2(txtdata):
        """
            Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
            Test a single axample.
            Arg:
                    txtdata: Array in C.
            Returns:
                The normal of a face.
        """
        global X_2, tst_2, yhat_2
    
        data = np.array(txtdata)
        data = data.reshape(-1, 41, 41, 41, 3)
        output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True})  # (100, 3)
        output = output.reshape(-1, 1)
        ret = output.tolist()
    
        return ret
    
    

    最后,本程序只是为了说明问题,抛砖引玉,代码有很多冗余之处,不要模仿!

    获取更多精彩,请关注「seniusen」!
    seniusen

  • 相关阅读:
    ZOJ 3765 Lights (zju March I)伸展树Splay
    UVA 11922 伸展树Splay 第一题
    UVALive 4794 Sharing Chocolate DP
    ZOJ 3757 Alice and Bod 模拟
    UVALive 3983 捡垃圾的机器人 DP
    UVA 10891 SUM游戏 DP
    poj 1328 Radar Installatio【贪心】
    poj 3264 Balanced Lineup【RMQ-ST查询区间最大最小值之差 +模板应用】
    【转】RMQ-ST算法详解
    poj 3083 Children of the Candy Corn 【条件约束dfs搜索 + bfs搜索】【复习搜索题目一定要看这道题目】
  • 原文地址:https://www.cnblogs.com/seniusen/p/9737428.html
Copyright © 2011-2022 走看看