GRDN网络结构代码实现
SubNets.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def weights_init(m):
"""
custom weights initialization called on netG and netD
https://github.com/pytorch/examples/blob/master/dcgan/main.py
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
####################################################################################################################
class make_dense(nn.Module):
def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
super(make_dense, self).__init__()
self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
bias=False)
self.nChannels = nChannels
def forward(self, x):
out = F.relu(self.conv(x))
out = torch.cat((x, out), 1)
return out
class make_residual_dense_ver1(nn.Module):
def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
super(make_residual_dense_ver1, self).__init__()
self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
bias=False)
self.nChannels_ = nChannels_
self.nChannels = nChannels
self.growthrate = growthRate
def forward(self, x):
# print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
# print('2', outoflayer.shape)
# print('3', out.shape, outoflayer.shape)
# print('4', out.shape)
outoflayer = F.relu(self.conv(x))
out = torch.cat((x[:, :self.nChannels, :, :] + outoflayer, x[:, self.nChannels:, :, :]), 1)
out = torch.cat((out, outoflayer), 1)
return out
class make_residual_dense_ver2(nn.Module):
def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
super(make_residual_dense_ver2, self).__init__()
if nChannels == nChannels_ :
self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
bias=False)
else:
self.conv = nn.Conv2d(nChannels_ + growthRate, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
bias=False)
self.nChannels_ = nChannels_
self.nChannels = nChannels
self.growthrate = growthRate
def forward(self, x):
# print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
# print('2', outoflayer.shape)
# print('3', out.shape, outoflayer.shape)
# print('4', out.shape)
outoflayer = F.relu(self.conv(x))
if x.shape[1] == self.nChannels:
out = torch.cat((x, x + outoflayer), 1)
else:
out = torch.cat((x[:, :self.nChannels, :, :], x[:, self.nChannels:self.nChannels + self.growthrate, :, :] + outoflayer, x[:, self.nChannels + self.growthrate:, :, :]), 1)
out = torch.cat((out, outoflayer), 1)
return out
class make_dense_LReLU(nn.Module):
def __init__(self, nChannels, growthRate, kernel_size=3):
super(make_dense_LReLU, self).__init__()
self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
bias=False)
def forward(self, x):
out = F.leaky_relu(self.conv(x))
out = torch.cat((x, out), 1)
return out
# Residual dense block (RDB) architecture
class RDB(nn.Module):
"""
https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
"""
def __init__(self, nChannels, nDenselayer, growthRate):
"""
:param nChannels: input feature 의 channel 수
:param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
:param growthRate: Conv 의 output layer 의 수
"""
super(RDB, self).__init__()
nChannels_ = nChannels
modules = []
for i in range(nDenselayer):
modules.append(make_dense(nChannels, nChannels_, growthRate))
nChannels_ += growthRate
self.dense_layers = nn.Sequential(*modules)
###################kingrdb ver2##############################################
# self.conv_1x1 = nn.Conv2d(nChannels_ + growthRate, nChannels, kernel_size=1, padding=0, bias=False)
###################else######################################################
self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
out = self.dense_layers(x)
out = self.conv_1x1(out)
# local residual 구조
out = out + x
return out
def RDB_Blocks(channels, size):
bundle = []
for i in range(size):
bundle.append(RDB(channels, nDenselayer=8, growthRate=64)) # RDB(input channels,
return nn.Sequential(*bundle)
####################################################################################################################
# Group of Residual dense block (GRDB) architecture
class GRDB(nn.Module):
"""
https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
"""
def __init__(self, numofkernels, nDenselayer, growthRate, numforrg):
"""
:param nChannels: input feature 의 channel 수
:param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
:param growthRate: Conv 의 output layer 의 수
"""
super(GRDB, self).__init__()
modules = []
for i in range(numforrg):
modules.append(RDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate))
self.rdbs = nn.Sequential(*modules)
self.conv_1x1 = nn.Conv2d(numofkernels * numforrg, numofkernels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
out = x
outputlist = []
for rdb in self.rdbs:
output = rdb(out)
outputlist.append(output)
out = output
concat = torch.cat(outputlist, 1)
out = x + self.conv_1x1(concat)
return out
# Group of group of Residual dense block (GRDB) architecture
class GGRDB(nn.Module):
"""
https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
"""
def __init__(self, numofmodules, numofkernels, nDenselayer, growthRate, numforrg):
"""
:param nChannels: input feature 의 channel 수
:param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
:param growthRate: Conv 의 output layer 의 수
"""
super(GGRDB, self).__init__()
modules = []
for i in range(numofmodules):
modules.append(GRDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate, numforrg=numforrg))
self.grdbs = nn.Sequential(*modules)
def forward(self, x):
output = x
for grdb in self.grdbs:
output = grdb(output)
return x + output
####################################################################################################################
class ResidualBlock(nn.Module):
"""
one_to_many 논문에서 제시된 resunit 구조
"""
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
residual = self.bn1(x)
residual = self.relu1(residual)
residual = self.conv1(residual)
residual = self.bn2(residual)
residual = self.relu2(residual)
residual = self.conv2(residual)
return x + residual
def ResidualBlocks(channels, size):
bundle = []
for i in range(size):
bundle.append(ResidualBlock(channels))
return nn.Sequential(*bundle)
DenoisingMoels.py
from models.subNets import *
from models.cbam import *
class ntire_rdb_gd_rir_ver1(nn.Module):
def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
super(ntire_rdb_gd_rir_ver1, self).__init__()
self.numforrg = numforrg # num of rdb units in one residual group
self.numofrdb = numofrdb # num of all rdb units
self.nDenselayer = numofconv
self.numofkernels = numoffilters
self.t = t
self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
# self.layer2 = nn.ReLU()
self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
modules = []
for i in range(self.numofrdb // self.numforrg):
modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
self.rglayer = nn.Sequential(*modules)
self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
# self.layer8 = nn.ReLU()
self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
self.cbam = CBAM(self.numofkernels, 16)
def forward(self, x):
out = self.layer1(x)
# out = self.layer2(out)
out = self.layer3(out)
# out = self.rglayer(out)
for grdb in self.rglayer:
for i in range(self.t):
out = grdb(out)
out = self.layer7(out)
out = self.cbam(out)
# out = self.layer8(out)
out = self.layer9(out)
# global residual 구조
return out + x
class ntire_rdb_gd_rir_ver2(nn.Module):
def __init__(self, input_channel, numofmodules=2, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
super(ntire_rdb_gd_rir_ver2, self).__init__()
self.numofmodules = numofmodules # num of modules to make residual
self.numforrg = numforrg # num of rdb units in one residual group
self.numofrdb = numofrdb # num of all rdb units
self.nDenselayer = numofconv
self.numofkernels = numoffilters
self.t = t
self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
# self.layer2 = nn.ReLU()
self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
modules = []
for i in range(self.numofrdb // (self.numofmodules * self.numforrg)):
modules.append(GGRDB(self.numofmodules, self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
for i in range((self.numofrdb % (self.numofmodules * self.numforrg)) // self.numforrg):
modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
self.rglayer = nn.Sequential(*modules)
self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
# self.layer8 = nn.ReLU()
self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
self.cbam = CBAM(numoffilters, 16)
def forward(self, x):
out = self.layer1(x)
# out = self.layer2(out)
out = self.layer3(out)
for grdb in self.rglayer:
for i in range(self.t):
out = grdb(out)
out = self.layer7(out)
out = self.cbam(out)
# out = self.layer8(out)
out = self.layer9(out)
# global residual 구조
return out + x
class Generator_one2many_gd_rir_old(nn.Module):
def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64):
super(Generator_one2many_gd_rir_old, self).__init__()
self.numforrg = numforrg # num of rdb units in one residual group
self.numofrdb = numofrdb # num of all rdb units
self.nDenselayer = numofconv
self.numofkernels = numoffilters
self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
self.layer2 = nn.ReLU()
self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
modules = []
for i in range(self.numofrdb // self.numforrg):
modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
self.rglayer = nn.Sequential(*modules)
self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
self.layer8 = nn.ReLU()
self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
self.cbam = CBAM(self.numofkernels, 16)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.rglayer(out)
out = self.layer7(out)
out = self.cbam(out)
out = self.layer8(out)
out = self.layer9(out)
# global residual 구조
return out + x
cbma.py
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=False, bn=False, bias=True):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
def weights_init_rcan(m):
"""
custom weights initialization called on netG and netD
https://github.com/pytorch/examples/blob/master/dcgan/main.py
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if classname.find('BasicConv') != -1:
m.conv.weight.data.normal_(0.0, 0.02)
if m.bn != None:
m.bn.bias.data.fill_(0)
else:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
DGU-3DMlab1_track1.py
import numpy as np
import cv2
import torch
from models.DenoisingModels import *
from utils.utils import *
from utils.transforms import *
import scipy.io as sio
import time
import tqdm
if __name__ == '__main__':
print('********************Test code for NTIRE challenge******************')
# path of input .mat file
mat_dir = 'mats/BenchmarkNoisyBlocksRaw.mat'
# Read .mat file
mat_file = sio.loadmat(mat_dir)
# get input numpy
noisyblock = mat_file['BenchmarkNoisyBlocksRaw']
print('input shape', noisyblock.shape)
# path of saved pkl file of model
modelpath = 'checkpoints/DGU-3DMlab1_track1.pkl'
expname = 'DGU-3DMlab1_track1'
# set gpu
device = torch.device('cuda:0')
# make network object
model = Generator_one2many_gd_rir_old(input_channel=1, numforrg=4, numofrdb=16, numofconv=8, numoffilters=67).to(device)
# make numpy of output with same shape of input
resultNP = np.ones(noisyblock.shape)
print('resultNP.shape', resultNP.shape)
submitpath = f'results_folder/{expname}'
make_dirs(submitpath)
# load checkpoint of the model
checkpoint = torch.load(modelpath)
model.load_state_dict(checkpoint['state_dict'])
transform = ToTensor()
revtransform = ToImage()
# pass inputs through model and get outputs
with torch.no_grad():
model.eval()
starttime = time.time() # check when model starts to process
for imgidx in tqdm.tqdm(range(noisyblock.shape[0])):
for patchidx in range(noisyblock.shape[1]):
img = noisyblock[imgidx][patchidx] # img shape (256, 256, 3)
input = transform(img).float()
input = input.view(1, -1, input.shape[1], input.shape[2]).to(device)
output = model(input) # pass input through model
outimg = revtransform(output) # transform output tensor to numpy
# put output patch into result numpy
resultNP[imgidx][patchidx] = outimg
# check time after finishing task for all input patches
endtime = time.time()
elapsedTime = endtime - starttime # calculate elapsed time
print('ended', elapsedTime)
num_of_pixels = noisyblock.shape[0] * noisyblock.shape[1] * noisyblock.shape[2] * noisyblock.shape[3]
print('number of pixels', num_of_pixels)
runtime_per_mega_pixels = (num_of_pixels / 1000000) / elapsedTime
print('Runtime per mega pixel', runtime_per_mega_pixels)
# save result numpy as .mat file
sio.savemat(f'{submitpath}/{expname}', dict([('results', resultNP)]))