zoukankan      html  css  js  c++  java
  • tensorflow :ckpt模型转换为pytorch : hdf5模型

    参考链接:https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py

    import tensorflow as tf
    import deepdish as dd
    import argparse
    import os
    import numpy as np
    
    def tr(v):
        # tensorflow weights to pytorch weights
        if v.ndim == 4:
            return np.ascontiguousarray(v.transpose(3,2,0,1))
        elif v.ndim == 2:
            return np.ascontiguousarray(v.transpose())
        return v
    
    def read_ckpt(ckpt):
        # https://github.com/tensorflow/tensorflow/issues/1823
        reader = tf.train.NewCheckpointReader(ckpt)
        weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().iteritems()}
        pyweights = {k: tr(v) for (k, v) in weights.items()}
        return pyweights
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
        parser.add_argument("infile", type=str,
                            help="Path to the ckpt.")
        parser.add_argument("outfile", type=str, nargs='?', default='',
                            help="Output file (inferred if missing).")
        args = parser.parse_args()
        if args.outfile == '':
            args.outfile = os.path.splitext(args.infile)[0] + '.h5'
        outdir = os.path.dirname(args.outfile)
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        weights = read_ckpt(args.infile)
        dd.io.save(args.outfile, weights)
        weights2 = dd.io.load(args.outfile)
  • 相关阅读:
    javaee_正则表达式基础和常用表达式
    hello2源代码分析
    servlet_filterj简介
    hello1的web.xml解析
    Annotation
    注入(Injection)
    容器(Container)
    Building Tool(Maven/Gradle)
    JavaWeb的历史与发展趋势
    Build Tools
  • 原文地址:https://www.cnblogs.com/wangyarui/p/9076401.html
Copyright © 2011-2022 走看看