zoukankan      html  css  js  c++  java
  • tensorflow和pytorch模型之间转换

    参考链接:
    https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py

    一. tensorflow模型转pytorch模型

    import tensorflow as tf
    import deepdish as dd
    import argparse
    import os
    import numpy as np
    
    def tr(v):
        # tensorflow weights to pytorch weights
        if v.ndim == 4:
            return np.ascontiguousarray(v.transpose(3,2,0,1))
        elif v.ndim == 2:
            return np.ascontiguousarray(v.transpose())
        return v
    
    def read_ckpt(ckpt):
        # https://github.com/tensorflow/tensorflow/issues/1823
        reader = tf.train.NewCheckpointReader(ckpt)
        weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
        pyweights = {k: tr(v) for (k, v) in weights.items()}
        return pyweights
    if __name__ == '__main__':
        parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
        parser.add_argument("infile", type=str,
                            help="Path to the ckpt.")  # ***model.ckpt-22177***
        parser.add_argument("outfile", type=str, nargs='?', default='',
                            help="Output file (inferred if missing).")
        args = parser.parse_args()
        if args.outfile == '':
            args.outfile = os.path.splitext(args.infile)[0] + '.h5'
        outdir = os.path.dirname(args.outfile)
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        weights = read_ckpt(args.infile)
        dd.io.save(args.outfile, weights)
    

      

    1.运行上述代码后会得到model.h5模型,如下:
    备注:保持tensorflow和pytorch使用的python版本一致

    2.使用:在pytorch内加载改模型:
    这里假设网络保存时参数命名一致

    net = ...
    import torch
    import deepdish as dd
    net = resnet50(..)
    model_dict = net.state_dict()
    #先将参数值numpy转换为tensor形式
    pretrained_dict =  = dd.io.load('./model.h5')
    new_pre_dict = {}
    for k,v in pretrained_dict.items():
        new_pre_dict[k] = torch.Tensor(v)
    #更新
    model_dict.update(new_pre_dict)
    #加载
    net.load_state_dict(model_dict)
    

      

    二. pytorch转tensorflow(待续。。)

    原文:https://blog.csdn.net/weixin_42699651/article/details/88932670

  • 相关阅读:
    给列表项标记添加自定义图像
    双飞翼布局与圣杯布局
    CSS3 calc()
    CSS滚动视差
    应用层层面面试题汇总
    Linux下OpenSSL 安装
    深入理解:Android 编译系统
    ios 好去处
    IBOutlet & IBAction
    ar技术序章-SDK介绍和选择
  • 原文地址:https://www.cnblogs.com/qbdj/p/11024565.html
Copyright © 2011-2022 走看看