zoukankan      html  css  js  c++  java
  • MxNet 模型转Tensorflow pb模型

    用mmdnn实现模型转换

    参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

    1. 安装mmdnn
      pip install mmdnn
    2. 准备好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50为例        https://github.com/deepinsight/insightface
    3. 用mmdnn运行命令行
      python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d resnet50 --inputShape 3,112,112 

       会生成resnet50.json(可视化文件) resnet50.npy(权重参数) resnet50.pb(网络结构)三个文件。

    4. 用mmdnn运行命令行
      python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py 

       生成tf_resnet50.py文件,可以调用tf_resnet50.py中的KitModel函数加载npy权重参数重新生成原网络框架。

    5. 打开tf_resnet.py文件,修改load_weights()中的代码 (tensorflow=1.14.0报错) 

       try:
              weights_dict = np.load(weight_file).item()
          except:
              weights_dict = np.load(weight_file, encoding='bytes').item()

      改为

       try:
              weights_dict = np.load(weight_file, allow_pickle=True).item()
      except:
              weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
    6. 基于resnet50.npy和tf_resnet50.py文​​件,固化参数,生成PB文件:

      import tensorflow as tf
      import tf_resnet50 as tf_fun
      def netWork():
          model=tf_fun.KitModel("./resnet50.npy")
          return model
      def freeze_graph(output_graph):
          output_node_names = "output"
          data,fc1=netWork()
          fc1=tf.identity(fc1,name="output")
      
          graph = tf.get_default_graph()  # 獲得默認的圖
          input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖
          init = tf.global_variables_initializer()
          with tf.Session() as sess:
              sess.run(init)
              output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
                  sess=sess,
                  input_graph_def=input_graph_def,  # 等於:sess.graph_def
                  output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
      
              with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                  f.write(output_graph_def.SerializeToString())  # 序列化輸出
      
      if __name__ == '__main__':
          freeze_graph("frozen_insightface_r50.pb")
          print("finish!")
    7. 采用tensorflow的post-train quantization离线量化方法(有一定的精度损失)转换成tflite模型,从而完成端侧的模型部署:
      import tensorflow as tf
      
      convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
                                                        input_shapes={"data":[1,112,112,3]})
      convert.post_training_quantize=True
      tflite_model=convert.convert()
      open("quantized_insightface_r50.tflite","wb").write(tflite_model)
      print("finish!")
  • 相关阅读:
    hadoop再次集群搭建(3)-如何选择相应的hadoop版本
    48. Rotate Image
    352. Data Stream as Disjoint Interval
    163. Missing Ranges
    228. Summary Ranges
    147. Insertion Sort List
    324. Wiggle Sort II
    215. Kth Largest Element in an Array
    快速排序
    280. Wiggle Sort
  • 原文地址:https://www.cnblogs.com/qiangz/p/11134240.html
Copyright © 2011-2022 走看看