model.py分析
import os import math import glob import torch import torch.nn as nn import torchvision from . import networks from . import utils from .renderer import Renderer EPS = 1e-7 class Unsup3D(): def __init__(self, cfgs): self.model_name = cfgs.get('model_name', self.__class__.__name__) self.device = cfgs.get('device', 'cpu') self.image_size = cfgs.get('image_size', 64) self.min_depth = cfgs.get('min_depth', 0.9) self.max_depth = cfgs.get('max_depth', 1.1) self.border_depth = cfgs.get('border_depth', (0.7*self.max_depth + 0.3*self.min_depth)) self.min_amb_light = cfgs.get('min_amb_light', 0.) self.max_amb_light = cfgs.get('max_amb_light', 1.) self.min_diff_light = cfgs.get('min_diff_light', 0.) self.max_diff_light = cfgs.get('max_diff_light', 1.) self.xyz_rotation_range = cfgs.get('xyz_rotation_range', 60) self.xy_translation_range = cfgs.get('xy_translation_range', 0.1) self.z_translation_range = cfgs.get('z_translation_range', 0.1) self.use_conf_map = cfgs.get('use_conf_map', True) self.lam_perc = cfgs.get('lam_perc', 1) self.lam_flip = cfgs.get('lam_flip', 0.5) self.lam_flip_start_epoch = cfgs.get('lam_flip_start_epoch', 0) self.lam_depth_sm = cfgs.get('lam_depth_sm', 0) self.lr = cfgs.get('lr', 1e-4) self.load_gt_depth = cfgs.get('load_gt_depth', False) self.renderer = Renderer(cfgs) ## networks and optimizers self.netD = networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None) self.netA = networks.EDDeconv(cin=3, cout=3, nf=64, zdim=256) self.netL = networks.Encoder(cin=3, cout=4, nf=32) self.netV = networks.Encoder(cin=3, cout=6, nf=32) if self.use_conf_map: self.netC = networks.ConfNet(cin=3, cout=2, nf=64, zdim=128) self.network_names = [k for k in vars(self) if 'net' in k] self.make_optimizer = lambda model: torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) ## other parameters self.PerceptualLoss = networks.PerceptualLoss(requires_grad=False) self.other_param_names = ['PerceptualLoss'] ## depth rescaler: -1~1 -> min_deph~max_deph self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth self.amb_light_rescaler = lambda x : (1+x)/2 *self.max_amb_light + (1-x)/2 *self.min_amb_light self.diff_light_rescaler = lambda x : (1+x)/2 *self.max_diff_light + (1-x)/2 *self.min_diff_light def init_optimizers(self): self.optimizer_names = [] for net_name in self.network_names: optimizer = self.make_optimizer(getattr(self, net_name)) optim_name = net_name.replace('net','optimizer') setattr(self, optim_name, optimizer) self.optimizer_names += [optim_name] def load_model_state(self, cp): for k in cp: if k and k in self.network_names: getattr(self, k).load_state_dict(cp[k]) def load_optimizer_state(self, cp): for k in cp: if k and k in self.optimizer_names: getattr(self, k).load_state_dict(cp[k]) def get_model_state(self): states = {} for net_name in self.network_names: states[net_name] = getattr(self, net_name).state_dict() return states def get_optimizer_state(self): states = {} for optim_name in self.optimizer_names: states[optim_name] = getattr(self, optim_name).state_dict() return states def to_device(self, device): self.device = device for net_name in self.network_names: setattr(self, net_name, getattr(self, net_name).to(device)) if self.other_param_names: for param_name in self.other_param_names: setattr(self, param_name, getattr(self, param_name).to(device)) def set_train(self): for net_name in self.network_names: getattr(self, net_name).train() def set_eval(self): for net_name in self.network_names: getattr(self, net_name).eval() def photometric_loss(self, im1, im2, mask=None, conf_sigma=None): loss = (im1-im2).abs() if conf_sigma is not None: loss = loss *2**0.5 / (conf_sigma +EPS) + (conf_sigma +EPS).log() if mask is not None: mask = mask.expand_as(loss) loss = (loss * mask).sum() / mask.sum() else: loss = loss.mean() return loss def backward(self): for optim_name in self.optimizer_names: getattr(self, optim_name).zero_grad() self.loss_total.backward() for optim_name in self.optimizer_names: getattr(self, optim_name).step() def forward(self, input): """Feedforward once.""" if self.load_gt_depth: input, depth_gt = input self.input_im = input.to(self.device) *2.-1. b, c, h, w = self.input_im.shape ## predict canonical depth, ok # networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None) self.canon_depth_raw = self.netD(self.input_im).squeeze(1) # BxHxW self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1) self.canon_depth = self.canon_depth.tanh() self.canon_depth = self.depth_rescaler(self.canon_depth) ## optional depth smoothness loss (only used in synthetic car experiments), # neighbor pixels are smooth in height and width self.loss_depth_sm = ((self.canon_depth[:,:-1,:] - self.canon_depth[:,1:,:]) /(self.max_depth-self.min_depth)).abs().mean() self.loss_depth_sm += ((self.canon_depth[:,:,:-1] - self.canon_depth[:,:,1:]) /(self.max_depth-self.min_depth)).abs().mean() ## clamp border depth, masked border,ok depth_border = torch.zeros(1,h,w-4).to(self.input_im.device) depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1) self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth self.canon_depth = torch.cat([self.canon_depth, self.canon_depth.flip(2)], 0) # flip, because depth is a 3-dim tensor ## predict canonical albedo, ok self.canon_albedo = self.netA(self.input_im) # Bx3xHxW self.canon_albedo = torch.cat([self.canon_albedo, self.canon_albedo.flip(3)], 0) # flip, albedo is a 4-dim tensor ## predict confidence map, share backbone, two-stream net if self.use_conf_map: conf_sigma_l1, conf_sigma_percl = self.netC(self.input_im) # Bx2xHxW self.conf_sigma_l1 = conf_sigma_l1[:,:1] self.conf_sigma_l1_flip = conf_sigma_l1[:,1:] self.conf_sigma_percl = conf_sigma_percl[:,:1] self.conf_sigma_percl_flip = conf_sigma_percl[:,1:] else: self.conf_sigma_l1 = None self.conf_sigma_l1_flip = None self.conf_sigma_percl = None self.conf_sigma_percl_flip = None ####################################### II(^(a,d,l),d,w): ^(a,d,l); II(^,d,w) camera model # step 1: ## predict lighting, predict by NetL, ## get: canon_light_a, canon_light_b, canon_light_d canon_light = self.netL(self.input_im).repeat(2,1) # Bx4, self.netL = networks.Encoder(cin=3, cout=4, nf=32) self.canon_light_a = self.amb_light_rescaler(canon_light[:,:1]) # ambience term, 0, Ks, each pixel has a Ks? self.canon_light_b = self.diff_light_rescaler(canon_light[:,1:2]) # diffuse term, 1, Kd, each pixel has a Kd? canon_light_dxy = canon_light[:,2:], # 2 and 3 self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b*2,1).to(self.input_im.device)], 1) self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5 # diffuse light direction ## shading # (canon_normal+canon_light_d+canon_light_a+canon_light_b)*canon_albedo self.canon_normal = self.renderer.get_normal_from_depth(self.canon_depth) #depth to norm, based on u, v directional normal vectors self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1) # dot (n,l) canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading # ks+kd.dot(n,l) self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1 #canon_albedo.canon_shading ############################ # step 2: ## predict viewpoint transformation self.view = self.netV(self.input_im).repeat(2,1) # networks.Encoder(cin=3, cout=6, nf=32), R and T self.view = torch.cat([ self.view[:,:3] *math.pi/180 *self.xyz_rotation_range, # 0,1,2 self.view[:,3:5] *self.xy_translation_range, # 3,4 self.view[:,5:] *self.z_translation_range], 1) # 5 ## reconstruct input view self.renderer.set_transform_matrices(self.view) self.recon_depth = self.renderer.warp_canon_depth(self.canon_depth) #torch.Size([128, 64, 64]) self.recon_normal = self.renderer.get_normal_from_depth(self.recon_depth) #torch.Size([128, 64, 64, 3]), get new normal as before grid_2d_from_canon = self.renderer.get_inv_warped_2d_grid(self.recon_depth) #torch.Size([128, 64, 64, 2]), new 2d grid because of wrap # grid_2d_from_canon is wrapped. we have to compute new a uniform quantized 2d grid (image) self.recon_im = nn.functional.grid_sample(self.canon_im, grid_2d_from_canon, mode='bilinear') #canon_im: torch.Size([128, 3, 64, 64]) ## mask out boder pixels margin = (self.max_depth - self.min_depth) /2 recon_im_mask = (self.recon_depth < self.max_depth+margin).float() # invalid border pixels have been clamped at max_depth+margin recon_im_mask_both = recon_im_mask[:b] * recon_im_mask[b:] # both original and flip reconstruction recon_im_mask_both = recon_im_mask_both.repeat(2,1,1).unsqueeze(1).detach() self.recon_im = self.recon_im * recon_im_mask_both ####################################### ## render symmetry axis, for visualization? canon_sym_axis = torch.zeros(h, w).to(self.input_im.device) canon_sym_axis[:, w//2-1:w//2+1] = 1 self.recon_sym_axis = nn.functional.grid_sample(canon_sym_axis.repeat(b*2,1,1,1), grid_2d_from_canon, mode='bilinear') self.recon_sym_axis = self.recon_sym_axis * recon_im_mask_both green = torch.FloatTensor([-1,1,-1]).to(self.input_im.device).view(1,3,1,1) self.input_im_symline = (0.5*self.recon_sym_axis) *green + (1-0.5*self.recon_sym_axis) *self.input_im.repeat(2,1,1,1) ## loss function self.loss_l1_im = self.photometric_loss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_l1) self.loss_l1_im_flip = self.photometric_loss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_l1_flip) self.loss_perc_im = self.PerceptualLoss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_percl) self.loss_perc_im_flip = self.PerceptualLoss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_percl_flip) lam_flip = 1 if self.trainer.current_epoch < self.lam_flip_start_epoch else self.lam_flip self.loss_total = self.loss_l1_im + lam_flip*self.loss_l1_im_flip + self.lam_perc*(self.loss_perc_im + lam_flip*self.loss_perc_im_flip) + self.lam_depth_sm*self.loss_depth_sm metrics = {'loss': self.loss_total} ## compute accuracy if gt depth is available if self.load_gt_depth: self.depth_gt = depth_gt[:,0,:,:].to(self.input_im.device) self.depth_gt = (1-self.depth_gt)*2-1 self.depth_gt = self.depth_rescaler(self.depth_gt) self.normal_gt = self.renderer.get_normal_from_depth(self.depth_gt) # mask out background mask_gt = (self.depth_gt<self.depth_gt.max()).float() mask_gt = (nn.functional.avg_pool2d(mask_gt.unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float() # erode by 1 pixel mask_pred = (nn.functional.avg_pool2d(recon_im_mask[:b].unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float() # erode by 1 pixel mask = mask_gt * mask_pred self.acc_mae_masked = ((self.recon_depth[:b] - self.depth_gt[:b]).abs() *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1) self.acc_mse_masked = (((self.recon_depth[:b] - self.depth_gt[:b])**2) *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1) self.sie_map_masked = utils.compute_sc_inv_err(self.recon_depth[:b].log(), self.depth_gt[:b].log(), mask=mask) self.acc_sie_masked = (self.sie_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1))**0.5 self.norm_err_map_masked = utils.compute_angular_distance(self.recon_normal[:b], self.normal_gt[:b], mask=mask) self.acc_normal_masked = self.norm_err_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1) metrics['SIE_masked'] = self.acc_sie_masked.mean() metrics['NorErr_masked'] = self.acc_normal_masked.mean() return metrics def visualize(self, logger, total_iter, max_bs=25): b, c, h, w = self.input_im.shape b0 = min(max_bs, b) ## render rotations with torch.no_grad(): v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b0,1) canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b0], self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W) canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b0].permute(0,3,1,2), self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5 # (B,T,C,H,W) input_im = self.input_im[:b0].detach().cpu().numpy() /2+0.5 input_im_symline = self.input_im_symline[:b0].detach().cpu() /2.+0.5 canon_albedo = self.canon_albedo[:b0].detach().cpu() /2.+0.5 canon_im = self.canon_im[:b0].detach().cpu() /2.+0.5 recon_im = self.recon_im[:b0].detach().cpu() /2.+0.5 recon_im_flip = self.recon_im[b:b+b0].detach().cpu() /2.+0.5 canon_depth_raw_hist = self.canon_depth_raw.detach().unsqueeze(1).cpu() canon_depth_raw = self.canon_depth_raw[:b0].detach().unsqueeze(1).cpu() /2.+0.5 canon_depth = ((self.canon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) recon_depth = ((self.recon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) canon_diffuse_shading = self.canon_diffuse_shading[:b0].detach().cpu() canon_normal = self.canon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 recon_normal = self.recon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 if self.use_conf_map: conf_map_l1 = 1/(1+self.conf_sigma_l1[:b0].detach().cpu()+EPS) conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b0].detach().cpu()+EPS) conf_map_percl = 1/(1+self.conf_sigma_percl[:b0].detach().cpu()+EPS) conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b0].detach().cpu()+EPS) canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_im_rotate, 1)] # [(C,H,W)]*T canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W) canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_normal_rotate, 1)] # [(C,H,W)]*T canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0) # (1,T,C,H,W) ## write summary logger.add_scalar('Loss/loss_total', self.loss_total, total_iter) logger.add_scalar('Loss/loss_l1_im', self.loss_l1_im, total_iter) logger.add_scalar('Loss/loss_l1_im_flip', self.loss_l1_im_flip, total_iter) logger.add_scalar('Loss/loss_perc_im', self.loss_perc_im, total_iter) logger.add_scalar('Loss/loss_perc_im_flip', self.loss_perc_im_flip, total_iter) logger.add_scalar('Loss/loss_depth_sm', self.loss_depth_sm, total_iter) logger.add_histogram('Depth/canon_depth_raw_hist', canon_depth_raw_hist, total_iter) vlist = ['view_rx', 'view_ry', 'view_rz', 'view_tx', 'view_ty', 'view_tz'] for i in range(self.view.shape[1]): logger.add_histogram('View/'+vlist[i], self.view[:,i], total_iter) logger.add_histogram('Light/canon_light_a', self.canon_light_a, total_iter) logger.add_histogram('Light/canon_light_b', self.canon_light_b, total_iter) llist = ['canon_light_dx', 'canon_light_dy', 'canon_light_dz'] for i in range(self.canon_light_d.shape[1]): logger.add_histogram('Light/'+llist[i], self.canon_light_d[:,i], total_iter) def log_grid_image(label, im, nrow=int(math.ceil(b0**0.5)), iter=total_iter): im_grid = torchvision.utils.make_grid(im, nrow=nrow) logger.add_image(label, im_grid, iter) log_grid_image('Image/input_image_symline', input_im_symline) log_grid_image('Image/canonical_albedo', canon_albedo) log_grid_image('Image/canonical_image', canon_im) log_grid_image('Image/recon_image', recon_im) log_grid_image('Image/recon_image_flip', recon_im_flip) log_grid_image('Image/recon_side', canon_im_rotate[:,0,:,:,:]) log_grid_image('Depth/canonical_depth_raw', canon_depth_raw) log_grid_image('Depth/canonical_depth', canon_depth) log_grid_image('Depth/recon_depth', recon_depth) log_grid_image('Depth/canonical_diffuse_shading', canon_diffuse_shading) log_grid_image('Depth/canonical_normal', canon_normal) log_grid_image('Depth/recon_normal', recon_normal) logger.add_histogram('Image/canonical_albedo_hist', canon_albedo, total_iter) logger.add_histogram('Image/canonical_diffuse_shading_hist', canon_diffuse_shading, total_iter) if self.use_conf_map: log_grid_image('Conf/conf_map_l1', conf_map_l1) logger.add_histogram('Conf/conf_sigma_l1_hist', self.conf_sigma_l1, total_iter) log_grid_image('Conf/conf_map_l1_flip', conf_map_l1_flip) logger.add_histogram('Conf/conf_sigma_l1_flip_hist', self.conf_sigma_l1_flip, total_iter) log_grid_image('Conf/conf_map_percl', conf_map_percl) logger.add_histogram('Conf/conf_sigma_percl_hist', self.conf_sigma_percl, total_iter) log_grid_image('Conf/conf_map_percl_flip', conf_map_percl_flip) logger.add_histogram('Conf/conf_sigma_percl_flip_hist', self.conf_sigma_percl_flip, total_iter) logger.add_video('Image_rotate/recon_rotate', canon_im_rotate_grid, total_iter, fps=4) logger.add_video('Image_rotate/canon_normal_rotate', canon_normal_rotate_grid, total_iter, fps=4) # visualize images and accuracy if gt is loaded if self.load_gt_depth: depth_gt = ((self.depth_gt[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1) normal_gt = self.normal_gt.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5 sie_map_masked = self.sie_map_masked[:b0].detach().unsqueeze(1).cpu() *1000 norm_err_map_masked = self.norm_err_map_masked[:b0].detach().unsqueeze(1).cpu() /100 logger.add_scalar('Acc_masked/MAE_masked', self.acc_mae_masked.mean(), total_iter) logger.add_scalar('Acc_masked/MSE_masked', self.acc_mse_masked.mean(), total_iter) logger.add_scalar('Acc_masked/SIE_masked', self.acc_sie_masked.mean(), total_iter) logger.add_scalar('Acc_masked/NorErr_masked', self.acc_normal_masked.mean(), total_iter) log_grid_image('Depth_gt/depth_gt', depth_gt) log_grid_image('Depth_gt/normal_gt', normal_gt) log_grid_image('Depth_gt/sie_map_masked', sie_map_masked) log_grid_image('Depth_gt/norm_err_map_masked', norm_err_map_masked) def save_results(self, save_dir): b, c, h, w = self.input_im.shape with torch.no_grad(): v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b,1) canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b], self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W) canon_im_rotate = canon_im_rotate.clamp(-1,1).detach().cpu() /2+0.5 canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b].permute(0,3,1,2), self.canon_depth[:b], v_before=v0, maxr=90, nsample=15) # (B,T,C,H,W) canon_normal_rotate = canon_normal_rotate.clamp(-1,1).detach().cpu() /2+0.5 input_im = self.input_im[:b].detach().cpu().numpy() /2+0.5 input_im_symline = self.input_im_symline.detach().cpu().numpy() /2.+0.5 canon_albedo = self.canon_albedo[:b].detach().cpu().numpy() /2+0.5 canon_im = self.canon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5 recon_im = self.recon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5 recon_im_flip = self.recon_im[b:].clamp(-1,1).detach().cpu().numpy() /2+0.5 canon_depth = ((self.canon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() recon_depth = ((self.recon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() canon_diffuse_shading = self.canon_diffuse_shading[:b].detach().cpu().numpy() canon_normal = self.canon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 recon_normal = self.recon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 if self.use_conf_map: conf_map_l1 = 1/(1+self.conf_sigma_l1[:b].detach().cpu().numpy()+EPS) conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b].detach().cpu().numpy()+EPS) conf_map_percl = 1/(1+self.conf_sigma_percl[:b].detach().cpu().numpy()+EPS) conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b].detach().cpu().numpy()+EPS) canon_light = torch.cat([self.canon_light_a, self.canon_light_b, self.canon_light_d], 1)[:b].detach().cpu().numpy() view = self.view[:b].detach().cpu().numpy() canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_im_rotate,1)] # [(C,H,W)]*T canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W) canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_normal_rotate,1)] # [(C,H,W)]*T canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0).numpy() # (1,T,C,H,W) sep_folder = True utils.save_images(save_dir, input_im, suffix='input_image', sep_folder=sep_folder) utils.save_images(save_dir, input_im_symline, suffix='input_image_symline', sep_folder=sep_folder) utils.save_images(save_dir, canon_albedo, suffix='canonical_albedo', sep_folder=sep_folder) utils.save_images(save_dir, canon_im, suffix='canonical_image', sep_folder=sep_folder) utils.save_images(save_dir, recon_im, suffix='recon_image', sep_folder=sep_folder) utils.save_images(save_dir, recon_im_flip, suffix='recon_image_flip', sep_folder=sep_folder) utils.save_images(save_dir, canon_depth, suffix='canonical_depth', sep_folder=sep_folder) utils.save_images(save_dir, recon_depth, suffix='recon_depth', sep_folder=sep_folder) utils.save_images(save_dir, canon_diffuse_shading, suffix='canonical_diffuse_shading', sep_folder=sep_folder) utils.save_images(save_dir, canon_normal, suffix='canonical_normal', sep_folder=sep_folder) utils.save_images(save_dir, recon_normal, suffix='recon_normal', sep_folder=sep_folder) if self.use_conf_map: utils.save_images(save_dir, conf_map_l1, suffix='conf_map_l1', sep_folder=sep_folder) utils.save_images(save_dir, conf_map_l1_flip, suffix='conf_map_l1_flip', sep_folder=sep_folder) utils.save_images(save_dir, conf_map_percl, suffix='conf_map_percl', sep_folder=sep_folder) utils.save_images(save_dir, conf_map_percl_flip, suffix='conf_map_percl_flip', sep_folder=sep_folder) utils.save_txt(save_dir, canon_light, suffix='canonical_light', sep_folder=sep_folder) utils.save_txt(save_dir, view, suffix='viewpoint', sep_folder=sep_folder) utils.save_videos(save_dir, canon_im_rotate_grid, suffix='image_video', sep_folder=sep_folder, cycle=True) utils.save_videos(save_dir, canon_normal_rotate_grid, suffix='normal_video', sep_folder=sep_folder, cycle=True) # save scores if gt is loaded if self.load_gt_depth: depth_gt = ((self.depth_gt[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy() normal_gt = self.normal_gt[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5 utils.save_images(save_dir, depth_gt, suffix='depth_gt', sep_folder=sep_folder) utils.save_images(save_dir, normal_gt, suffix='normal_gt', sep_folder=sep_folder) all_scores = torch.stack([ self.acc_mae_masked.detach().cpu(), self.acc_mse_masked.detach().cpu(), self.acc_sie_masked.detach().cpu(), self.acc_normal_masked.detach().cpu()], 1) if not hasattr(self, 'all_scores'): self.all_scores = torch.FloatTensor() self.all_scores = torch.cat([self.all_scores, all_scores], 0) def save_scores(self, path): # save scores if gt is loaded if self.load_gt_depth: header = 'MAE_masked, MSE_masked, SIE_masked, NorErr_masked' mean = self.all_scores.mean(0) std = self.all_scores.std(0) header = header + ' Mean: ' + ', '.join(['%.8f'%x for x in mean]) header = header + ' Std: ' + ', '.join(['%.8f'%x for x in std]) utils.save_scores(path, self.all_scores, header=header)
render.py
import torch import math import neural_renderer as nr from .utils import * EPS = 1e-7 class Renderer(): def __init__(self, cfgs): self.device = cfgs.get('device', 'cpu') self.image_size = cfgs.get('image_size', 64) self.min_depth = cfgs.get('min_depth', 0.9) self.max_depth = cfgs.get('max_depth', 1.1) self.rot_center_depth = cfgs.get('rot_center_depth', (self.min_depth+self.max_depth)/2) self.fov = cfgs.get('fov', 10) self.tex_cube_size = cfgs.get('tex_cube_size', 2) self.renderer_min_depth = cfgs.get('renderer_min_depth', 0.1) self.renderer_max_depth = cfgs.get('renderer_max_depth', 10.) #### camera intrinsics # (u) (x) # d * K^-1 (v) = (y) # (1) (z) ## renderer for visualization R = [[[1.,0.,0.], [0.,1.,0.], [0.,0.,1.]]] R = torch.FloatTensor(R).to(self.device) t = torch.zeros(1,3, dtype=torch.float32).to(self.device) fx = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180)) fy = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180)) cx = (self.image_size-1)/2 cy = (self.image_size-1)/2 K = [[fx, 0., cx], [0., fy, cy], [0., 0., 1.]] K = torch.FloatTensor(K).to(self.device) self.inv_K = torch.inverse(K).unsqueeze(0) self.K = K.unsqueeze(0) self.renderer = nr.Renderer(camera_mode='projection', light_intensity_ambient=1.0, light_intensity_directional=0., K=self.K, R=R, t=t, near=self.renderer_min_depth, far=self.renderer_max_depth, image_size=self.image_size, orig_size=self.image_size, fill_back=True, background_color=[1,1,1]) def set_transform_matrices(self, view): self.rot_mat, self.trans_xyz = get_transform_matrices(view) def rotate_pts(self, pts, rot_mat): centroid = torch.FloatTensor([0.,0.,self.rot_center_depth]).to(pts.device).view(1,1,3) pts = pts - centroid # move to centroid pts = pts.matmul(rot_mat.transpose(2,1)) # rotate pts = pts + centroid # move back return pts def translate_pts(self, pts, trans_xyz): return pts + trans_xyz def depth_to_3d_grid(self, depth): #### 3, compute 2d grid--->3d depth, (x,y, depth), xin [],yin [] b, h, w = depth.shape grid_2d = get_grid(b, h, w, normalize=False).to(depth.device) # Nxhxwx2, torch.Size([128, 64, 64, 2]) depth = depth.unsqueeze(-1) # torch.Size([128, 64, 64, 1]) grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3) # torch.Size([128, 64, 64, 3]) grid_3d = grid_3d.matmul(self.inv_K.to(depth.device).transpose(2,1)) * depth return grid_3d def grid_3d_to_2d(self, grid_3d): b, h, w, _ = grid_3d.shape grid_2d = grid_3d / grid_3d[...,2:] grid_2d = grid_2d.matmul(self.K.to(grid_3d.device).transpose(2,1))[:,:,:,:2] WH = torch.FloatTensor([w-1, h-1]).to(grid_3d.device).view(1,1,1,2) grid_2d = grid_2d / WH *2.-1. # normalize to -1~1 return grid_2d def get_warped_3d_grid(self, depth): #### 2 b, h, w = depth.shape #torch.Size([128, 64, 64]) grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3) #torch.Size([128, 4096, 3]), has K_inv, 64x64 3D points grid_3d = self.rotate_pts(grid_3d, self.rot_mat) # rotate and translate grid_3d = self.translate_pts(grid_3d, self.trans_xyz) # the state of each imaget return grid_3d.reshape(b,h,w,3) # return 3d vertices def get_inv_warped_3d_grid(self, depth): b, h, w = depth.shape grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3) grid_3d = self.translate_pts(grid_3d, -self.trans_xyz) grid_3d = self.rotate_pts(grid_3d, self.rot_mat.transpose(2,1)) return grid_3d.reshape(b,h,w,3) # return 3d vertices def get_warped_2d_grid(self, depth): b, h, w = depth.shape grid_3d = self.get_warped_3d_grid(depth) grid_2d = self.grid_3d_to_2d(grid_3d) return grid_2d def get_inv_warped_2d_grid(self, depth): b, h, w = depth.shape grid_3d = self.get_inv_warped_3d_grid(depth) grid_2d = self.grid_3d_to_2d(grid_3d) return grid_2d def warp_canon_depth(self, canon_depth): #### 1 b, h, w = canon_depth.shape grid_3d = self.get_warped_3d_grid(canon_depth).reshape(b,-1,3) # get 3d vertices: including generate (x,y, z) mesh, reshape faces = get_face_idx(b, h, w).to(canon_depth.device) # mesh into 3d points, R and T transformations warped_depth = self.renderer.render_depth(grid_3d, faces) # new depth # allow some margin out of valid range margin = (self.max_depth - self.min_depth) /2 warped_depth = warped_depth.clamp(min=self.min_depth-margin, max=self.max_depth+margin) return warped_depth def get_normal_from_depth(self, depth): b, h, w = depth.shape grid_3d = self.depth_to_3d_grid(depth) tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2] tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1] normal = tu.cross(tv, dim=3) zero = torch.FloatTensor([0,0,1]).to(depth.device) normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2) normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1) normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS) return normal def render_yaw(self, im, depth, v_before=None, v_after=None, rotations=None, maxr=90, nsample=9, crop_mesh=None): b, c, h, w = im.shape grid_3d = self.depth_to_3d_grid(depth) if crop_mesh is not None: top, bottom, left, right = crop_mesh # pixels from border to be cropped if top > 0: grid_3d[:,:top,:,1] = grid_3d[:,top:top+1,:,1].repeat(1,top,1) grid_3d[:,:top,:,2] = grid_3d[:,top:top+1,:,2].repeat(1,top,1) if bottom > 0: grid_3d[:,-bottom:,:,1] = grid_3d[:,-bottom-1:-bottom,:,1].repeat(1,bottom,1) grid_3d[:,-bottom:,:,2] = grid_3d[:,-bottom-1:-bottom,:,2].repeat(1,bottom,1) if left > 0: grid_3d[:,:,:left,0] = grid_3d[:,:,left:left+1,0].repeat(1,1,left) grid_3d[:,:,:left,2] = grid_3d[:,:,left:left+1,2].repeat(1,1,left) if right > 0: grid_3d[:,:,-right:,0] = grid_3d[:,:,-right-1:-right,0].repeat(1,1,right) grid_3d[:,:,-right:,2] = grid_3d[:,:,-right-1:-right,2].repeat(1,1,right) grid_3d = grid_3d.reshape(b,-1,3) im_trans = [] # inverse warp if v_before is not None: rot_mat, trans_xyz = get_transform_matrices(v_before) grid_3d = self.translate_pts(grid_3d, -trans_xyz) grid_3d = self.rotate_pts(grid_3d, rot_mat.transpose(2,1)) if rotations is None: rotations = torch.linspace(-math.pi/180*maxr, math.pi/180*maxr, nsample) for i, ri in enumerate(rotations): ri = torch.FloatTensor([0, ri, 0]).to(im.device).view(1,3) rot_mat_i, _ = get_transform_matrices(ri) grid_3d_i = self.rotate_pts(grid_3d, rot_mat_i.repeat(b,1,1)) if v_after is not None: if len(v_after.shape) == 3: v_after_i = v_after[i] else: v_after_i = v_after rot_mat, trans_xyz = get_transform_matrices(v_after_i) grid_3d_i = self.rotate_pts(grid_3d_i, rot_mat) grid_3d_i = self.translate_pts(grid_3d_i, trans_xyz) faces = get_face_idx(b, h, w).to(im.device) textures = get_textures_from_im(im, tx_size=self.tex_cube_size) warped_images = self.renderer.render_rgb(grid_3d_i, faces, textures).clamp(min=-1., max=1.) im_trans += [warped_images] return torch.stack(im_trans, 1) # b x t x c x h x w
utils.py
import torch def mm_normalize(x, min=0, max=1): x_min = x.min() x_max = x.max() x_range = x_max - x_min x_z = (x - x_min) / x_range x_out = x_z * (max - min) + min return x_out def rand_range(size, min, max): return torch.rand(size)*(max-min)+min def rand_posneg_range(size, min, max): i = (torch.rand(size) > 0.5).type(torch.float)*2.-1. return i*rand_range(size, min, max) def get_grid(b, H, W, normalize=True): #### if normalize: h_range = torch.linspace(-1,1,H) w_range = torch.linspace(-1,1,W) else: h_range = torch.arange(0,H) w_range = torch.arange(0,W) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y return grid def get_rotation_matrix(tx, ty, tz): m_x = torch.zeros((len(tx), 3, 3)).to(tx.device) m_y = torch.zeros((len(tx), 3, 3)).to(tx.device) m_z = torch.zeros((len(tx), 3, 3)).to(tx.device) m_x[:, 1, 1], m_x[:, 1, 2] = tx.cos(), -tx.sin() m_x[:, 2, 1], m_x[:, 2, 2] = tx.sin(), tx.cos() m_x[:, 0, 0] = 1 m_y[:, 0, 0], m_y[:, 0, 2] = ty.cos(), ty.sin() m_y[:, 2, 0], m_y[:, 2, 2] = -ty.sin(), ty.cos() m_y[:, 1, 1] = 1 m_z[:, 0, 0], m_z[:, 0, 1] = tz.cos(), -tz.sin() m_z[:, 1, 0], m_z[:, 1, 1] = tz.sin(), tz.cos() m_z[:, 2, 2] = 1 return torch.matmul(m_z, torch.matmul(m_y, m_x)) def get_transform_matrices(view): b = view.size(0) if view.size(1) == 6: rx = view[:,0] ry = view[:,1] rz = view[:,2] trans_xyz = view[:,3:].reshape(b,1,3) elif view.size(1) == 5: rx = view[:,0] ry = view[:,1] rz = view[:,2] delta_xy = view[:,3:].reshape(b,1,2) trans_xyz = torch.cat([delta_xy, torch.zeros(b,1,1).to(view.device)], 2) elif view.size(1) == 3: rx = view[:,0] ry = view[:,1] rz = view[:,2] trans_xyz = torch.zeros(b,1,3).to(view.device) rot_mat = get_rotation_matrix(rx, ry, rz) return rot_mat, trans_xyz def get_face_idx(b, h, w): idx_map = torch.arange(h*w).reshape(h,w) faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map[:h-1,1:]], -1).reshape(-1,3) faces2 = torch.stack([idx_map[:h-1,1:], idx_map[1:,:w-1], idx_map[1:,1:]], -1).reshape(-1,3) return torch.cat([faces1,faces2], 0).repeat(b,1,1).int() def vcolor_to_texture_cube(vcolors): # input bxcxnx3 b, c, n, f = vcolors.shape coeffs = torch.FloatTensor( [[ 0.5, 0.5, 0.5], [ 0. , 0. , 1. ], [ 0. , 1. , 0. ], [-0.5, 0.5, 0.5], [ 1. , 0. , 0. ], [ 0.5, -0.5, 0.5], [ 0.5, 0.5, -0.5], [ 0. , 0. , 0. ]]).to(vcolors.device) return coeffs.matmul(vcolors.permute(0,2,3,1)).reshape(b,n,2,2,2,c) def get_textures_from_im(im, tx_size=1): b, c, h, w = im.shape if tx_size == 1: textures = torch.cat([im[:,:,:h-1,:w-1].reshape(b,c,-1), im[:,:,1:,1:].reshape(b,c,-1)], 2) textures = textures.transpose(2,1).reshape(b,-1,1,1,1,c) elif tx_size == 2: textures1 = torch.stack([im[:,:,:h-1,:w-1], im[:,:,:h-1,1:], im[:,:,1:,:w-1]], -1).reshape(b,c,-1,3) textures2 = torch.stack([im[:,:,1:,:w-1], im[:,:,:h-1,1:], im[:,:,1:,1:]], -1).reshape(b,c,-1,3) textures = vcolor_to_texture_cube(torch.cat([textures1, textures2], 2)) # bxnx2x2x2xc else: raise NotImplementedError("Currently support texture size of 1 or 2 only.") return textures