zoukankan      html  css  js  c++  java
  • cvpr 2020 best paper code 分析

    model.py分析

    import os
    import math
    import glob
    import torch
    import torch.nn as nn
    import torchvision
    from . import networks
    from . import utils
    from .renderer import Renderer
    
    
    EPS = 1e-7
    
    
    class Unsup3D():
        def __init__(self, cfgs):
            self.model_name = cfgs.get('model_name', self.__class__.__name__)
            self.device = cfgs.get('device', 'cpu')
            self.image_size = cfgs.get('image_size', 64)
            self.min_depth = cfgs.get('min_depth', 0.9)
            self.max_depth = cfgs.get('max_depth', 1.1)
            self.border_depth = cfgs.get('border_depth', (0.7*self.max_depth + 0.3*self.min_depth))
            self.min_amb_light = cfgs.get('min_amb_light', 0.)
            self.max_amb_light = cfgs.get('max_amb_light', 1.)
            self.min_diff_light = cfgs.get('min_diff_light', 0.)
            self.max_diff_light = cfgs.get('max_diff_light', 1.)
            self.xyz_rotation_range = cfgs.get('xyz_rotation_range', 60)
            self.xy_translation_range = cfgs.get('xy_translation_range', 0.1)
            self.z_translation_range = cfgs.get('z_translation_range', 0.1)
            self.use_conf_map = cfgs.get('use_conf_map', True)
            self.lam_perc = cfgs.get('lam_perc', 1)
            self.lam_flip = cfgs.get('lam_flip', 0.5)
            self.lam_flip_start_epoch = cfgs.get('lam_flip_start_epoch', 0)
            self.lam_depth_sm = cfgs.get('lam_depth_sm', 0)
            self.lr = cfgs.get('lr', 1e-4)
            self.load_gt_depth = cfgs.get('load_gt_depth', False)
            self.renderer = Renderer(cfgs)
    
            ## networks and optimizers
            self.netD = networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None)
            self.netA = networks.EDDeconv(cin=3, cout=3, nf=64, zdim=256)
            self.netL = networks.Encoder(cin=3, cout=4, nf=32)
            self.netV = networks.Encoder(cin=3, cout=6, nf=32)
            if self.use_conf_map:
                self.netC = networks.ConfNet(cin=3, cout=2, nf=64, zdim=128)
            self.network_names = [k for k in vars(self) if 'net' in k]
            self.make_optimizer = lambda model: torch.optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
    
            ## other parameters
            self.PerceptualLoss = networks.PerceptualLoss(requires_grad=False)
            self.other_param_names = ['PerceptualLoss']
    
            ## depth rescaler: -1~1 -> min_deph~max_deph
            self.depth_rescaler = lambda d : (1+d)/2 *self.max_depth + (1-d)/2 *self.min_depth
            self.amb_light_rescaler = lambda x : (1+x)/2 *self.max_amb_light + (1-x)/2 *self.min_amb_light
            self.diff_light_rescaler = lambda x : (1+x)/2 *self.max_diff_light + (1-x)/2 *self.min_diff_light
    
        def init_optimizers(self):
            self.optimizer_names = []
            for net_name in self.network_names:
                optimizer = self.make_optimizer(getattr(self, net_name))
                optim_name = net_name.replace('net','optimizer')
                setattr(self, optim_name, optimizer)
                self.optimizer_names += [optim_name]
    
        def load_model_state(self, cp):
            for k in cp:
                if k and k in self.network_names:
                    getattr(self, k).load_state_dict(cp[k])
    
        def load_optimizer_state(self, cp):
            for k in cp:
                if k and k in self.optimizer_names:
                    getattr(self, k).load_state_dict(cp[k])
    
        def get_model_state(self):
            states = {}
            for net_name in self.network_names:
                states[net_name] = getattr(self, net_name).state_dict()
            return states
    
        def get_optimizer_state(self):
            states = {}
            for optim_name in self.optimizer_names:
                states[optim_name] = getattr(self, optim_name).state_dict()
            return states
    
        def to_device(self, device):
            self.device = device
            for net_name in self.network_names:
                setattr(self, net_name, getattr(self, net_name).to(device))
            if self.other_param_names:
                for param_name in self.other_param_names:
                    setattr(self, param_name, getattr(self, param_name).to(device))
    
        def set_train(self):
            for net_name in self.network_names:
                getattr(self, net_name).train()
    
        def set_eval(self):
            for net_name in self.network_names:
                getattr(self, net_name).eval()
    
        def photometric_loss(self, im1, im2, mask=None, conf_sigma=None):
            loss = (im1-im2).abs()
            if conf_sigma is not None:
                loss = loss *2**0.5 / (conf_sigma +EPS) + (conf_sigma +EPS).log()
            if mask is not None:
                mask = mask.expand_as(loss)
                loss = (loss * mask).sum() / mask.sum()
            else:
                loss = loss.mean()
            return loss
    
        def backward(self):
            for optim_name in self.optimizer_names:
                getattr(self, optim_name).zero_grad()
            self.loss_total.backward()
            for optim_name in self.optimizer_names:
                getattr(self, optim_name).step()
    
        def forward(self, input):
            """Feedforward once."""
            if self.load_gt_depth:
                input, depth_gt = input
            self.input_im = input.to(self.device) *2.-1.
            b, c, h, w = self.input_im.shape
    
            ## predict canonical depth, ok
            # networks.EDDeconv(cin=3, cout=1, nf=64, zdim=256, activation=None)
            self.canon_depth_raw = self.netD(self.input_im).squeeze(1)  # BxHxW
            self.canon_depth = self.canon_depth_raw - self.canon_depth_raw.view(b,-1).mean(1).view(b,1,1)
            self.canon_depth = self.canon_depth.tanh()
            self.canon_depth = self.depth_rescaler(self.canon_depth)
    
            ## optional depth smoothness loss (only used in synthetic car experiments), 
            # neighbor pixels are smooth in height and width
            self.loss_depth_sm = ((self.canon_depth[:,:-1,:] - self.canon_depth[:,1:,:]) /(self.max_depth-self.min_depth)).abs().mean()
            self.loss_depth_sm += ((self.canon_depth[:,:,:-1] - self.canon_depth[:,:,1:]) /(self.max_depth-self.min_depth)).abs().mean()
    
            ## clamp border depth, masked border,ok
            depth_border = torch.zeros(1,h,w-4).to(self.input_im.device)
            depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1)
            self.canon_depth = self.canon_depth*(1-depth_border) + depth_border *self.border_depth
            self.canon_depth = torch.cat([self.canon_depth, self.canon_depth.flip(2)], 0)  # flip, because depth is a 3-dim tensor
    
            ## predict canonical albedo, ok
            self.canon_albedo = self.netA(self.input_im)  # Bx3xHxW
            self.canon_albedo = torch.cat([self.canon_albedo, self.canon_albedo.flip(3)], 0)  # flip, albedo is a 4-dim tensor
    
            ## predict confidence map, share backbone, two-stream net
            if self.use_conf_map:
                conf_sigma_l1, conf_sigma_percl = self.netC(self.input_im)  # Bx2xHxW
                self.conf_sigma_l1 = conf_sigma_l1[:,:1]
                self.conf_sigma_l1_flip = conf_sigma_l1[:,1:]
                self.conf_sigma_percl = conf_sigma_percl[:,:1]
                self.conf_sigma_percl_flip = conf_sigma_percl[:,1:]
            else:
                self.conf_sigma_l1 = None
                self.conf_sigma_l1_flip = None
                self.conf_sigma_percl = None
                self.conf_sigma_percl_flip = None
    
            ####################################### II(^(a,d,l),d,w): ^(a,d,l); II(^,d,w) camera model
            # step 1:
            ## predict lighting, predict by NetL, 
            ## get: canon_light_a, canon_light_b, canon_light_d
            canon_light = self.netL(self.input_im).repeat(2,1)  # Bx4, self.netL = networks.Encoder(cin=3, cout=4, nf=32)
            self.canon_light_a = self.amb_light_rescaler(canon_light[:,:1])  # ambience term, 0, Ks, each pixel has a Ks?
            self.canon_light_b = self.diff_light_rescaler(canon_light[:,1:2])  # diffuse term, 1, Kd, each pixel has a Kd?
            canon_light_dxy = canon_light[:,2:], # 2 and 3
            self.canon_light_d = torch.cat([canon_light_dxy, torch.ones(b*2,1).to(self.input_im.device)], 1)
            self.canon_light_d = self.canon_light_d / ((self.canon_light_d**2).sum(1, keepdim=True))**0.5  # diffuse light direction
    
            ## shading
            # (canon_normal+canon_light_d+canon_light_a+canon_light_b)*canon_albedo
            self.canon_normal = self.renderer.get_normal_from_depth(self.canon_depth) #depth to norm, based on u, v directional normal vectors
            self.canon_diffuse_shading = (self.canon_normal * self.canon_light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1) # dot (n,l)
            canon_shading = self.canon_light_a.view(-1,1,1,1) + self.canon_light_b.view(-1,1,1,1)*self.canon_diffuse_shading # ks+kd.dot(n,l)
            self.canon_im = (self.canon_albedo/2+0.5) * canon_shading *2-1  #canon_albedo.canon_shading
    
    
            ############################ 
            # step 2: 
            ## predict viewpoint transformation
            self.view = self.netV(self.input_im).repeat(2,1)  # networks.Encoder(cin=3, cout=6, nf=32), R and T
            self.view = torch.cat([
                self.view[:,:3] *math.pi/180 *self.xyz_rotation_range, # 0,1,2
                self.view[:,3:5] *self.xy_translation_range,           # 3,4
                self.view[:,5:] *self.z_translation_range], 1)         # 5
    
            ## reconstruct input view
            self.renderer.set_transform_matrices(self.view)
            self.recon_depth = self.renderer.warp_canon_depth(self.canon_depth) #torch.Size([128, 64, 64])
    
            self.recon_normal = self.renderer.get_normal_from_depth(self.recon_depth) #torch.Size([128, 64, 64, 3]), get new normal as before
            grid_2d_from_canon = self.renderer.get_inv_warped_2d_grid(self.recon_depth) #torch.Size([128, 64, 64, 2]), new 2d grid because of wrap
            
            # grid_2d_from_canon is wrapped. we have to compute new a uniform quantized 2d grid (image)
            self.recon_im = nn.functional.grid_sample(self.canon_im, grid_2d_from_canon, mode='bilinear') #canon_im: torch.Size([128, 3, 64, 64])
    
            ## mask out boder pixels
            margin = (self.max_depth - self.min_depth) /2
            recon_im_mask = (self.recon_depth < self.max_depth+margin).float()  # invalid border pixels have been clamped at max_depth+margin
            recon_im_mask_both = recon_im_mask[:b] * recon_im_mask[b:]  # both original and flip reconstruction
            recon_im_mask_both = recon_im_mask_both.repeat(2,1,1).unsqueeze(1).detach()
            self.recon_im = self.recon_im * recon_im_mask_both
    
            #######################################
    
            ## render symmetry axis, for visualization?
            canon_sym_axis = torch.zeros(h, w).to(self.input_im.device)
            canon_sym_axis[:, w//2-1:w//2+1] = 1
            self.recon_sym_axis = nn.functional.grid_sample(canon_sym_axis.repeat(b*2,1,1,1), grid_2d_from_canon, mode='bilinear')
            self.recon_sym_axis = self.recon_sym_axis * recon_im_mask_both
            green = torch.FloatTensor([-1,1,-1]).to(self.input_im.device).view(1,3,1,1)
            self.input_im_symline = (0.5*self.recon_sym_axis) *green + (1-0.5*self.recon_sym_axis) *self.input_im.repeat(2,1,1,1)
            
    
    
            ## loss function
            self.loss_l1_im = self.photometric_loss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_l1)
            self.loss_l1_im_flip = self.photometric_loss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_l1_flip)
            self.loss_perc_im = self.PerceptualLoss(self.recon_im[:b], self.input_im, mask=recon_im_mask_both[:b], conf_sigma=self.conf_sigma_percl)
            self.loss_perc_im_flip = self.PerceptualLoss(self.recon_im[b:], self.input_im, mask=recon_im_mask_both[b:], conf_sigma=self.conf_sigma_percl_flip)
            lam_flip = 1 if self.trainer.current_epoch < self.lam_flip_start_epoch else self.lam_flip
            self.loss_total = self.loss_l1_im + lam_flip*self.loss_l1_im_flip + self.lam_perc*(self.loss_perc_im + lam_flip*self.loss_perc_im_flip) + self.lam_depth_sm*self.loss_depth_sm
    
            metrics = {'loss': self.loss_total}
    
            ## compute accuracy if gt depth is available
            if self.load_gt_depth:
                self.depth_gt = depth_gt[:,0,:,:].to(self.input_im.device)
                self.depth_gt = (1-self.depth_gt)*2-1
                self.depth_gt = self.depth_rescaler(self.depth_gt)
                self.normal_gt = self.renderer.get_normal_from_depth(self.depth_gt)
    
                # mask out background
                mask_gt = (self.depth_gt<self.depth_gt.max()).float()
                mask_gt = (nn.functional.avg_pool2d(mask_gt.unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float()  # erode by 1 pixel
                mask_pred = (nn.functional.avg_pool2d(recon_im_mask[:b].unsqueeze(1), 3, stride=1, padding=1).squeeze(1) > 0.99).float()  # erode by 1 pixel
                mask = mask_gt * mask_pred
                self.acc_mae_masked = ((self.recon_depth[:b] - self.depth_gt[:b]).abs() *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1)
                self.acc_mse_masked = (((self.recon_depth[:b] - self.depth_gt[:b])**2) *mask).view(b,-1).sum(1) / mask.view(b,-1).sum(1)
                self.sie_map_masked = utils.compute_sc_inv_err(self.recon_depth[:b].log(), self.depth_gt[:b].log(), mask=mask)
                self.acc_sie_masked = (self.sie_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1))**0.5
                self.norm_err_map_masked = utils.compute_angular_distance(self.recon_normal[:b], self.normal_gt[:b], mask=mask)
                self.acc_normal_masked = self.norm_err_map_masked.view(b,-1).sum(1) / mask.view(b,-1).sum(1)
    
                metrics['SIE_masked'] = self.acc_sie_masked.mean()
                metrics['NorErr_masked'] = self.acc_normal_masked.mean()
    
            return metrics
    
        def visualize(self, logger, total_iter, max_bs=25):
            b, c, h, w = self.input_im.shape
            b0 = min(max_bs, b)
    
            ## render rotations
            with torch.no_grad():
                v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b0,1)
                canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b0], self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5  # (B,T,C,H,W)
                canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b0].permute(0,3,1,2), self.canon_depth[:b0], v_before=v0, maxr=90).detach().cpu() /2.+0.5  # (B,T,C,H,W)
    
            input_im = self.input_im[:b0].detach().cpu().numpy() /2+0.5
            input_im_symline = self.input_im_symline[:b0].detach().cpu() /2.+0.5
            canon_albedo = self.canon_albedo[:b0].detach().cpu() /2.+0.5
            canon_im = self.canon_im[:b0].detach().cpu() /2.+0.5
            recon_im = self.recon_im[:b0].detach().cpu() /2.+0.5
            recon_im_flip = self.recon_im[b:b+b0].detach().cpu() /2.+0.5
            canon_depth_raw_hist = self.canon_depth_raw.detach().unsqueeze(1).cpu()
            canon_depth_raw = self.canon_depth_raw[:b0].detach().unsqueeze(1).cpu() /2.+0.5
            canon_depth = ((self.canon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
            recon_depth = ((self.recon_depth[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
            canon_diffuse_shading = self.canon_diffuse_shading[:b0].detach().cpu()
            canon_normal = self.canon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
            recon_normal = self.recon_normal.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
            if self.use_conf_map:
                conf_map_l1 = 1/(1+self.conf_sigma_l1[:b0].detach().cpu()+EPS)
                conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b0].detach().cpu()+EPS)
                conf_map_percl = 1/(1+self.conf_sigma_percl[:b0].detach().cpu()+EPS)
                conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b0].detach().cpu()+EPS)
    
            canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_im_rotate, 1)]  # [(C,H,W)]*T
            canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0)  # (1,T,C,H,W)
            canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b0**0.5))) for img in torch.unbind(canon_normal_rotate, 1)]  # [(C,H,W)]*T
            canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0)  # (1,T,C,H,W)
    
            ## write summary
            logger.add_scalar('Loss/loss_total', self.loss_total, total_iter)
            logger.add_scalar('Loss/loss_l1_im', self.loss_l1_im, total_iter)
            logger.add_scalar('Loss/loss_l1_im_flip', self.loss_l1_im_flip, total_iter)
            logger.add_scalar('Loss/loss_perc_im', self.loss_perc_im, total_iter)
            logger.add_scalar('Loss/loss_perc_im_flip', self.loss_perc_im_flip, total_iter)
            logger.add_scalar('Loss/loss_depth_sm', self.loss_depth_sm, total_iter)
    
            logger.add_histogram('Depth/canon_depth_raw_hist', canon_depth_raw_hist, total_iter)
            vlist = ['view_rx', 'view_ry', 'view_rz', 'view_tx', 'view_ty', 'view_tz']
            for i in range(self.view.shape[1]):
                logger.add_histogram('View/'+vlist[i], self.view[:,i], total_iter)
            logger.add_histogram('Light/canon_light_a', self.canon_light_a, total_iter)
            logger.add_histogram('Light/canon_light_b', self.canon_light_b, total_iter)
            llist = ['canon_light_dx', 'canon_light_dy', 'canon_light_dz']
            for i in range(self.canon_light_d.shape[1]):
                logger.add_histogram('Light/'+llist[i], self.canon_light_d[:,i], total_iter)
    
            def log_grid_image(label, im, nrow=int(math.ceil(b0**0.5)), iter=total_iter):
                im_grid = torchvision.utils.make_grid(im, nrow=nrow)
                logger.add_image(label, im_grid, iter)
    
            log_grid_image('Image/input_image_symline', input_im_symline)
            log_grid_image('Image/canonical_albedo', canon_albedo)
            log_grid_image('Image/canonical_image', canon_im)
            log_grid_image('Image/recon_image', recon_im)
            log_grid_image('Image/recon_image_flip', recon_im_flip)
            log_grid_image('Image/recon_side', canon_im_rotate[:,0,:,:,:])
    
            log_grid_image('Depth/canonical_depth_raw', canon_depth_raw)
            log_grid_image('Depth/canonical_depth', canon_depth)
            log_grid_image('Depth/recon_depth', recon_depth)
            log_grid_image('Depth/canonical_diffuse_shading', canon_diffuse_shading)
            log_grid_image('Depth/canonical_normal', canon_normal)
            log_grid_image('Depth/recon_normal', recon_normal)
    
            logger.add_histogram('Image/canonical_albedo_hist', canon_albedo, total_iter)
            logger.add_histogram('Image/canonical_diffuse_shading_hist', canon_diffuse_shading, total_iter)
    
            if self.use_conf_map:
                log_grid_image('Conf/conf_map_l1', conf_map_l1)
                logger.add_histogram('Conf/conf_sigma_l1_hist', self.conf_sigma_l1, total_iter)
                log_grid_image('Conf/conf_map_l1_flip', conf_map_l1_flip)
                logger.add_histogram('Conf/conf_sigma_l1_flip_hist', self.conf_sigma_l1_flip, total_iter)
                log_grid_image('Conf/conf_map_percl', conf_map_percl)
                logger.add_histogram('Conf/conf_sigma_percl_hist', self.conf_sigma_percl, total_iter)
                log_grid_image('Conf/conf_map_percl_flip', conf_map_percl_flip)
                logger.add_histogram('Conf/conf_sigma_percl_flip_hist', self.conf_sigma_percl_flip, total_iter)
    
            logger.add_video('Image_rotate/recon_rotate', canon_im_rotate_grid, total_iter, fps=4)
            logger.add_video('Image_rotate/canon_normal_rotate', canon_normal_rotate_grid, total_iter, fps=4)
    
            # visualize images and accuracy if gt is loaded
            if self.load_gt_depth:
                depth_gt = ((self.depth_gt[:b0] -self.min_depth)/(self.max_depth-self.min_depth)).detach().cpu().unsqueeze(1)
                normal_gt = self.normal_gt.permute(0,3,1,2)[:b0].detach().cpu() /2+0.5
                sie_map_masked = self.sie_map_masked[:b0].detach().unsqueeze(1).cpu() *1000
                norm_err_map_masked = self.norm_err_map_masked[:b0].detach().unsqueeze(1).cpu() /100
    
                logger.add_scalar('Acc_masked/MAE_masked', self.acc_mae_masked.mean(), total_iter)
                logger.add_scalar('Acc_masked/MSE_masked', self.acc_mse_masked.mean(), total_iter)
                logger.add_scalar('Acc_masked/SIE_masked', self.acc_sie_masked.mean(), total_iter)
                logger.add_scalar('Acc_masked/NorErr_masked', self.acc_normal_masked.mean(), total_iter)
    
                log_grid_image('Depth_gt/depth_gt', depth_gt)
                log_grid_image('Depth_gt/normal_gt', normal_gt)
                log_grid_image('Depth_gt/sie_map_masked', sie_map_masked)
                log_grid_image('Depth_gt/norm_err_map_masked', norm_err_map_masked)
    
        def save_results(self, save_dir):
            b, c, h, w = self.input_im.shape
    
            with torch.no_grad():
                v0 = torch.FloatTensor([-0.1*math.pi/180*60,0,0,0,0,0]).to(self.input_im.device).repeat(b,1)
                canon_im_rotate = self.renderer.render_yaw(self.canon_im[:b], self.canon_depth[:b], v_before=v0, maxr=90, nsample=15)  # (B,T,C,H,W)
                canon_im_rotate = canon_im_rotate.clamp(-1,1).detach().cpu() /2+0.5
                canon_normal_rotate = self.renderer.render_yaw(self.canon_normal[:b].permute(0,3,1,2), self.canon_depth[:b], v_before=v0, maxr=90, nsample=15)  # (B,T,C,H,W)
                canon_normal_rotate = canon_normal_rotate.clamp(-1,1).detach().cpu() /2+0.5
    
            input_im = self.input_im[:b].detach().cpu().numpy() /2+0.5
            input_im_symline = self.input_im_symline.detach().cpu().numpy() /2.+0.5
            canon_albedo = self.canon_albedo[:b].detach().cpu().numpy() /2+0.5
            canon_im = self.canon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5
            recon_im = self.recon_im[:b].clamp(-1,1).detach().cpu().numpy() /2+0.5
            recon_im_flip = self.recon_im[b:].clamp(-1,1).detach().cpu().numpy() /2+0.5
            canon_depth = ((self.canon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
            recon_depth = ((self.recon_depth[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
            canon_diffuse_shading = self.canon_diffuse_shading[:b].detach().cpu().numpy()
            canon_normal = self.canon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
            recon_normal = self.recon_normal[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
            if self.use_conf_map:
                conf_map_l1 = 1/(1+self.conf_sigma_l1[:b].detach().cpu().numpy()+EPS)
                conf_map_l1_flip = 1/(1+self.conf_sigma_l1_flip[:b].detach().cpu().numpy()+EPS)
                conf_map_percl = 1/(1+self.conf_sigma_percl[:b].detach().cpu().numpy()+EPS)
                conf_map_percl_flip = 1/(1+self.conf_sigma_percl_flip[:b].detach().cpu().numpy()+EPS)
            canon_light = torch.cat([self.canon_light_a, self.canon_light_b, self.canon_light_d], 1)[:b].detach().cpu().numpy()
            view = self.view[:b].detach().cpu().numpy()
    
            canon_im_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_im_rotate,1)]  # [(C,H,W)]*T
            canon_im_rotate_grid = torch.stack(canon_im_rotate_grid, 0).unsqueeze(0).numpy()  # (1,T,C,H,W)
            canon_normal_rotate_grid = [torchvision.utils.make_grid(img, nrow=int(math.ceil(b**0.5))) for img in torch.unbind(canon_normal_rotate,1)]  # [(C,H,W)]*T
            canon_normal_rotate_grid = torch.stack(canon_normal_rotate_grid, 0).unsqueeze(0).numpy()  # (1,T,C,H,W)
    
            sep_folder = True
            utils.save_images(save_dir, input_im, suffix='input_image', sep_folder=sep_folder)
            utils.save_images(save_dir, input_im_symline, suffix='input_image_symline', sep_folder=sep_folder)
            utils.save_images(save_dir, canon_albedo, suffix='canonical_albedo', sep_folder=sep_folder)
            utils.save_images(save_dir, canon_im, suffix='canonical_image', sep_folder=sep_folder)
            utils.save_images(save_dir, recon_im, suffix='recon_image', sep_folder=sep_folder)
            utils.save_images(save_dir, recon_im_flip, suffix='recon_image_flip', sep_folder=sep_folder)
            utils.save_images(save_dir, canon_depth, suffix='canonical_depth', sep_folder=sep_folder)
            utils.save_images(save_dir, recon_depth, suffix='recon_depth', sep_folder=sep_folder)
            utils.save_images(save_dir, canon_diffuse_shading, suffix='canonical_diffuse_shading', sep_folder=sep_folder)
            utils.save_images(save_dir, canon_normal, suffix='canonical_normal', sep_folder=sep_folder)
            utils.save_images(save_dir, recon_normal, suffix='recon_normal', sep_folder=sep_folder)
            if self.use_conf_map:
                utils.save_images(save_dir, conf_map_l1, suffix='conf_map_l1', sep_folder=sep_folder)
                utils.save_images(save_dir, conf_map_l1_flip, suffix='conf_map_l1_flip', sep_folder=sep_folder)
                utils.save_images(save_dir, conf_map_percl, suffix='conf_map_percl', sep_folder=sep_folder)
                utils.save_images(save_dir, conf_map_percl_flip, suffix='conf_map_percl_flip', sep_folder=sep_folder)
            utils.save_txt(save_dir, canon_light, suffix='canonical_light', sep_folder=sep_folder)
            utils.save_txt(save_dir, view, suffix='viewpoint', sep_folder=sep_folder)
    
            utils.save_videos(save_dir, canon_im_rotate_grid, suffix='image_video', sep_folder=sep_folder, cycle=True)
            utils.save_videos(save_dir, canon_normal_rotate_grid, suffix='normal_video', sep_folder=sep_folder, cycle=True)
    
            # save scores if gt is loaded
            if self.load_gt_depth:
                depth_gt = ((self.depth_gt[:b] -self.min_depth)/(self.max_depth-self.min_depth)).clamp(0,1).detach().cpu().unsqueeze(1).numpy()
                normal_gt = self.normal_gt[:b].permute(0,3,1,2).detach().cpu().numpy() /2+0.5
                utils.save_images(save_dir, depth_gt, suffix='depth_gt', sep_folder=sep_folder)
                utils.save_images(save_dir, normal_gt, suffix='normal_gt', sep_folder=sep_folder)
    
                all_scores = torch.stack([
                    self.acc_mae_masked.detach().cpu(),
                    self.acc_mse_masked.detach().cpu(),
                    self.acc_sie_masked.detach().cpu(),
                    self.acc_normal_masked.detach().cpu()], 1)
                if not hasattr(self, 'all_scores'):
                    self.all_scores = torch.FloatTensor()
                self.all_scores = torch.cat([self.all_scores, all_scores], 0)
    
        def save_scores(self, path):
            # save scores if gt is loaded
            if self.load_gt_depth:
                header = 'MAE_masked, 
                          MSE_masked, 
                          SIE_masked, 
                          NorErr_masked'
                mean = self.all_scores.mean(0)
                std = self.all_scores.std(0)
                header = header + '
    Mean: ' + ',	'.join(['%.8f'%x for x in mean])
                header = header + '
    Std: ' + ',	'.join(['%.8f'%x for x in std])
                utils.save_scores(path, self.all_scores, header=header)

    render.py

    import torch
    import math
    import neural_renderer as nr
    from .utils import *
    
    
    EPS = 1e-7
    
    
    class Renderer():
        def __init__(self, cfgs):
            self.device = cfgs.get('device', 'cpu')
            self.image_size = cfgs.get('image_size', 64)
            self.min_depth = cfgs.get('min_depth', 0.9)
            self.max_depth = cfgs.get('max_depth', 1.1)
            self.rot_center_depth = cfgs.get('rot_center_depth', (self.min_depth+self.max_depth)/2)
            self.fov = cfgs.get('fov', 10)
            self.tex_cube_size = cfgs.get('tex_cube_size', 2)
            self.renderer_min_depth = cfgs.get('renderer_min_depth', 0.1)
            self.renderer_max_depth = cfgs.get('renderer_max_depth', 10.)
    
            #### camera intrinsics
            #             (u)   (x)
            #    d * K^-1 (v) = (y)
            #             (1)   (z)
    
            ## renderer for visualization
            R = [[[1.,0.,0.],
                  [0.,1.,0.],
                  [0.,0.,1.]]]
            R = torch.FloatTensor(R).to(self.device)
            t = torch.zeros(1,3, dtype=torch.float32).to(self.device)
            fx = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180))
            fy = (self.image_size-1)/2/(math.tan(self.fov/2 *math.pi/180))
            cx = (self.image_size-1)/2
            cy = (self.image_size-1)/2
            K = [[fx, 0., cx],
                 [0., fy, cy],
                 [0., 0., 1.]]
            K = torch.FloatTensor(K).to(self.device)
            self.inv_K = torch.inverse(K).unsqueeze(0)
            self.K = K.unsqueeze(0)
            self.renderer = nr.Renderer(camera_mode='projection',
                                        light_intensity_ambient=1.0,
                                        light_intensity_directional=0.,
                                        K=self.K, R=R, t=t,
                                        near=self.renderer_min_depth, far=self.renderer_max_depth,
                                        image_size=self.image_size, orig_size=self.image_size,
                                        fill_back=True,
                                        background_color=[1,1,1])
    
        def set_transform_matrices(self, view):
            self.rot_mat, self.trans_xyz = get_transform_matrices(view)
    
        def rotate_pts(self, pts, rot_mat):
            centroid = torch.FloatTensor([0.,0.,self.rot_center_depth]).to(pts.device).view(1,1,3)
            pts = pts - centroid  # move to centroid
            pts = pts.matmul(rot_mat.transpose(2,1))  # rotate
            pts = pts + centroid  # move back
            return pts
    
        def translate_pts(self, pts, trans_xyz):
            return pts + trans_xyz
    
        def depth_to_3d_grid(self, depth): #### 3, compute 2d grid--->3d depth, (x,y, depth), xin [],yin []
            b, h, w = depth.shape
            grid_2d = get_grid(b, h, w, normalize=False).to(depth.device)  # Nxhxwx2, torch.Size([128, 64, 64, 2])
            depth = depth.unsqueeze(-1) # torch.Size([128, 64, 64, 1])
            grid_3d = torch.cat((grid_2d, torch.ones_like(depth)), dim=3) # torch.Size([128, 64, 64, 3])
            grid_3d = grid_3d.matmul(self.inv_K.to(depth.device).transpose(2,1)) * depth
            return grid_3d
    
        def grid_3d_to_2d(self, grid_3d):
            b, h, w, _ = grid_3d.shape
            grid_2d = grid_3d / grid_3d[...,2:]
            grid_2d = grid_2d.matmul(self.K.to(grid_3d.device).transpose(2,1))[:,:,:,:2]
            WH = torch.FloatTensor([w-1, h-1]).to(grid_3d.device).view(1,1,1,2)
            grid_2d = grid_2d / WH *2.-1.  # normalize to -1~1
            return grid_2d
    
        def get_warped_3d_grid(self, depth):  #### 2
            b, h, w = depth.shape #torch.Size([128, 64, 64])
            grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3) #torch.Size([128, 4096, 3]), has K_inv, 64x64 3D points
            grid_3d = self.rotate_pts(grid_3d, self.rot_mat)       # rotate and translate
            grid_3d = self.translate_pts(grid_3d, self.trans_xyz)  # the state of each imaget
            return grid_3d.reshape(b,h,w,3) # return 3d vertices
    
        def get_inv_warped_3d_grid(self, depth):
            b, h, w = depth.shape
            grid_3d = self.depth_to_3d_grid(depth).reshape(b,-1,3)
            grid_3d = self.translate_pts(grid_3d, -self.trans_xyz)          
            grid_3d = self.rotate_pts(grid_3d, self.rot_mat.transpose(2,1))
            return grid_3d.reshape(b,h,w,3) # return 3d vertices
    
        def get_warped_2d_grid(self, depth):
            b, h, w = depth.shape
            grid_3d = self.get_warped_3d_grid(depth)
            grid_2d = self.grid_3d_to_2d(grid_3d)
            return grid_2d
    
        def get_inv_warped_2d_grid(self, depth):
            b, h, w = depth.shape
            grid_3d = self.get_inv_warped_3d_grid(depth)
            grid_2d = self.grid_3d_to_2d(grid_3d)
            return grid_2d
    
        def warp_canon_depth(self, canon_depth):  #### 1
            b, h, w = canon_depth.shape
            grid_3d = self.get_warped_3d_grid(canon_depth).reshape(b,-1,3)  # get 3d vertices: including generate (x,y, z) mesh, reshape 
            faces = get_face_idx(b, h, w).to(canon_depth.device)            # mesh into 3d points, R and T transformations
            warped_depth = self.renderer.render_depth(grid_3d, faces)       # new depth
    
            # allow some margin out of valid range
            margin = (self.max_depth - self.min_depth) /2
            warped_depth = warped_depth.clamp(min=self.min_depth-margin, max=self.max_depth+margin)
            return warped_depth
    
        def get_normal_from_depth(self, depth):
            b, h, w = depth.shape
            grid_3d = self.depth_to_3d_grid(depth)
    
            tu = grid_3d[:,1:-1,2:] - grid_3d[:,1:-1,:-2]
            tv = grid_3d[:,2:,1:-1] - grid_3d[:,:-2,1:-1]
            normal = tu.cross(tv, dim=3)
    
            zero = torch.FloatTensor([0,0,1]).to(depth.device)
            normal = torch.cat([zero.repeat(b,h-2,1,1), normal, zero.repeat(b,h-2,1,1)], 2)
            normal = torch.cat([zero.repeat(b,1,w,1), normal, zero.repeat(b,1,w,1)], 1)
            normal = normal / (((normal**2).sum(3, keepdim=True))**0.5 + EPS)
            return normal
    
        def render_yaw(self, im, depth, v_before=None, v_after=None, rotations=None, maxr=90, nsample=9, crop_mesh=None):
            b, c, h, w = im.shape
            grid_3d = self.depth_to_3d_grid(depth)
    
            if crop_mesh is not None:
                top, bottom, left, right = crop_mesh  # pixels from border to be cropped
                if top > 0:
                    grid_3d[:,:top,:,1] = grid_3d[:,top:top+1,:,1].repeat(1,top,1)
                    grid_3d[:,:top,:,2] = grid_3d[:,top:top+1,:,2].repeat(1,top,1)
                if bottom > 0:
                    grid_3d[:,-bottom:,:,1] = grid_3d[:,-bottom-1:-bottom,:,1].repeat(1,bottom,1)
                    grid_3d[:,-bottom:,:,2] = grid_3d[:,-bottom-1:-bottom,:,2].repeat(1,bottom,1)
                if left > 0:
                    grid_3d[:,:,:left,0] = grid_3d[:,:,left:left+1,0].repeat(1,1,left)
                    grid_3d[:,:,:left,2] = grid_3d[:,:,left:left+1,2].repeat(1,1,left)
                if right > 0:
                    grid_3d[:,:,-right:,0] = grid_3d[:,:,-right-1:-right,0].repeat(1,1,right)
                    grid_3d[:,:,-right:,2] = grid_3d[:,:,-right-1:-right,2].repeat(1,1,right)
    
            grid_3d = grid_3d.reshape(b,-1,3)
            im_trans = []
    
            # inverse warp
            if v_before is not None:
                rot_mat, trans_xyz = get_transform_matrices(v_before)
                grid_3d = self.translate_pts(grid_3d, -trans_xyz)
                grid_3d = self.rotate_pts(grid_3d, rot_mat.transpose(2,1))
    
            if rotations is None:
                rotations = torch.linspace(-math.pi/180*maxr, math.pi/180*maxr, nsample)
            for i, ri in enumerate(rotations):
                ri = torch.FloatTensor([0, ri, 0]).to(im.device).view(1,3)
                rot_mat_i, _ = get_transform_matrices(ri)
                grid_3d_i = self.rotate_pts(grid_3d, rot_mat_i.repeat(b,1,1))
    
                if v_after is not None:
                    if len(v_after.shape) == 3:
                        v_after_i = v_after[i]
                    else:
                        v_after_i = v_after
                    rot_mat, trans_xyz = get_transform_matrices(v_after_i)
                    grid_3d_i = self.rotate_pts(grid_3d_i, rot_mat)
                    grid_3d_i = self.translate_pts(grid_3d_i, trans_xyz)
    
                faces = get_face_idx(b, h, w).to(im.device)
                textures = get_textures_from_im(im, tx_size=self.tex_cube_size)
                warped_images = self.renderer.render_rgb(grid_3d_i, faces, textures).clamp(min=-1., max=1.)
                im_trans += [warped_images]
            return torch.stack(im_trans, 1)  # b x t x c x h x w

    utils.py

    import torch
    
    
    def mm_normalize(x, min=0, max=1):
        x_min = x.min()
        x_max = x.max()
        x_range = x_max - x_min
        x_z = (x - x_min) / x_range
        x_out = x_z * (max - min) + min
        return x_out
    
    
    def rand_range(size, min, max):
        return torch.rand(size)*(max-min)+min
    
    
    def rand_posneg_range(size, min, max):
        i = (torch.rand(size) > 0.5).type(torch.float)*2.-1.
        return i*rand_range(size, min, max)
    
    
    def get_grid(b, H, W, normalize=True):  ####
        if normalize:
            h_range = torch.linspace(-1,1,H)
            w_range = torch.linspace(-1,1,W)
        else:
            h_range = torch.arange(0,H)
            w_range = torch.arange(0,W)
        grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b,1,1,1).flip(3).float() # flip h,w to x,y
        return grid
    
    
    def get_rotation_matrix(tx, ty, tz):
        m_x = torch.zeros((len(tx), 3, 3)).to(tx.device)
        m_y = torch.zeros((len(tx), 3, 3)).to(tx.device)
        m_z = torch.zeros((len(tx), 3, 3)).to(tx.device)
    
        m_x[:, 1, 1], m_x[:, 1, 2] = tx.cos(), -tx.sin()
        m_x[:, 2, 1], m_x[:, 2, 2] = tx.sin(), tx.cos()
        m_x[:, 0, 0] = 1
    
        m_y[:, 0, 0], m_y[:, 0, 2] = ty.cos(), ty.sin()
        m_y[:, 2, 0], m_y[:, 2, 2] = -ty.sin(), ty.cos()
        m_y[:, 1, 1] = 1
    
        m_z[:, 0, 0], m_z[:, 0, 1] = tz.cos(), -tz.sin()
        m_z[:, 1, 0], m_z[:, 1, 1] = tz.sin(), tz.cos()
        m_z[:, 2, 2] = 1
        return torch.matmul(m_z, torch.matmul(m_y, m_x))
    
    
    def get_transform_matrices(view):
        b = view.size(0)
        if view.size(1) == 6:
            rx = view[:,0]
            ry = view[:,1]
            rz = view[:,2]
            trans_xyz = view[:,3:].reshape(b,1,3)
        elif view.size(1) == 5:
            rx = view[:,0]
            ry = view[:,1]
            rz = view[:,2]
            delta_xy = view[:,3:].reshape(b,1,2)
            trans_xyz = torch.cat([delta_xy, torch.zeros(b,1,1).to(view.device)], 2)
        elif view.size(1) == 3:
            rx = view[:,0]
            ry = view[:,1]
            rz = view[:,2]
            trans_xyz = torch.zeros(b,1,3).to(view.device)
        rot_mat = get_rotation_matrix(rx, ry, rz)
        return rot_mat, trans_xyz
    
    
    def get_face_idx(b, h, w):
        idx_map = torch.arange(h*w).reshape(h,w)
        faces1 = torch.stack([idx_map[:h-1,:w-1], idx_map[1:,:w-1], idx_map[:h-1,1:]], -1).reshape(-1,3)
        faces2 = torch.stack([idx_map[:h-1,1:], idx_map[1:,:w-1], idx_map[1:,1:]], -1).reshape(-1,3)
        return torch.cat([faces1,faces2], 0).repeat(b,1,1).int()
    
    
    def vcolor_to_texture_cube(vcolors):
        # input bxcxnx3
        b, c, n, f = vcolors.shape
        coeffs = torch.FloatTensor(
            [[ 0.5,  0.5,  0.5],
             [ 0. ,  0. ,  1. ],
             [ 0. ,  1. ,  0. ],
             [-0.5,  0.5,  0.5],
             [ 1. ,  0. ,  0. ],
             [ 0.5, -0.5,  0.5],
             [ 0.5,  0.5, -0.5],
             [ 0. ,  0. ,  0. ]]).to(vcolors.device)
        return coeffs.matmul(vcolors.permute(0,2,3,1)).reshape(b,n,2,2,2,c)
    
    
    def get_textures_from_im(im, tx_size=1):
        b, c, h, w = im.shape
        if tx_size == 1:
            textures = torch.cat([im[:,:,:h-1,:w-1].reshape(b,c,-1), im[:,:,1:,1:].reshape(b,c,-1)], 2)
            textures = textures.transpose(2,1).reshape(b,-1,1,1,1,c)
        elif tx_size == 2:
            textures1 = torch.stack([im[:,:,:h-1,:w-1], im[:,:,:h-1,1:], im[:,:,1:,:w-1]], -1).reshape(b,c,-1,3)
            textures2 = torch.stack([im[:,:,1:,:w-1], im[:,:,:h-1,1:], im[:,:,1:,1:]], -1).reshape(b,c,-1,3)
            textures = vcolor_to_texture_cube(torch.cat([textures1, textures2], 2)) # bxnx2x2x2xc
        else:
            raise NotImplementedError("Currently support texture size of 1 or 2 only.")
        return textures
  • 相关阅读:
    [51单片机] TFT2.4彩屏1 [文字显示 画矩形]
    [51单片机] 中断1-中断整体介绍
    [汇编] 从键盘输入一个一位数字,然后响铃n声
    [汇编] 比较2个字符串是否相等
    [汇编] 将字符串里的一个'&'字符换成空格
    [汇编] 2数相加极简单版
    mysql的IFNULL函数
    mysql 中 unix_timestamp和from_unixtime函数
    Excel实现二级菜单联动
    Hibernate中@Embedded和@Embeddable注解
  • 原文地址:https://www.cnblogs.com/Wanggcong/p/15341971.html
Copyright © 2011-2022 走看看