zoukankan      html  css  js  c++  java
  • 如何手动解析Keras等框架保存的HDF5格式的权重文件

    问题

    在使用Keras保存为h5或者hdf5格式的模型权重文件后,一般采用keras.models.load_model()恢复模型结构和权重,或者采用model.load_weights()导入权重。
    但是,在进行迁移学习或者模型输出nan或inf时,需要手动导入部分权重进行查看或者修改之类,就不得不学会操作HDF5格式文件了。

    解答

    在Python中,常采用h5py库对HDF5文件进行读写操作,这是非常方便的,其导出的权重都是numpy矩阵,可以直接应用。
    HDF5格式本身类似于XML或者JSON,是一种通用的树状结构文档的表示方式,通过 $ /根目录/子目录 $ 的路径形式定位元素。
    话不多说,上示例代码:

    # test.py
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
    
    import cv2 as cv
    import numpy as np
    import glob
    
    from model import build_encoder_decoder, build_refinement
    from data_generator import normalize_input, denormalize_output, depth_random_scale_shift
    
    import h5py as h5
    import tensorflow as tf
    
    if __name__ == '__main__':
        img_rows, img_cols = 288, 384
        channel = 4
    
        model_path = '../Models/DIM/final.00000000-0.0607.hdf5'
        coarse = build_encoder_decoder(img_rows, img_cols, train=False)
        fine = build_refinement(coarse, train=False)
        #fine.summary()
    
        f = h5.File(model_path, 'r')
        w = f['/model_weights']
        params = dict()
        for k in w.keys():
            if len(w[k].keys()) > 0:
                sub_keys = w[k].keys()
                for k2 in sub_keys:
                    assert k == k2
                    for k3 in w[k][k2].keys():
                        name_ = k + '/' + k3
                        val_ = w[k][k2][k3][()]
                        params[name_] = val_
                        if np.any(np.isnan(val_)):
                            print('NAN VAR FOUND: ')
                            print(name_)
                            #val_[np.where(np.isnan(val_))] = 0.0
                        if np.any(np.isinf(val_)):
                            print('INF VAR FOUND: ')
                            print(name_)
                            #val_[np.where(np.isnan(val_))] = 0.0
    
        #print(params)
        vars = tf.trainable_variables()
        assert len(params) == len(vars)
    
        sess = tf.Session()
    
        t_deconv6 = sess.graph.get_tensor_by_name('deconv6/Relu:0')
    
        for i in range(len(vars)):
            sess.run(vars[i].assign(params[vars[i].name]))
    
        # check all uninitialized variables
        var_unset = tf.report_uninitialized_variables(tf.global_variables())
        print(sess.run(var_unset))
        #fine.load_weights(model_path)
        t_out = fine.outputs[0]
        t_in = fine.inputs[0]
    
        files_rgbd = glob.glob('../Datasets/DIM/test/rgbd-1/*.PNG')
        #files_gt = glob.glob('../Datasets/DIM/test/gt/*.PNG')
        #assert len(files_gt) == len(files_rgbd)
        x_test = np.zeros([1, img_rows, img_cols, 4], dtype=np.float32)
    
        for i in range(len(files_rgbd)):
            print(files_rgbd[i])
            rgbd = cv.imread(files_rgbd[i], -1)
            rgbd = cv.resize(rgbd, (img_cols, img_rows))
    
            # alpha = cv.imread(files_gt[i], 0)
            # alpha = cv.resize(alpha, (img_cols, img_rows))
            #rgbd[:,:,3], alpha = depth_random_scale_shift(rgbd[:,:,3], alpha, 255)
    
            rgb = rgbd[:,:,:3]
            disp = np.stack([rgbd[:,:,3]]*3, axis=-1)
            #gt = np.stack([alpha]*3, axis=-1)
    
            x_test[0, :, :, :4] = normalize_input(rgbd)
            out, deconv6 = sess.run([t_out, t_deconv6], feed_dict={t_in: x_test})
            #print(deconv6)
            #print(out)
            #exit(0)
    
            out = denormalize_output(out[0,:,:,0])
            out = np.stack([out] * 3, axis=-1)
            #merged = np.concatenate((rgb, disp, out, gt), axis=1)
            merged = np.concatenate((rgb, disp, out), axis=1)
            cv.imwrite(files_rgbd[i].replace('rgbd-1', 'out'), merged)
    

    上述代码展示的功能就是从权重文件中导入所有的权重,并且依次检查每个权重的正确性(合法性),是否包含INF或者NAN这类无效值,发现时,打印对应的名称和其值。
    同时上述代码也给出了结合Tensorflow,手动利用解析出来的变量-权重字典,对计算图进行变量初始化。

  • 相关阅读:
    团队项目-选题报告
    第一次结对编程作业
    第一次个人编程作业
    第一次博客作业
    Java web的读取Excel简单Demo
    Java一些常见的出错异常处理
    JSTL截取字符串
    DATAX动态参数数据传递
    DataX实现oracle到oracle之间的数据传递
    DataX安装环境搭建
  • 原文地址:https://www.cnblogs.com/thisisajoke/p/13932517.html
Copyright © 2011-2022 走看看