zoukankan      html  css  js  c++  java
  • LEP+低秩+神经网络去噪

    from __future__ import print_function
    import matplotlib
    import matplotlib.pyplot as plt
    %matplotlib inline
    import scipy.misc
    import os
    import numpy as np
    
    from models.resnet import ResNet
    from models.unet import UNet
    from models.skip import skip
    import torch
    import torch.optim
    
    from utils.inpainting_utils import *
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark =True
    dtype = torch.cuda.FloatTensor
    
    PLOT = True
    imsize = -1
    dim_div_by = 64
    NET_TYPE = 'skip_depth6'
    
    iteation_LEP = '/home/hxj/桌面/PG/test/iteation+LEP/'
    LEP = '/home/hxj/桌面/PG/test/LEP-only/'
    ORI = '/home/hxj/gluon-tutorials/GAN/MultiPIE/YaleB_test_crop_gray/'
    img_name = 'yaleB38_P00A-130E+20.png'
    real_face_name='data/face/reSVD10.png'
    
    pad = 'reflection' # 'zero'
    OPT_OVER = 'net'
    OPTIMIZER = 'adam'
    INPUT = 'noise'
    input_depth = 32
    #input_depth = 4
    num_iter = 600
    param_noise = False
    figsize = 5 
    reg_noise_std = 0.03
    LR = 0.01
    mse = torch.nn.MSELoss().type(dtype)
    #i = 0
    def closure():
        #global i
        
        if param_noise:
            for n in [x for x in net.parameters() if len(x.size()) == 4]:
                n = n + n.detach().clone().normal_() * n.std() / 50
        
        net_input = net_input_saved
        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)
            
            
        out = net(net_input)
       
        #total_loss = mse(out * mask_var, img_var * mask_var)
        #total_loss = mse(out, img_var)
        total_loss = mse(out,itLEP_var) + mse(out,ORI_var)*0.1+ mse(out,LEP_var)*0.2 + mse(out,RF_var)*0.5
        total_loss.backward()
            
        print ('Iteration %s     Loss %f' % (img_name, total_loss.item()), '
    ', end='')
        #if  PLOT and i % show_every == 0:
            #out_np = torch_to_np(out)
            #img_save =(np.clip(out_np, 0, 1))[0]
            #scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/'+str(i)+'_'+img_name)
            #plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
            #plt.imshow(img_save)
            #plt.axis('off')
            #plt.savefig('result/'+str(i)+'_'+img_name,dpi=128*128)
            #plt.show()
             
           
        #i += 1
    
        return total_loss
    RF_pil, RF_np = get_image(real_face_name, imsize)
    RF_var = np_to_torch(RF_np).type(dtype)
    
    files = os.listdir(iteation_LEP)
    for img_name in files:
        itLEP_pil, itLEP_np = get_image(iteation_LEP+img_name, imsize)
        LEP_pil, LEP_np = get_image(LEP+img_name, imsize)
        ORI_pil, ORI_np = get_image(ORI+img_name, imsize)
        
        itLEP_var = np_to_torch(itLEP_np).type(dtype)
        LEP_var = np_to_torch(LEP_np).type(dtype)
        ORI_var = np_to_torch(ORI_np).type(dtype)
        
        net = skip(input_depth, itLEP_np.shape[0], 
               num_channels_down = [128] * 5,
               num_channels_up =   [128] * 5,
               num_channels_skip =    [128] * 5,
               filter_size_up = 3, filter_size_down = 3,
               upsample_mode='nearest', filter_skip_size=1,
               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
        
        net_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
        # net_input[0,0,:] = itLEP_var
        # net_input[0,1,:] = LEP_var
        # net_input[0,2,:] = ORI_var
        # net_input[0,3,:] = RF_var
        #net_input = np_to_torch(RF_np).type(dtype)
        
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        p = get_params(OPT_OVER, net, net_input)
        optimize(OPTIMIZER, p, closure, LR, num_iter)
    
        
        out_np = torch_to_np(net(net_input))
        img_save =(np.clip(out_np, 0, 1))[0]
        scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/noise_input/0.01/'+img_name)
        
  • 相关阅读:
    如何提高Java并行程序性能??
    《实战Java虚拟机》,最简单的JVM入门书,京东活动,满200就减100了,该出手了
    看JVM就推荐这本书
    【Java】实战Java虚拟机之五“开启JIT编译”
    实战Java虚拟机之四:提升性能,禁用System.gc() ?
    实战Java虚拟机之三“G1的新生代GC”
    实战Java虚拟机之二“虚拟机的工作模式”
    实战Java虚拟机之一“堆溢出处理”
    实战java虚拟机的学习计划图(看懂java虚拟机)
    aspose.cells 复制单元格
  • 原文地址:https://www.cnblogs.com/hxjbc/p/10817751.html
Copyright © 2011-2022 走看看