zoukankan      html  css  js  c++  java
  • Implicit Neural Representations with Periodic Activation Functions(siren)

    代码:https://github.com/vsitzmann/siren

    看其中一个运行在图片上的例子experiment_scripts/train_img.py

    这个例子实现的是论文中下面部分的例子:

    A simple example: fitting an image. 考虑一个例子,即寻找一个能够以连续的方式参数化一个给定的离散图像 f 的函数。图像定义一个与它们的RGB颜色相关联的像素坐标的数据集。唯一实施的约束是 Φ 应该在像素坐标上输出图像颜色,该约束仅依赖于Φ(与其任何导数无关)和,其表示形式为,该约束可以转换成损失

    在图1中,我们使用带有不同激活函数的可兼容的网络结构去拟合Φθ成一个自然图像。我们只对图像值进行监督实验,同时对梯度∇f 和 Laplacians∆f也进行了可视化。只有两种方法,即带有位置编码(P.E)[5]的ReLU网络和我们的SIREN,能够准确地表示ground truth图像f (x),而SIREN是唯一能够表示信号导数的网络。

    即训练网络,能够输入图像的坐标信息,然后输出图像的像素信息,拟合一张图像

    1.数据处理

    使用的是skimage自带的拿相机的人的示例照片。查看下该照片:

    #coding:utf-8
    import skimage
    import os
    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    
    img = skimage.data.camera() #这是个灰度图像,仅一张
    print(img.shape) #(512, 512)
    skimage.io.imsave('./camera_people.jpg',img)
    
    img = skimage.data.chelsea() #这是个小猫的数据集,是彩色图像,仅一张
    print(img.shape) #(300, 451, 3)
    skimage.io.imsave('./cat.jpg',img)

    返回图像:

    dataio.py:

    get_mgrid()函数:

    import numpy as np 
    import torch
    
    sidelen = 512
    dim = 2
    if isinstance(sidelen, int):
        sidelen = dim * (sidelen,)
        print(sidelen)
    
    grid_1 = np.mgrid[:sidelen[0], :sidelen[1]]
    print(grid_1.shape)
    
    grid_2 = np.stack(grid_1, axis=-1)
    print(grid_2.shape)
    
    grid_3 = grid_2[None, ...].astype(np.float32)
    print(grid_3.shape)
    
    grid_4 = torch.Tensor(grid_3).view(-1, dim)
    print(grid_4.shape)

    返回:

    (512, 512)
    (2, 512, 512)
    (512, 512, 2)
    (1, 512, 512, 2)
    torch.Size([262144, 2])
    def get_mgrid(sidelen, dim=2):
        '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
        if isinstance(sidelen, int):
            sidelen = dim * (sidelen,) #(512, 512)
    
        if dim == 2:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2)
            # 此时数组的值在[0,511]的范围里,除以511变成[0,1]的范围
            pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
            pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
        elif dim == 3:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
            pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
            pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
            pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
        else:
            raise NotImplementedError('Not implemented for dim=%d' % dim)
    
        pixel_coords -= 0.5
        pixel_coords *= 2. # 这两部操作将数组中的值的范围变为[-1,1]
        #最后构造得到一个网格,pixel_coords为对应的262144个(x,y)的坐标点
        pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2])
        return pixel_coords
    
    print(get_mgrid(512))

    返回:

    tensor([[-1.0000, -1.0000],
            [-1.0000, -0.9961],
            [-1.0000, -0.9922],
            ...,
            [ 1.0000,  0.9922],
            [ 1.0000,  0.9961],
            [ 1.0000,  1.0000]])

    出错:

    OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
    OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.

    解决,添加:

    import os
    os.environ['KMP_DUPLICATE_LIB_OK']='True'

    测试使用:

    #coding:utf-8
    import numpy as np 
    import torch
    from torch.utils.data import Dataset
    from PIL import Image
    import skimage
    from torchvision.transforms import Resize, Compose, ToTensor, Normalize
    import scipy.ndimage
    
    import os
    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    
    def get_mgrid(sidelen, dim=2):
        '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
        if isinstance(sidelen, int):
            sidelen = dim * (sidelen,) #(512, 512)
    
        if dim == 2:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2)
            # 此时数组的值在[0,511]的范围里,除以511变成[0,1]的范围
            pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
            pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
        elif dim == 3:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
            pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
            pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
            pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
        else:
            raise NotImplementedError('Not implemented for dim=%d' % dim)
    
        pixel_coords -= 0.5
        pixel_coords *= 2. # 这两部操作将数组中的值的范围变为[-1,1]
        #最后构造得到一个网格,pixel_coords为对应的262144个(x,y)的坐标点
        pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2])
        return pixel_coords
    
    # print(get_mgrid(512))
    
    class Camera(Dataset):
        def __init__(self, downsample_factor=1):
            super().__init__()
            self.downsample_factor = downsample_factor
            self.img = Image.fromarray(skimage.data.camera()) #skimage自带的拿相机的人的照片
            self.img_channels = 1
    
            if downsample_factor > 1:
                size = (int(512 / downsample_factor),) * 2
                self.img_downsampled = self.img.resize(size, Image.ANTIALIAS)
    
        def __len__(self):
            return 1
    
        def __getitem__(self, idx):
            if self.downsample_factor > 1:
                return self.img_downsampled
            else:
                return self.img
    
    class Implicit2DWrapper(torch.utils.data.Dataset):
        def __init__(self, dataset, sidelength=None, compute_diff=None):
    
            if isinstance(sidelength, int):
                sidelength = (sidelength, sidelength)
            self.sidelength = sidelength
    
            self.transform = Compose([
                Resize(sidelength),
                ToTensor(),
                Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
            ])
    
            self.compute_diff = compute_diff
            self.dataset = dataset
            self.mgrid = get_mgrid(sidelength)
    
        def __len__(self):
            return len(self.dataset)
    
        def __getitem__(self, idx):
            img = self.transform(self.dataset[idx])
    
            if self.compute_diff == 'gradients':
                img *= 1e1
                gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
                grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
            elif self.compute_diff == 'laplacian':
                img *= 1e4
                laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
            elif self.compute_diff == 'all':
                gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
                # print(gradx.shape) #(512, 512, 1)
                grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
                # print(grady.shape) #(512, 512, 1)
                laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
                # print(laplace.shape) #(512, 512, 1)
    
            # print(img.shape) #torch.Size([1, 512, 512])
            img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels)
            # print(img.shape) #torch.Size([262144, 1])
    
    
            in_dict = {'idx': idx, 'coords': self.mgrid}
            gt_dict = {'img': img}
    
            if self.compute_diff == 'gradients':
                gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                       torch.from_numpy(grady).reshape(-1, 1)),
                                      dim=-1)
                gt_dict.update({'gradients': gradients})
    
            elif self.compute_diff == 'laplacian':
                gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})
    
            elif self.compute_diff == 'all':
                gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                       torch.from_numpy(grady).reshape(-1, 1)),
                                      dim=-1)
                # print(gradients.shape) #torch.Size([262144, 2])
                gt_dict.update({'gradients': gradients})
                gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})
    
            return in_dict, gt_dict
    
    
    img_dataset = Camera()
    coord_dataset = Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all')
    in_dict, gt_dict = coord_dataset[0]
    print(in_dict)
    print(gt_dict)
    
    print(in_dict['coords'].shape)
    print(gt_dict['img'].shape)
    print(gt_dict['gradients'].shape)
    print(gt_dict['laplace'].shape)

    返回:

    {'idx': 0, 'coords': tensor([[-1.0000, -1.0000],
            [-1.0000, -0.9961],
            [-1.0000, -0.9922],
            ...,
            [ 1.0000,  0.9922],
            [ 1.0000,  0.9961],
            [ 1.0000,  1.0000]])}
    {'img': tensor([[ 0.2235],
            [ 0.2314],
            [ 0.2549],
            ...,
            [-0.0510],
            [-0.1137],
            [-0.1294]]), 'gradients': tensor([[ 0.0000,  0.1255],
            [-0.0314,  0.4706],
            [-0.0941,  0.2196],
            ...,
            [ 0.0000, -2.1333],
            [-0.0000, -1.2549],
            [-0.0000, -0.2510]]), 'laplace': tensor([[ 0.0078],
            [ 0.0157],
            [-0.0392],
            ...,
            [ 0.0078],
            [ 0.0471],
            [ 0.0157]])}
    torch.Size([262144, 2])
    torch.Size([262144, 1])
    torch.Size([262144, 2])
    torch.Size([262144, 1])

    2.使用模型

    module.py

    FCBlock:

    MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=1, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (3): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (4): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=2, bias=True)
      )
    )

    SingleBVPNet():

    SingleBVPNet(
      (image_downsampling): ImageDownsampling()
      (net): FCBlock(
        (net): MetaSequential(
          (0): MetaSequential(
            (0): BatchLinear(in_features=2, out_features=256, bias=True)
            (1): Sine()
          )
          (1): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (2): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (3): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (4): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=1, bias=True)
          )
        )
      )
    )
    

     3.损失函数

    loss_functions.py

    def image_mse(mask, model_output, gt):
        if mask is None:
            return {'img_loss': ((model_output['model_out'] - gt['img']) ** 2).mean()}
        else:
            return {'img_loss': (mask * (model_output['model_out'] - gt['img']) ** 2).mean()}

    使用的是MSELoss

    4.总结

    这个简单的例子主要相关的代码是:

    • experiment_scripts/train_img.py
    • dataio.py
    • modules.py
    • loss_functions.py

    大概将主要内容放在一起看看效果:

    #coding:utf-8
    import numpy as np 
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset
    from PIL import Image
    import skimage
    # from skimage import io #有这个,就会报错OMP: Error #15
    from torchvision.transforms import Resize, Compose, ToTensor, Normalize
    import scipy.ndimage
    from torch.utils.data import DataLoader
    from collections import OrderedDict
    from torchmeta.modules.utils import get_subdict
    
    ############################################################## 数据处理 ##############################
    
    import os
    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    
    def get_mgrid(sidelen, dim=2):
        '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
        if isinstance(sidelen, int):
            sidelen = dim * (sidelen,) #(512, 512)
    
        if dim == 2:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2)
            # 此时数组的值在[0,511]的范围里,除以511变成[0,1]的范围
            pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
            pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
        elif dim == 3:
            pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
            pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
            pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
            pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
        else:
            raise NotImplementedError('Not implemented for dim=%d' % dim)
    
        pixel_coords -= 0.5
        pixel_coords *= 2. # 这两部操作将数组中的值的范围变为[-1,1]
        #最后构造得到一个网格,pixel_coords为对应的262144个(x,y)的坐标点
        pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2])
        return pixel_coords
    
    
    class Camera(Dataset):
        def __init__(self, downsample_factor=1):
            super().__init__()
            self.downsample_factor = downsample_factor
            self.img = Image.fromarray(skimage.data.camera()) #skimage自带的拿相机的人的照片
            self.img_channels = 1
    
            if downsample_factor > 1:
                size = (int(512 / downsample_factor),) * 2
                self.img_downsampled = self.img.resize(size, Image.ANTIALIAS)
    
        def __len__(self):
            return 1
    
        def __getitem__(self, idx):
            if self.downsample_factor > 1:
                return self.img_downsampled
            else:
                return self.img
    
    class Implicit2DWrapper(torch.utils.data.Dataset):
        def __init__(self, dataset, sidelength=None, compute_diff=None):
    
            if isinstance(sidelength, int):
                sidelength = (sidelength, sidelength)
            self.sidelength = sidelength
    
            self.transform = Compose([
                Resize(sidelength),
                ToTensor(),
                Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
            ])
    
            self.compute_diff = compute_diff
            self.dataset = dataset
            self.mgrid = get_mgrid(sidelength)
    
        def __len__(self):
            return len(self.dataset)
    
        def __getitem__(self, idx):
            img = self.transform(self.dataset[idx])
            # self.dataset[idx].save('./camera_people_2.jpg')
    
            if self.compute_diff == 'gradients':
                img *= 1e1
                gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
                grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
            elif self.compute_diff == 'laplacian':
                img *= 1e4
                laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
            elif self.compute_diff == 'all':
                gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
                # print(gradx.shape) #(512, 512, 1)
                grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
                # print(grady.shape) #(512, 512, 1)
                laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
                # print(laplace.shape) #(512, 512, 1)
    
            # print(img.shape) #torch.Size([1, 512, 512])
            #将图像的每一个像素值展开得到262144个像素值
            img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels)
            # print(img.shape) #torch.Size([262144, 1])
    
    
            in_dict = {'idx': idx, 'coords': self.mgrid}
            gt_dict = {'img': img}
    
            if self.compute_diff == 'gradients':
                gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                       torch.from_numpy(grady).reshape(-1, 1)),
                                      dim=-1)
                gt_dict.update({'gradients': gradients})
    
            elif self.compute_diff == 'laplacian':
                gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})
    
            elif self.compute_diff == 'all':
                gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                       torch.from_numpy(grady).reshape(-1, 1)),
                                      dim=-1)
                # print(gradients.shape) #torch.Size([262144, 2])
                gt_dict.update({'gradients': gradients})
                gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})
    
            return in_dict, gt_dict
    
    
    img_dataset = Camera()
    coord_dataset = Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all')
    # in_dict, gt_dict = coord_dataset[3]
    # print(in_dict)
    # print(gt_dict)
    
    # print(in_dict['coords'].shape)
    # print(gt_dict['img'].shape)
    
    # print(gt_dict['gradients'].shape)
    # print(gt_dict['laplace'].shape)
    
    #num_workers=0说明使用单进程
    dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
    
    ############################################################## 数据处理 ##############################
    
    ############################################################## 使用的模型 ##############################
    
    from torchmeta.modules import (MetaModule, MetaSequential)
    
    class Sine(nn.Module):
        def __init(self):
            super().__init__()
    
        def forward(self, input):
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            return torch.sin(30 * input)
    
    def sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                # See supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
    
    def first_layer_sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-1 / num_input, 1 / num_input)
    
    def init_weights_normal(m):
        if type(m) == BatchLinear or type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    
    def init_weights_xavier(m):
        if type(m) == BatchLinear or type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.xavier_normal_(m.weight)
    
    
    def init_weights_selu(m):
        if type(m) == BatchLinear or type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
    
    def init_weights_elu(m):
        if type(m) == BatchLinear or type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))
      
    # 重新写了下nn.Linear层
    class BatchLinear(nn.Linear, MetaModule):
        '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
        hypernetwork.'''
        __doc__ = nn.Linear.__doc__
    
        def forward(self, input, params=None):
            if params is None:
                params = OrderedDict(self.named_parameters()) #得到nn.Linear的参数
    
            bias = params.get('bias', None)
            weight = params['weight']
    
            # print('BatchLinear list :', [i for i in range(len(weight.shape) - 2)]) #[]
            # 不知道这个跟nn.Linear层的原本实现有什么差别
            # output = input.matmul(weight.t())
            # output += bias
            # print('weight.shape before : ', weight.shape) #torch.Size([256, 2])
            print('input.shape : ', input.shape) #torch.Size([1, 262144, 2])
            # print('weight permute :', weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2).shape)#相当于weight的转置操作
            
            # 其实就是x*(A转置) + b 操作
            output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) 
            # print('weight.shape after : ', weight.shape) #torch.Size([256, 2])
            print('output.shape : ', output.shape) #torch.Size([1, 262144, 256])
            output += bias.unsqueeze(-2)
            return output
    
    class ImageDownsampling(nn.Module):
        '''Generate samples in u,v plane according to downsampling blur kernel'''
    
        def __init__(self, sidelength, downsample=False):
            super().__init__()
            if isinstance(sidelength, int):
                self.sidelength = (sidelength, sidelength)
            else:
                self.sidelength = sidelength
    
            if self.sidelength is not None:
                # self.sidelength = torch.Tensor(self.sidelength).cuda().float()
                self.sidelength = torch.Tensor(self.sidelength).float()
            else:
                assert downsample is False
            self.downsample = downsample
    
        def forward(self, coords):
            if self.downsample:
                return coords + self.forward_bilinear(coords)
            else:
                return coords
    
        def forward_box(self, coords):
            return 2 * (torch.rand_like(coords) - 0.5) / self.sidelength
    
        def forward_bilinear(self, coords):
            Y = torch.sqrt(torch.rand_like(coords)) - 1 #torch.rand_like(coords)返回跟coords的tensor一样size的0-1随机数 
            Z = 1 - torch.sqrt(torch.rand_like(coords))
            b = torch.rand_like(coords) < 0.5
    
            Q = (b * Y + ~b * Z) / self.sidelength
            return Q
    
    class FCBlock(MetaModule):
        '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
        Can be used just as a normal neural network though, as well.
        '''
    
        def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                     outermost_linear=False, nonlinearity='relu', weight_init=None):
            super().__init__()
    
            self.first_layer_init = None
    
            # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
            # special first-layer initialization scheme
            nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                             'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                             'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                             'tanh':(nn.Tanh(), init_weights_xavier, None),
                             'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                             'softplus':(nn.Softplus(), init_weights_normal, None),
                             'elu':(nn.ELU(inplace=True), init_weights_elu, None)}
    
            nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]
    
            if weight_init is not None:  # Overwrite weight init if passed
                self.weight_init = weight_init
            else:
                self.weight_init = nl_weight_init
    
            self.net = []
            self.net.append(MetaSequential( #BatchLinear和一个sine层
                BatchLinear(in_features, hidden_features), nl
            ))
    
            for i in range(num_hidden_layers):
                self.net.append(MetaSequential(
                    BatchLinear(hidden_features, hidden_features), nl
                ))
    
            if outermost_linear:
                self.net.append(MetaSequential(BatchLinear(hidden_features, out_features)))
            else:
                self.net.append(MetaSequential(
                    BatchLinear(hidden_features, out_features), nl
                ))
    
            # 如果使用的是sine,第一层的初始化和后面层的初始化是不同的
            self.net = MetaSequential(*self.net)
            if self.weight_init is not None:
                self.net.apply(self.weight_init)
    
            if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
                self.net[0].apply(first_layer_init)
    
        def forward(self, coords, params=None, **kwargs):
            if params is None:
                params = OrderedDict(self.named_parameters())
    
            output = self.net(coords, params=get_subdict(params, 'net'))
            return output
    
        def forward_with_activations(self, coords, params=None, retain_grad=False):
            '''Returns not only model output, but also intermediate activations.'''
            if params is None:
                params = OrderedDict(self.named_parameters())
    
            activations = OrderedDict()
    
            x = coords.clone().detach().requires_grad_(True)
            activations['input'] = x
            for i, layer in enumerate(self.net):
                subdict = get_subdict(params, 'net.%d' % i)
                for j, sublayer in enumerate(layer):
                    if isinstance(sublayer, BatchLinear):
                        x = sublayer(x, params=get_subdict(subdict, '%d' % j))
                    else:
                        x = sublayer(x)
    
                    if retain_grad:
                        x.retain_grad()
                    activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
            return activations
    
    class SingleBVPNet(MetaModule):
        '''A canonical representation network for a BVP.'''
    
        def __init__(self, out_features=1, type='sine', in_features=2,
                     mode='mlp', hidden_features=256, num_hidden_layers=3, **kwargs):
            super().__init__()
            self.mode = mode
    
            if self.mode == 'rbf':
                self.rbf_layer = RBFLayer(in_features=in_features, out_features=kwargs.get('rbf_centers', 1024))
                in_features = kwargs.get('rbf_centers', 1024)
            elif self.mode == 'nerf':
                self.positional_encoding = PosEncodingNeRF(in_features=in_features,
                                                           sidelength=kwargs.get('sidelength', None),
                                                           fn_samples=kwargs.get('fn_samples', None),
                                                           use_nyquist=kwargs.get('use_nyquist', True))
                in_features = self.positional_encoding.out_dim
    
            self.image_downsampling = ImageDownsampling(sidelength=kwargs.get('sidelength', None),
                                                        downsample=kwargs.get('downsample', False))
            self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
                               hidden_features=hidden_features, outermost_linear=True, nonlinearity=type)
            print(self)
    
        def forward(self, model_input, params=None):
            if params is None:
                params = OrderedDict(self.named_parameters())
    
            # Enables us to compute gradients w.r.t. coordinates
            coords_org = model_input['coords'].clone().detach().requires_grad_(True)
            coords = coords_org
    
            # various input processing methods for different applications
            if self.image_downsampling.downsample:
                coords = self.image_downsampling(coords)
            if self.mode == 'rbf':
                coords = self.rbf_layer(coords)
            elif self.mode == 'nerf':
                coords = self.positional_encoding(coords)
    
            output = self.net(coords, get_subdict(params, 'net'))
            return {'model_in': coords_org, 'model_out': output}
    
    
    # 该模型的作用就是输入(512,512)图像对应的大小为[batch_size, 262144, 2]像素坐标model_input['coords']
    # 输出对应的大小为[batch_size, 262144, 1]的像素值,output['model_out']
    # SingleBVPNet模型就是拟合的带参数theta的函数
    # 最后用损失MSE去计算得到的像素值output['model_out']和真正的像素值gt['img']之间的误差
    # 减少该误差来训练网络
    model = SingleBVPNet(type='sine', mode='mlp', sidelength=(512, 512))
    # for i in model.children():
    #     print(i)
    
    # 这里的输入只有一张图,即那个照相的男人
    # 拟合网络生成这张图
    for step, (model_input, gt) in enumerate(dataloader):
        print('-'*30)
        print('step : ', step)
        print(model_input['coords'].shape)
        print(gt['img'].shape)
    
        output = model(model_input)
        print('model in : ', output['model_in'].shape)
        print('model out : ', output['model_out'].shape)
    View Code

    返回:

    SingleBVPNet(
      (image_downsampling): ImageDownsampling()
      (net): FCBlock(
        (net): MetaSequential(
          (0): MetaSequential(
            (0): BatchLinear(in_features=2, out_features=256, bias=True)
            (1): Sine()
          )
          (1): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (2): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (3): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=256, bias=True)
            (1): Sine()
          )
          (4): MetaSequential(
            (0): BatchLinear(in_features=256, out_features=1, bias=True)
          )
        )
      )
    )
    ------------------------------
    step :  0
    torch.Size([1, 262144, 2])
    torch.Size([1, 262144, 1])
    input.shape :  torch.Size([1, 262144, 2])
    output.shape :  torch.Size([1, 262144, 256])
    input.shape :  torch.Size([1, 262144, 256])
    output.shape :  torch.Size([1, 262144, 256])
    input.shape :  torch.Size([1, 262144, 256])
    output.shape :  torch.Size([1, 262144, 256])
    input.shape :  torch.Size([1, 262144, 256])
    output.shape :  torch.Size([1, 262144, 256])
    input.shape :  torch.Size([1, 262144, 256])
    output.shape :  torch.Size([1, 262144, 1])
    model in :  torch.Size([1, 262144, 2])
    model out :  torch.Size([1, 262144, 1])

    可见sine激活函数实现使用:

    # 重新写了下nn.Linear层
    class BatchLinear(nn.Linear, MetaModule):
        '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
        hypernetwork.'''
        __doc__ = nn.Linear.__doc__
    
        def forward(self, input, params=None):
            if params is None:
                params = OrderedDict(self.named_parameters()) #得到nn.Linear的参数
    
            bias = params.get('bias', None)
            weight = params['weight']
    
            # print('BatchLinear list :', [i for i in range(len(weight.shape) - 2)]) #[]
            # 不知道这个跟nn.Linear层的原本实现有什么差别
            # output = input.matmul(weight.t())
            # output += bias
            # print('weight.shape before : ', weight.shape) #torch.Size([256, 2])
            print('input.shape : ', input.shape) #torch.Size([1, 262144, 2])
            # print('weight permute :', weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2).shape)#相当于weight的转置操作
            
            # 其实就是x*(A转置) + b 操作
            output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) 
            # print('weight.shape after : ', weight.shape) #torch.Size([256, 2])
            print('output.shape : ', output.shape) #torch.Size([1, 262144, 256])
            # print('bias before:', bias.shape) #torch.Size([256])
            # print('bias after:', bias.unsqueeze(-2).shape)
            output += bias.unsqueeze(-2) #torch.Size([1, 256])
            return output

    参数w(weight)和b(bias)都在该层,得到sine()的输入wTx+b

    然后对BatchLinear的输出wTx+b使用sine()激活函数:

    class Sine(nn.Module):
        def __init(self):
            super().__init__()
    
        def forward(self, input):
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            return torch.sin(30 * input) #w0=30
    
    def sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1) #num_input即in_features_num
                # See supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
    
    def first_layer_sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-1 / num_input, 1 / num_input)
  • 相关阅读:
    MongoDB学习:(一)MongoDB安装
    事件轮询 Event Loop
    常见的HTML5语义化标签
    前端动画性能优化方案
    前端动画的实现
    《SVN的操作流程及规范》
    css、js文件后的后缀作用是什么?
    实现单行文字溢出显示...,以及多行文字溢出显示...
    从输入URL到页面返回的过程详解
    jQuery实现点击复制效果
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/13215031.html
Copyright © 2011-2022 走看看