# enhance_raw.py # transform from single frame into multi-frame enhanced single raw from __future__ import division import os, time, scipy.io import tensorflow as tf import numpy as np import rawpy import glob from model_sid_latest import network_enhance_raw import platform import os from tensorflow.python.tools import freeze_graph os.environ["CUDA_VISIBLE_DEVICES"] = "0" if platform.system() == 'Windows': data_dir = 'D:/data/LightOnOff/' elif platform.system() == 'Linux': data_dir = './dataset/LightOnOff/' else: print('platform not supported!') assert False checkpoint_dir = './model_light_on_off/' result_dir = './out_light_on_off/' log_dir = './log_light_on_off/' learning_rate = 1e-4 save_model_every_n_epoch = 10 max_epoch = 20000 if platform.system() == 'Windows': save_output_every_n_steps = 1 else: save_output_every_n_steps = 100 # BBF100-2 bbf_w = 4032 bbf_h = 3024 patch_h = 512 patch_w = 512 patch_h = 800 patch_w = 1024 max_level = 1023 black_level = 64 tf.reset_default_graph() # set up dataset train_ids = os.listdir(data_dir) train_ids.sort() def preprocess(raw, bl, wl): im = raw.raw_image_visible.astype(np.float32) im = np.maximum(im - bl, 0) return im / (wl - bl) def pack_raw_bbf(path): raw = rawpy.imread(path) bl = 64 wl = 1023 im = preprocess(raw, bl, wl) im = np.expand_dims(im, axis=2) H = im.shape[0] W = im.shape[1] if raw.raw_pattern[0, 0] == 0: # CFA=RGGB out = np.concatenate((im[0:H:2, 0:W:2, :], im[0:H:2, 1:W:2, :], im[1:H:2, 1:W:2, :], im[1:H:2, 0:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 2: # BGGR out = np.concatenate((im[1:H:2, 1:W:2, :], im[0:H:2, 1:W:2, :], im[0:H:2, 0:W:2, :], im[1:H:2, 0:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG out = np.concatenate((im[0:H:2, 1:W:2, :], im[0:H:2, 0:W:2, :], im[1:H:2, 0:W:2, :], im[1:H:2, 1:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG out = np.concatenate((im[1:H:2, 0:W:2, :], im[0:H:2, 0:W:2, :], im[0:H:2, 1:W:2, :], im[1:H:2, 1:W:2, :]), axis=2) else: assert False wb = np.array(raw.camera_whitebalance) wb[3] = wb[1] wb = wb / wb[1] out = np.minimum(out * wb, 1.0) # normalize the brightness # out = np.minimum(out * 0.2 / np.maximum(1e-6, np.mean(out[:, :, 1])), 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2 out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16) out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl)) del out return out_16bit_ def raw2rgb(raw): # GRBG assert len(raw.shape)==3 h, w = raw.shape[0]<<1, raw.shape[1]<<1 rgb = np.zeros([h, w, 3]) rgb[0:h:2, 0:w:2, 1] = raw[:, :, 1] rgb[0:h:2, 1:w:2, 0] = raw[:, :, 0] rgb[1:h:2, 0:w:2, 2] = raw[:, :, 2] rgb[1:h:2, 1:w:2, 1] = raw[:, :, 3] return rgb def max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center): return np.maximum( np.maximum( np.maximum( np.maximum( np.maximum( np.maximum( np.maximum( np.maximum(left, left_top), top), top_right), right), right_bottom), bottom), bottom_left), center) def demosaic(rgb): for chn_id in range(3): left = rgb[0:-2, 1:-1, chn_id] left_top = rgb[0:-2, 0:-2, chn_id] top = rgb[0:-2, 1:-1, chn_id] top_right = rgb[0:-2, 2:, chn_id] right = rgb[1:-1, 2:, chn_id] right_bottom = rgb[2:, 2:, chn_id] bottom = rgb[2:, 1:-1, chn_id] bottom_left = rgb[2:, 0:-2, chn_id] center = rgb[1:-1, 1:-1, chn_id] rgb[1:-1, 1:-1, chn_id] = max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center) return rgb def gray_ps(rgb): return np.power(np.power(rgb[:, :, 0], 2.2) * 0.2973 + np.power(rgb[:,:,1], 2.2) * 0.6274 + np.power(rgb[:,:,2], 2.2) * 0.0753, 1/2.2) + 1e-7 def gamma_correction(x, curve_ratio): gray_scale = np.expand_dims(gray_ps(x), axis=-1) gray_scale_new = np.power(gray_scale, curve_ratio) return np.minimum(x * gray_scale_new / gray_scale, 1.0) # setting the ratio of GPU global memory usage gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input') gt_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4]) out_im = network_enhance_raw(in_im, patch_h, patch_w) norm_im = tf.minimum(tf.maximum(out_im, 0.0), 1.0) ssim_loss = 1 - tf.image.ssim_multiscale(norm_im[0], gt_im[0], 1.0) l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(norm_im - gt_im), axis=-1)) l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(norm_im - gt_im), axis=-1)) # G_loss = ssim_loss G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('MS-SSIM Loss', ssim_loss) tf.summary.scalar('L1 Loss', l1_loss) tf.summary.scalar('L2 Loss', l2_loss) t_vars = tf.trainable_variables() lr = tf.placeholder(tf.float32) G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt: print('loaded ' + ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) # save the images for tracking training states if not os.path.isdir(result_dir): os.mkdir(result_dir) g_loss = np.zeros((500, 1)) merged = tf.summary.merge_all() writer = tf.summary.FileWriter(log_dir, sess.graph) gt_files = [None] * len(train_ids) input_files = [None] * len(train_ids) input_images = [None] * len(train_ids) gt_images = [None] * len(train_ids) for i in range(0, len(train_ids)): gt_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*on*.dng')[0] input_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*off*.dng') input_images[i] = [None] * len(input_files[i]) steps = 0 st = time.time() for epoch in range(0, max_epoch): for ind in np.random.permutation(len(train_ids)): steps += 1 sid = np.random.randint(0, len(input_files[ind])) if input_images[ind][sid] is None: input_images[ind][sid] = np.expand_dims(pack_raw_bbf(input_files[ind][sid]), axis=0) if gt_images[ind] is None: gt_images[ind] = np.expand_dims(np.maximum(pack_raw_bbf(gt_files[ind]), 0), axis=0) # random cropping xx = np.random.randint(0, bbf_w//2 - patch_w) yy = np.random.randint(0, bbf_h//2 - patch_h) input_patch = np.float32(input_images[ind][sid][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level) gt_patch = np.float32(gt_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level) # random flipping if np.random.randint(2, size=1)[0] == 1: # random flip input_patch = np.flip(input_patch, axis=1) gt_patch = np.flip(gt_patch, axis=1) if np.random.randint(2, size=1)[0] == 1: input_patch = np.flip(input_patch, axis=0) gt_patch = np.flip(gt_patch, axis=0) # if np.random.randint(2, size=1)[0] == 1: # random transpose # input_patch = np.transpose(input_patch, (0, 2, 1, 3)) # gt_patch = np.transpose(gt_patch, (0, 2, 1, 3)) # summary, _, G_current, output = sess.run( # [merged, G_opt, G_loss, out_im], # feed_dict={ # in_im: input_patch, # gt_im: gt_patch, # lr: learning_rate}) # g_loss[ind] = G_current summary, output = sess.run( [merged, out_im], feed_dict={ in_im: input_patch, gt_im: gt_patch, lr: learning_rate }) # saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch) # print('model saved.') # exit(0) tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model_raw2raw.pb') freeze_graph.freeze_graph( 'output_model/pb_model/model_raw2raw.pb', '', False, './model_light_on_off/0.ckpt', 'gen/output', 'save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', True, "") exit(0) if steps % save_output_every_n_steps == 0: loss_ = np.mean(g_loss[np.where(g_loss)]) cost_ = (time.time() - st)/save_output_every_n_steps st = time.time() print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_)) writer.add_summary(summary, global_step=steps) # save the current output image for network inspection out_ = np.minimum(np.maximum(output, 0), 1) in_rgb = gamma_correction(demosaic(raw2rgb(input_patch[0])), 0.35) gt_rgb = gamma_correction(demosaic(raw2rgb(gt_patch[0])), 0.35) out_rgb = gamma_correction(demosaic(raw2rgb(out_[0])), 0.35) temp = np.concatenate((in_rgb, gt_rgb, out_rgb), axis=1) scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255) .save(result_dir + '/%d_%s_00.jpg' % (epoch, train_ids[ind])) # clean up the memory if necessary if platform.system() == 'Windows': input_images[ind][sid] = None gt_images[ind] = None if epoch % save_model_every_n_epoch == 0: saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch) print('model saved.')