zoukankan      html  css  js  c++  java
  • TensorFlow 同时调用多个预训练好的模型

    在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测。

    弄明白了如何调用单个模型,其实调用多个模型也就顺理成章。我们只需要建立多个图,然后每个图导入一个模型,再针对每个图创建一个会话,分别进行预测即可。

    import tensorflow as tf
    import numpy as np
    
    # 建立两个 graph
    g1 = tf.Graph()
    g2 = tf.Graph()
    
    # 为每个 graph 建创建一个 session
    sess1 = tf.Session(graph=g1)
    sess2 = tf.Session(graph=g2)
    
    X_1 = None
    tst_1 = None
    yhat_1 = None
    
    X_2 = None
    tst_2 = None
    yhat_2 = None
    
    def load_model(sess):
        """
            Loading the pre-trained model and parameters.
        """
        global X_1, tst_1, yhat_1
        with sess1.as_default():
            with sess1.graph.as_default():
                modelpath = r'F:/resnet/model/new0.25-0.35/'
                saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
                saver.restore(sess1, tf.train.latest_checkpoint(modelpath))
                graph = tf.get_default_graph()
                X_1 = graph.get_tensor_by_name("X:0")
                tst_1 = graph.get_tensor_by_name("tst:0")
                yhat_1 = graph.get_tensor_by_name("tanh:0")
                print('Successfully load the model_1!')
    
    			
    def load_model_2():
        """
            Loading the pre-trained model and parameters.
        """
        global X_2, tst_2, yhat_2
        with sess2.as_default():
            with sess2.graph.as_default():
                modelpath = r'F:/resnet/model/new0.25-0.352/'
                saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
                saver.restore(sess2, tf.train.latest_checkpoint(modelpath))
                graph = tf.get_default_graph()
                X_2 = graph.get_tensor_by_name("X:0")
                tst_2 = graph.get_tensor_by_name("tst:0")
                yhat_2 = graph.get_tensor_by_name("tanh:0")
                print('Successfully load the model_2!')
    	
    def test_1(txtdata):
        """
            Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
            Test a single axample.
            Arg:
                    txtdata: Array in C.
            Returns:
                The normal of a face.
        """
        global X_1, tst_1, yhat_1
        data = np.array(txtdata)
        data = data.reshape(-1, 41, 41, 41, 3)
        output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True})  # (100, 3)
        output = output.reshape(-1, 1)
        ret = output.tolist()
        return ret
    
    
    def test_2(txtdata):
        """
            Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
            Test a single axample.
            Arg:
                    txtdata: Array in C.
            Returns:
                The normal of a face.
        """
        global X_2, tst_2, yhat_2
    
        data = np.array(txtdata)
        data = data.reshape(-1, 41, 41, 41, 3)
        output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True})  # (100, 3)
        output = output.reshape(-1, 1)
        ret = output.tolist()
    
        return ret
    
    

    最后,本程序只是为了说明问题,抛砖引玉,代码有很多冗余之处,不要模仿!

    获取更多精彩,请关注「seniusen」!
    seniusen

  • 相关阅读:
    fatal: 'origin' does not appear to be a git repository
    Mac cpu过高问题分析及解决
    Java8新特性之日期和时间
    Allure自动化测试报告之修改allure测试报告logo
    Allure自动化测试报告之修改allure测试报告名称
    jmeter压测过程中报java.lang.NoClassDefFoundError: org/bouncycastle/jce/provider/BouncyCastleProvider
    java.security.InvalidKeyException: Illegal key size or default parameters
    scanf中的%[^ ]%*c格式
    fork,vfork和clone底层实现
    僵尸进程的产生原因和避免方法
  • 原文地址:https://www.cnblogs.com/seniusen/p/9737428.html
Copyright © 2011-2022 走看看