zoukankan      html  css  js  c++  java
  • 深度学习模型转换,以pytorch转tensorflow为例

    这里以onnx为中介进行转换。主要用到

    STEP1. 将pytorch 模型转换成onnx模型

    注意这里关键是要构造一个模型的输入输入,这里假设模型接受两个输入。

    pmodel = PytorchModel()
    dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
    torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                      verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

    参数 input_names表示模型的输入参数(随便起名字),output_names表示输出名字

    STEP 2. 将onnx模型转成tf

    这里需要借助onnx_tf这个库

    import onnx
    from onnx_tf.backend import prepare
    
    onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
    tf_model = prepare(onnx_model)
    tf_model.export_graph("/tmp/xxpb/")  # export the model

    STEP 3 使用tensorflow模型

    import tensorflow as tf
    import io
    import numpy as np
    
    model_path = '/tmp/xxpb/'
    
    sess = tf.compat.v1.Session()
    metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
    sig = metagraph.signature_def["serving_default"]
    input_dict = dict(sig.inputs)
    output_dict = dict(sig.outputs)
    print(input_dict, output_dict)
    output_stochastic_act_label_0 = output_dict["output_0"].name
    output_stochastic_act_label_1 = output_dict["output_1"].name
    
    input_state_label = None
    initial_state = None
    state = None
    if "state" in input_dict.keys():
        input_state_label = input_dict["state"].name
        strfile = io.StringIO()
        print(input_dict["state"].tensor_shape, file=strfile)
        lines = strfile.getvalue().split("
    ")
        dim_1 = int(lines[1].split(":")[1].strip(" "))
        dim_2 = int(lines[4].split(":")[1].strip(" "))
        initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
        state = np.zeros((dim_1, dim_2), dtype=np.float32)
    input_obs_label_1 = input_dict["input1"].name
    input_obs_label_0 = input_dict["input2"].name
    input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
    out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
    print(out)

    注意这里的name需要重新设置一遍。





  • 相关阅读:
    VellCar(我的钢管车)
    我的留言发送提醒界面
    PHP邮箱验证是否有效
    PHP字符串操作常用函数
    让php永远后台运行
    找到的两个php爬虫,分享一下
    全国DNS服务器IP地址【电信、网通、铁通】
    我的新顶级域名vell001.ml
    PHP使用SwiftMailer发送邮件
    6个WordPress备份插件
  • 原文地址:https://www.cnblogs.com/MrLJC/p/14145763.html
Copyright © 2011-2022 走看看