问题
在使用Keras保存为h5或者hdf5格式的模型权重文件后,一般采用keras.models.load_model()恢复模型结构和权重,或者采用model.load_weights()导入权重。
但是,在进行迁移学习或者模型输出nan或inf时,需要手动导入部分权重进行查看或者修改之类,就不得不学会操作HDF5格式文件了。
解答
在Python中,常采用h5py库对HDF5文件进行读写操作,这是非常方便的,其导出的权重都是numpy矩阵,可以直接应用。
HDF5格式本身类似于XML或者JSON,是一种通用的树状结构文档的表示方式,通过 $ /根目录/子目录 $ 的路径形式定位元素。
话不多说,上示例代码:
# test.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import cv2 as cv
import numpy as np
import glob
from model import build_encoder_decoder, build_refinement
from data_generator import normalize_input, denormalize_output, depth_random_scale_shift
import h5py as h5
import tensorflow as tf
if __name__ == '__main__':
img_rows, img_cols = 288, 384
channel = 4
model_path = '../Models/DIM/final.00000000-0.0607.hdf5'
coarse = build_encoder_decoder(img_rows, img_cols, train=False)
fine = build_refinement(coarse, train=False)
#fine.summary()
f = h5.File(model_path, 'r')
w = f['/model_weights']
params = dict()
for k in w.keys():
if len(w[k].keys()) > 0:
sub_keys = w[k].keys()
for k2 in sub_keys:
assert k == k2
for k3 in w[k][k2].keys():
name_ = k + '/' + k3
val_ = w[k][k2][k3][()]
params[name_] = val_
if np.any(np.isnan(val_)):
print('NAN VAR FOUND: ')
print(name_)
#val_[np.where(np.isnan(val_))] = 0.0
if np.any(np.isinf(val_)):
print('INF VAR FOUND: ')
print(name_)
#val_[np.where(np.isnan(val_))] = 0.0
#print(params)
vars = tf.trainable_variables()
assert len(params) == len(vars)
sess = tf.Session()
t_deconv6 = sess.graph.get_tensor_by_name('deconv6/Relu:0')
for i in range(len(vars)):
sess.run(vars[i].assign(params[vars[i].name]))
# check all uninitialized variables
var_unset = tf.report_uninitialized_variables(tf.global_variables())
print(sess.run(var_unset))
#fine.load_weights(model_path)
t_out = fine.outputs[0]
t_in = fine.inputs[0]
files_rgbd = glob.glob('../Datasets/DIM/test/rgbd-1/*.PNG')
#files_gt = glob.glob('../Datasets/DIM/test/gt/*.PNG')
#assert len(files_gt) == len(files_rgbd)
x_test = np.zeros([1, img_rows, img_cols, 4], dtype=np.float32)
for i in range(len(files_rgbd)):
print(files_rgbd[i])
rgbd = cv.imread(files_rgbd[i], -1)
rgbd = cv.resize(rgbd, (img_cols, img_rows))
# alpha = cv.imread(files_gt[i], 0)
# alpha = cv.resize(alpha, (img_cols, img_rows))
#rgbd[:,:,3], alpha = depth_random_scale_shift(rgbd[:,:,3], alpha, 255)
rgb = rgbd[:,:,:3]
disp = np.stack([rgbd[:,:,3]]*3, axis=-1)
#gt = np.stack([alpha]*3, axis=-1)
x_test[0, :, :, :4] = normalize_input(rgbd)
out, deconv6 = sess.run([t_out, t_deconv6], feed_dict={t_in: x_test})
#print(deconv6)
#print(out)
#exit(0)
out = denormalize_output(out[0,:,:,0])
out = np.stack([out] * 3, axis=-1)
#merged = np.concatenate((rgb, disp, out, gt), axis=1)
merged = np.concatenate((rgb, disp, out), axis=1)
cv.imwrite(files_rgbd[i].replace('rgbd-1', 'out'), merged)
上述代码展示的功能就是从权重文件中导入所有的权重,并且依次检查每个权重的正确性(合法性),是否包含INF或者NAN这类无效值,发现时,打印对应的名称和其值。
同时上述代码也给出了结合Tensorflow,手动利用解析出来的变量-权重字典,对计算图进行变量初始化。