zoukankan      html  css  js  c++  java
  • 深度学习模型转换,以pytorch转tensorflow为例

    这里以onnx为中介进行转换。主要用到

    STEP1. 将pytorch 模型转换成onnx模型

    注意这里关键是要构造一个模型的输入输入,这里假设模型接受两个输入。

    pmodel = PytorchModel()
    dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
    torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                      verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

    参数 input_names表示模型的输入参数(随便起名字),output_names表示输出名字

    STEP 2. 将onnx模型转成tf

    这里需要借助onnx_tf这个库

    import onnx
    from onnx_tf.backend import prepare
    
    onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
    tf_model = prepare(onnx_model)
    tf_model.export_graph("/tmp/xxpb/")  # export the model

    STEP 3 使用tensorflow模型

    import tensorflow as tf
    import io
    import numpy as np
    
    model_path = '/tmp/xxpb/'
    
    sess = tf.compat.v1.Session()
    metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
    sig = metagraph.signature_def["serving_default"]
    input_dict = dict(sig.inputs)
    output_dict = dict(sig.outputs)
    print(input_dict, output_dict)
    output_stochastic_act_label_0 = output_dict["output_0"].name
    output_stochastic_act_label_1 = output_dict["output_1"].name
    
    input_state_label = None
    initial_state = None
    state = None
    if "state" in input_dict.keys():
        input_state_label = input_dict["state"].name
        strfile = io.StringIO()
        print(input_dict["state"].tensor_shape, file=strfile)
        lines = strfile.getvalue().split("
    ")
        dim_1 = int(lines[1].split(":")[1].strip(" "))
        dim_2 = int(lines[4].split(":")[1].strip(" "))
        initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
        state = np.zeros((dim_1, dim_2), dtype=np.float32)
    input_obs_label_1 = input_dict["input1"].name
    input_obs_label_0 = input_dict["input2"].name
    input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
    out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
    print(out)

    注意这里的name需要重新设置一遍。





  • 相关阅读:
    There is no session with id session多人使用一个账号
    记录一次@Autowire和@Resource遇到的坑
    shiro 未认证登录统一处理以及碰到的问题记录
    Realm [*] was unable to find account data for the submitted AuthenticationToken
    springboot项目监听器不起作用
    发送邮件com.sun.mail.util.TraceInputStream.<init>(Ljava/io/InputStream;Lcom/sun/mail
    mysql查询重复数据记录
    使用shiro在网关层解决过滤url
    maven添加jetty插件,同时运行多个实例
    Linux 安装Zookeeper集群
  • 原文地址:https://www.cnblogs.com/MrLJC/p/14145763.html
Copyright © 2011-2022 走看看