zoukankan      html  css  js  c++  java
  • 深度学习模型转换,以pytorch转tensorflow为例

    这里以onnx为中介进行转换。主要用到

    STEP1. 将pytorch 模型转换成onnx模型

    注意这里关键是要构造一个模型的输入输入,这里假设模型接受两个输入。

    pmodel = PytorchModel()
    dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
    torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                      verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

    参数 input_names表示模型的输入参数(随便起名字),output_names表示输出名字

    STEP 2. 将onnx模型转成tf

    这里需要借助onnx_tf这个库

    import onnx
    from onnx_tf.backend import prepare
    
    onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
    tf_model = prepare(onnx_model)
    tf_model.export_graph("/tmp/xxpb/")  # export the model

    STEP 3 使用tensorflow模型

    import tensorflow as tf
    import io
    import numpy as np
    
    model_path = '/tmp/xxpb/'
    
    sess = tf.compat.v1.Session()
    metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
    sig = metagraph.signature_def["serving_default"]
    input_dict = dict(sig.inputs)
    output_dict = dict(sig.outputs)
    print(input_dict, output_dict)
    output_stochastic_act_label_0 = output_dict["output_0"].name
    output_stochastic_act_label_1 = output_dict["output_1"].name
    
    input_state_label = None
    initial_state = None
    state = None
    if "state" in input_dict.keys():
        input_state_label = input_dict["state"].name
        strfile = io.StringIO()
        print(input_dict["state"].tensor_shape, file=strfile)
        lines = strfile.getvalue().split("
    ")
        dim_1 = int(lines[1].split(":")[1].strip(" "))
        dim_2 = int(lines[4].split(":")[1].strip(" "))
        initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
        state = np.zeros((dim_1, dim_2), dtype=np.float32)
    input_obs_label_1 = input_dict["input1"].name
    input_obs_label_0 = input_dict["input2"].name
    input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
    out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
    print(out)

    注意这里的name需要重新设置一遍。





  • 相关阅读:
    Robin Hood CodeForces
    Arthur and Questions CodeForces
    AC日记——过河卒 洛谷 1002
    加密(模拟)
    AC日记——codevs 1086 栈 (卡特兰数)
    AC日记——搞笑世界杯 codevs 1060(dp)
    AC日记—— codevs 1031 质数环(搜索)
    AC日记——产生数 codevs 1009 (弗洛伊德)(组合数学)
    AC日记——阶乘之和 洛谷 P1009(高精度)
    AC日记——逃跑的拉尔夫 codevs 1026 (搜索)
  • 原文地址:https://www.cnblogs.com/MrLJC/p/14145763.html
Copyright © 2011-2022 走看看