zoukankan      html  css  js  c++  java
  • Sym-GAN

    import sys; 
    sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
    from __future__ import print_function
    import os
    import matplotlib as mpl
    import tarfile
    import matplotlib.image as mpimg
    from matplotlib import pyplot as plt
    import cv2
    import mxnet as mx
    from mxnet import gluon
    from mxnet import ndarray as nd
    from mxnet.gluon import nn, utils
    from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, 
        BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout
    from mxnet import autograd
    import numpy as np
    
    epochs = 500
    batch_size = 10
    
    use_gpu = True
    ctx = mx.gpu() if use_gpu else mx.cpu()
    
    lr = 0.0002
    beta1 = 0.5
    #lambda1 = 100
    lambda1 = 10
    
    pool_size = 50
    img_horizon = mx.image.HorizontalFlipAug(1)
    def load_retinex(batch_size):
        img_in_list = []
        img_out_list = []
    
        """
        path='CAS/Lighting_aligned_128'
        ground_path = 'CAS/Lighting_aligned_128_retinex_to_color'
        
        for path, _, fnames in os.walk(path):
            for fname in fnames:
                if not fname.endswith('.png'):
                    continue
                          
                lingting_img = os.path.join(path, fname)
                ground_img = os.path.join(ground_path,fname)
                        
                #补充水平翻转和光照增加或者减少50%
                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                img_arr_fname_t = img_horizon(img_arr_fname)
                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
                img_arr_gnema_t = img_horizon(img_arr_gnema)
                
                img_arr_fname = cv2.cvtColor(img_arr_fname.asnumpy(), cv2.COLOR_RGB2LAB)
                img_arr_fname_t = cv2.cvtColor(img_arr_fname_t.asnumpy(), cv2.COLOR_RGB2LAB)
                img_arr_gnema = cv2.cvtColor(img_arr_gnema.asnumpy(), cv2.COLOR_RGB2LAB)
                img_arr_gnema_t = cv2.cvtColor(img_arr_gnema_t.asnumpy(), cv2.COLOR_RGB2LAB)
                
                
                img_arr_in, img_arr_out = [img_arr_fname[:,:,0].reshape((1,) + img_arr_in.shape),
                                           img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_in_list.append(img_arr_in)
                img_out_list.append(img_arr_out)
                
                img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                               nd.transpose(img_arr_gnema_t, (2,0,1))]
                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                               img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                img_in_list.append(img_arr_in_t)
                img_out_list.append(img_arr_out_t)
        """       
        mulpath_lighting = 'MultiPIE/MultiPIE_Lighting/'
        mulpaht_ground = 'MultiPIE/MultiPIE_Lighting/'
        for path, _, fnames in os.walk(mulpath_lighting):
            for fname in fnames:
                num = fname[14:16]
                if num !='07':
                    lingting_img = os.path.join(mulpath_lighting, fname)
                    ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')
                    img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                    img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
                              
                
                #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)
                #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)
                #补充水平翻转和光照增加或者减少50%
                #img_arr_fname_b = img_bright(img_arr_fname)
                    
                    img_arr_fname_t = img_horizon(img_arr_fname)
                    img_arr_gnema_t = img_horizon(img_arr_gnema)
                  #lighting image 共4个,normal ground truth共2个          
                   
                    img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                               nd.transpose(img_arr_gnema, (2,0,1))]
                    img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                               img_arr_out.reshape((1,) + img_arr_out.shape)]
                    img_in_list.append(img_arr_in)
                    img_out_list.append(img_arr_out)
                
                    img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                                   nd.transpose(img_arr_gnema_t, (2,0,1))]
                    img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                                   img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                    img_in_list.append(img_arr_in_t)
                    img_out_list.append(img_arr_out_t)
                    
                
        return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
        
    img_wd = 256
    img_ht = 256
    train_img_path = '../data/edges2handbags/train_mini/'
    val_img_path = '../data/edges2handbags/val/' 
    
    def load_data(path, batch_size, is_reversed=False):
        img_in_list = []
        img_out_list = []
        for path, _, fnames in os.walk(path):
            for fname in fnames:
                if not fname.endswith('.jpg'):
                    continue
                img = os.path.join(path, fname)
                img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
                img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
                # Crop input and output images
                img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                           mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
                img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),
                                           nd.transpose(img_arr_out, (2,0,1))]
                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                           img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_in_list.append(img_arr_out if is_reversed else img_arr_in)
                img_out_list.append(img_arr_in if is_reversed else img_arr_out)
    
        return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)],
                                 batch_size=batch_size)
    
    
    train_data = load_data(train_img_path, batch_size, is_reversed=False)
    val_data = load_data(val_img_path, batch_size, is_reversed=False)
    img_horizon = mx.image.HorizontalFlipAug(1)
    def load_retinex(batch_size):
        img_in_list = []
        img_out_list = []
        
        path='CAS/Lighting_aligned_128'
        ground_path = 'CAS/Normal_aligned_128'
        img_in_list = []
        img_out_list = []
        """ 
        for path, _, fnames in os.walk(path):
            for fname in fnames:
                if not fname.endswith('.png'):
                    continue
                
                temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R0_S0.png'
                ground_img = os.path.join(ground_path, temp_name)
                if not os.path.exists(ground_img):
                    temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R1_S0.png'
                    ground_img = os.path.join(ground_path, temp_name)
                if not os.path.exists(ground_img):
                    continue
                lingting_img = os.path.join(path, fname)
                        
                #补充水平翻转和光照增加或者减少50%
                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                img_arr_fname_t = img_horizon(img_arr_fname)
                         
                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
                img_arr_gnema_t = img_horizon(img_arr_gnema)
                  
                img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                        nd.transpose(img_arr_gnema, (2,0,1))]
                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                        img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_in_list.append(img_arr_in)
                img_out_list.append(img_arr_out)
                         
                img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                                nd.transpose(img_arr_gnema_t, (2,0,1))]
                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                             img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                img_in_list.append(img_arr_in_t)
                img_out_list.append(img_arr_out_t)
                
        """       
        mulpath_lighting = 'MultiPIE/MultiPIE_Lighting_128/'
        mulpaht_ground = 'MultiPIE/MultiPIE_Lighting_128/'
        for path, _, fnames in os.walk(mulpath_lighting):
            for fname in fnames:
                num = fname[14:16]
                if num !='07':
                    lingting_img = os.path.join(mulpath_lighting, fname)
                    ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')
                    img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                    img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
                
                #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)
                #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)
                #补充水平翻转和光照增加或者减少50%
                #img_arr_fname_b = img_bright(img_arr_fname)
                    
                    img_arr_fname_t = img_horizon(img_arr_fname)
                    img_arr_gnema_t = img_horizon(img_arr_gnema)
                  #lighting image 共4个,normal ground truth共2个          
                   
                    img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                               nd.transpose(img_arr_gnema, (2,0,1))]
                    img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                               img_arr_out.reshape((1,) + img_arr_out.shape)]
                    img_in_list.append(img_arr_in)
                    img_out_list.append(img_arr_out)
                
                    img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                                   nd.transpose(img_arr_gnema_t, (2,0,1))]
                    img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                                   img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                    img_in_list.append(img_arr_in_t)
                    img_out_list.append(img_arr_out_t)
                    
           
        return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
        
    def visualize(img_arr):
        plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
        plt.axis('off')
    def preview_train_data(train_data):
        img_in_list, img_out_list = train_data.next().data
        for i in range(4):
            plt.subplot(2,4,i+1)
            visualize(img_in_list[i])
            plt.subplot(2,4,i+5)
            visualize(img_out_list[i])
        plt.show()
    
    
    train_data = load_retinex(10)
    preview_train_data(train_data)
    # Define Unet generator skip block
    class UnetSkipUnit(HybridBlock):
        def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False,
                     use_dropout=False, use_bias=False):
            super(UnetSkipUnit, self).__init__()
    
            with self.name_scope():
                self.outermost = outermost
                en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1,
                                 in_channels=outer_channels, use_bias=use_bias)
                en_relu = LeakyReLU(alpha=0.2)
                en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)
                de_relu = Activation(activation='relu')
                de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)
    
                if innermost:
                    de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                              in_channels=inner_channels, use_bias=use_bias)
                    encoder = [en_relu, en_conv]
                    decoder = [de_relu, de_conv, de_norm]
                    model = encoder + decoder
                elif outermost:
                    de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                              in_channels=inner_channels * 2)
                    encoder = [en_conv]
                    decoder = [de_relu, de_conv, Activation(activation='tanh')]
                    model = encoder + [inner_block] + decoder
                else:
                    de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                              in_channels=inner_channels * 2, use_bias=use_bias)
                    encoder = [en_relu, en_conv, en_norm]
                    decoder = [de_relu, de_conv, de_norm]
                    model = encoder + [inner_block] + decoder
                if use_dropout:
                    model += [Dropout(rate=0.5)]
    
                self.model = HybridSequential()
                with self.model.name_scope():
                    for block in model:
                        self.model.add(block)
    
        def hybrid_forward(self, F, x):
            if self.outermost:
                return self.model(x)
            else:
                return F.concat(self.model(x), x, dim=1)
    
    # Define Unet generator
    class UnetGenerator(HybridBlock):
        def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True):
            super(UnetGenerator, self).__init__()
    
            #Build unet generator structure
            unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True)
            for _ in range(num_downs - 5):
                unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout)
            unet = UnetSkipUnit(ngf * 8, ngf * 4, unet)
            unet = UnetSkipUnit(ngf * 4, ngf * 2, unet)
            unet = UnetSkipUnit(ngf * 2, ngf * 1, unet)
            unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True)
    
            with self.name_scope():
                self.model = unet
    
        def hybrid_forward(self, F, x):
            return self.model(x)
    
    # Define the PatchGAN discriminator
    class Discriminator(HybridBlock):
        def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False):
            super(Discriminator, self).__init__()
    
            with self.name_scope():
                self.model = HybridSequential()
                kernel_size = 4
                padding = int(np.ceil((kernel_size - 1)/2))
                self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2,
                                      padding=padding, in_channels=in_channels))
                self.model.add(LeakyReLU(alpha=0.2))
    
                nf_mult = 1
                for n in range(1, n_layers):
                    nf_mult_prev = nf_mult
                    nf_mult = min(2 ** n, 8)
                    self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2,
                                          padding=padding, in_channels=ndf * nf_mult_prev,
                                          use_bias=use_bias))
                    self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
                    self.model.add(LeakyReLU(alpha=0.2))
    
                nf_mult_prev = nf_mult
                nf_mult = min(2 ** n_layers, 8)
                self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1,
                                      padding=padding, in_channels=ndf * nf_mult_prev,
                                      use_bias=use_bias))
                self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
                self.model.add(LeakyReLU(alpha=0.2))
                self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1,
                                      padding=padding, in_channels=ndf * nf_mult))
                if use_sigmoid:
                    self.model.add(Activation(activation='sigmoid'))
    
        def hybrid_forward(self, F, x):
            out = self.model(x)
            #print(out)
            return out
    def param_init(param):
        if param.name.find('conv') != -1:
            if param.name.find('weight') != -1:
                param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
                
            else:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
        elif param.name.find('batchnorm') != -1:
            param.initialize(init=mx.init.Zero(), ctx=ctx)
            # Initialize gamma from normal distribution with mean 1 and std 0.02
            if param.name.find('gamma') != -1:
                param.set_data(nd.random_normal(1, 0.02, param.data().shape))
    
    def network_init(net):
        with net.name_scope():
            for param in net.collect_params().values():
                param_init(param)
    
    def set_network():
        # Pixel2pixel networks
        netG1 = UnetGenerator(in_channels=3, num_downs=6)
        netD1 = Discriminator(in_channels=6)
        netG2 = UnetGenerator(in_channels=3, num_downs=6)
        netD2 = Discriminator(in_channels=6)
    
        # Initialize parameters
        network_init(netG1)
        network_init(netD1)
        network_init(netG2)
        network_init(netD2)
    
        # trainer for the generator and the discriminator
        trainerG1 = gluon.Trainer(netG1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
        trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
        
        trainerG2 = gluon.Trainer(netG2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
        trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
        return netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2
    
    # Loss
    #GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    GAN_loss = gluon.loss.L2Loss()
    L1_loss = gluon.loss.L1Loss()
    L2_loss = gluon.loss.L2Loss()
    
    netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 = set_network()
    class ImagePool():
        def __init__(self, pool_size):
            self.pool_size = pool_size
            if self.pool_size > 0:
                self.num_imgs = 0
                self.images = []
    
        def query(self, images):
            if self.pool_size == 0:
                return images
            ret_imgs = []
            for i in range(images.shape[0]):
                image = nd.expand_dims(images[i], axis=0)
                if self.num_imgs < self.pool_size:
                    self.num_imgs = self.num_imgs + 1
                    self.images.append(image)
                    ret_imgs.append(image)
                else:
                    p = nd.random_uniform(0, 1, shape=(1,)).asscalar()
                    if p > 0.5:
                        random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar()
                        tmp = self.images[random_id].copy()
                        self.images[random_id] = image
                        ret_imgs.append(tmp)
                    else:
                        ret_imgs.append(image)
            ret_imgs = nd.concat(*ret_imgs, dim=0)
            return ret_imgs

    #这是retinex使用的代码

    def singleScaleRetinex(img, sigma):
        retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma))
        return retinex
    
    def multiScaleRetinex(img, sigma_list):
        retinex = np.zeros_like(img)
        for sigma in sigma_list:
            retinex += singleScaleRetinex(img, sigma)
        retinex = retinex / len(sigma_list)
        return retinex
    
    def colorRestoration(img, alpha, beta):
        img_sum = np.sum(img, axis=2, keepdims=True)
        color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum))
        return color_restoration
    
    def simplestColorBalance(img, low_clip, high_clip):    
        total = img.shape[0] * img.shape[1]
        for i in range(img.shape[2]):
            unique, counts = np.unique(img[:, :, i], return_counts=True)
            current = 0
            for u, c in zip(unique, counts):            
                if float(current) / total < low_clip:
                    low_val = u
                if float(current) / total < high_clip:
                    high_val = u
                current += c
            img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val)
        return img    
    
    def MSRCR(img, sigma_list, G, b, alpha, beta, low_clip, high_clip):
        img = np.float64(img) + 1.0
        img_retinex = multiScaleRetinex(img, sigma_list)    
        img_color = colorRestoration(img, alpha, beta)    
        img_msrcr = G * (img_retinex * img_color + b)
        for i in range(img_msrcr.shape[2]):
            img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / 
                                 (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * 
                                 255
        
        img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255))
        img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip)       
        return img_msrcr
    
    def automatedMSRCR(img, sigma_list):
        img = np.float64(img) + 1.0
        img_retinex = multiScaleRetinex(img, sigma_list)
        for i in range(img_retinex.shape[2]):
            unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True)
            for u, c in zip(unique, count):
                if u == 0:
                    zero_count = c
                    break
                
            low_val = unique[0] / 100.0
            high_val = unique[-1] / 100.0
            for u, c in zip(unique, count):
                if u < 0 and c < zero_count * 0.1:
                    low_val = u / 100.0
                if u > 0 and c < zero_count * 0.1:
                    high_val = u / 100.0
                    break
            img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val)
            
            img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / 
                                   (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) 
                                   * 255
        img_retinex = np.uint8(img_retinex)
        return img_retinex
    
    def MSRCP(img, sigma_list, low_clip, high_clip):
        img = np.float64(img) + 1.0
        intensity = np.sum(img, axis=2) / img.shape[2]    
        retinex = multiScaleRetinex(intensity, sigma_list)
        intensity = np.expand_dims(intensity, 2)
        retinex = np.expand_dims(retinex, 2)
        intensity1 = simplestColorBalance(retinex, low_clip, high_clip)
        intensity1 = (intensity1 - np.min(intensity1)) / 
                     (np.max(intensity1) - np.min(intensity1)) * 
                     255.0 + 1.0
        img_msrcp = np.zeros_like(img)
        for y in range(img_msrcp.shape[0]):
            for x in range(img_msrcp.shape[1]):
                B = np.max(img[y, x])
                A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0])
                img_msrcp[y, x, 0] = A * img[y, x, 0]
                img_msrcp[y, x, 1] = A * img[y, x, 1]
                img_msrcp[y, x, 2] = A * img[y, x, 2]
    
        img_msrcp = np.uint8(img_msrcp - 1.0)
        return img_msrcp

    #预训练

    from datetime import datetime
    import time
    import logging
    
    def facc(label, pred):
            pred = pred.ravel()
            label = label.ravel()
            return ((pred > 0.5) == label).mean()
    def pre_train():
        metric = mx.metric.CustomMetric(facc)
        stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
        logging.basicConfig(level=logging.DEBUG)
    
        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            train_data.reset()
            iter = 0
            for batch in train_data:
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch.data[0].as_in_context(ctx)
                real_out = batch.data[1].as_in_context(ctx)
               
                   
                with autograd.record():
                    fake_out = netG1(real_in)
                    errG1 = L1_loss(fake_out, real_out)*lambda1
                    #errG1 = land_mark_errs(real_in, fake_out)
                    errG1.backward()
    
                trainerG1.step(batch.data[0].shape[0])
                 
                with autograd.record():
                    fake_out2 = netG2(real_out)
                    errG2 = L1_loss(fake_out2, real_in)*lambda1 
                    errG2.backward()
    
                trainerG2.step(batch.data[0].shape[0])
           
                # Print log infomation every ten batches
                if iter % 10 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                    logging.info('G1generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                            %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                    logging.info('G1generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errG2).asscalar(), acc, iter, epoch))
               
                iter = iter + 1
                btic = time.time()
    
            name, acc = metric.get()
            metric.reset()
            logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
    
            # Visualize one generated image for each epoch
            fake_img = fake_out[0]
            visualize(fake_img)
            plt.show()
            
            #fake_img2 = fake_out2[0]
            #visualize(fake_img2)
            #plt.show()
    
    pre_train()
    def save_data(path,tpath):
        img_in_list = []
        img_out_list = []
        for path, _, fnames in os.walk(path):
            for fname in fnames:
                if not fname.endswith('.jpg'):
                    continue
                img = os.path.join(path, fname)
                img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
                img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
                # Crop input and output images
                img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                           mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
                #img_arr_in = mx.image.imresize(img_arr_in,128,128)
                #img_arr_out = mx.image.imresize(img_arr_out,128,128)
                img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),
                                           nd.transpose(img_arr_out, (2,0,1))]
                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                           img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_out = netG1(img_arr_out.as_in_context(ctx))
                img_out1 = img_out[0]
                img_out2 = ((img_out1.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
                    
                save_name = tpath+fname
                
                cv2.imwrite(save_name, img_out2)
    
    save_data("../data/edges2handbags/val/","../data/edges2handbags/G1andG2/")
    netD1 = Discriminator(in_channels=6)
    netD2 = Discriminator(in_channels=6)
    network_init(netD1)
    network_init(netD2)
    trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    from datetime import datetime
    import time
    import logging
    def facc(label, pred):
            pred = pred.ravel()
            label = label.ravel()
            return ((pred > 0.5) == label).mean()
    
    def dual_pre_train():
        metric = mx.metric.CustomMetric(facc)
        stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
        logging.basicConfig(level=logging.DEBUG)
        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            PIE_normal_to_lighting.reset()
            iter = 0
            for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch1.data[0].as_in_context(ctx)
                real_out = batch1.data[1].as_in_context(ctx)
                lighing_bad = batch2.data[0].as_in_context(ctx) 
                lighing_good = batch2.data[1].as_in_context(ctx)
                          
                         
                with autograd.record():
                    fake_out = netG1(real_in)
                    #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out)
                    errG1 = L1_loss(real_in, fake_out)+L1_loss(netG1(netG2(lighing_good)), lighing_good)
                    #增加一个三方loss
                    #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) 
                               #                         + L1_loss(netG1(netG2(fake_out)),fake_out) 
                    errG1.backward()
    
                trainerG1.step(batch1.data[0].shape[0])
                
                with autograd.record():
                    fake_out3 = netG2(real_out)
                    #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)
                    errG2 = L1_loss(lighing_good, fake_out3)+L1_loss(netG2(netG1(real_in)), real_in) 
                    #增加一个三方loss
                    #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)
                               #                         + L1_loss(netG2(netG1(fake_out3)),fake_out3) 
                    errG2.backward()
    
                trainerG2.step(batch2.data[0].shape[0])
                
                # Print log infomation every ten batches
                if iter % 10 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                    logging.info('G1generator loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                    logging.info('G2generator loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errG2).asscalar(), acc, iter, epoch))
                iter = iter + 1
                btic = time.time()
    
            name, acc = metric.get()
            metric.reset()
            logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
    
            # Visualize one generated image for each epoch
            fake_img = fake_out[0]
            visualize(fake_img)
            plt.show()
    
    dual_pre_train()
    def test_netG(Spath,Tpath):
        for path, _, fnames in os.walk(Spath):
            for fname in fnames:
                if not fname.endswith('.png'):
                    continue
                #num = fname[14:16]
                #if num !='07':
                    #continue
                test_img = os.path.join(path, fname)
                img_fname = mx.image.imread(test_img) 
                img_arr_fname = img_fname.astype(np.float32)/127.5 - 1
                img_arr_fname = mx.image.imresize(img_arr_fname,128,128)
                img_arr_in = nd.transpose(img_arr_fname, (2,0,1))
                img_arr_in = img_arr_in.reshape((1,) + img_arr_in.shape)
                img_out = netG1(img_arr_in.as_in_context(ctx))
                img_out = img_out[0]
                #img_out = mx.image.imresize(img_out,120,165)
                save_name = Tpath+ fname
                plt.imsave(save_name, ((img_out.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) )
                
    #test_netG('MultiPIE/MultiPIE_test_128_Gray/','MultiPIE/relighting/')
    test_netG('MultiPIE/Bio_relighing2/','MultiPIE/Bio_color/')

    #使用opencv的人脸特征点作为损失

    fileDir = '/home/hxj/gluon-tutorials/GAN/openface/'
    sys.path.append(os.path.join(fileDir))
    import argparse
    import cv2
    import dlib
    import matplotlib.pyplot as plt
    from pylab import plot  
    from openface.align_dlib import AlignDlib
    modelDir = os.path.join(fileDir, 'models')
    openfaceModelDir = os.path.join(modelDir, 'openface')
    dlibModelDir = os.path.join(modelDir, 'dlib')
    dlibFacePredictor= os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat")
    
    def land_mark_errs(batch1,batch2):
        align = AlignDlib(dlibFacePredictor)
        sum_err = nd.zeros((10)).as_in_context(ctx)
        i=0
        for (x,y) in zip(batch1,batch2):
            x1 = ((x.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
            y1 = ((y.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
            """
            bbx = align.getLargestFaceBoundingBox(x1)
            if bbx is None:
                x1_r = MSRCR(x1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)
                bbx = align.getLargestFaceBoundingBox(x1_r)
                if bbx is None:
                    lab = cv2.cvtColor(x1,cv2.COLOR_RGB2LUV)
                    bbx = align.getLargestFaceBoundingBox(lab)
                    #if bbx is None:
                        #print('bbx is none')
                    
            bby = align.getLargestFaceBoundingBox(y1)
            if bby is None:
                y1_r = MSRCR(y1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)
                bby = align.getLargestFaceBoundingBox(y1_r)
                if bby is None:
                    lab = cv2.cvtColor(y1, cv2.COLOR_RGB2LUV)
                    bby = align.getLargestFaceBoundingBox(lab)
                    if bby is None:
                        #print('bby is none')
        
            if bby is None:
                continue
            if bbx is None:
                #bbx= bby
                continue
            """
            bbx = dlib.rectangle(-19, -19, 124, 125)
            bby = dlib.rectangle(-19, -19, 124, 125)
            landmarks_x = nd.array(align.findLandmarks(x1, bbx))
            landmarks_y = nd.array(align.findLandmarks(y1, bby))
            if landmarks_x  is None:
                continue
            if landmarks_y is None:
                continue
            sum_err[i]=nd.sum(nd.abs(landmarks_x -landmarks_y))/68
            i+=1
        return sum_err
    from datetime import datetime
    import time
    import logging
    def facc(label, pred):
            pred = pred.ravel()
            label = label.ravel()
            return ((pred > 0.5) == label).mean()
    
    def generate_train_single():
        image_pool = ImagePool(pool_size)
        metric = mx.metric.CustomMetric(facc)
        stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
        logging.basicConfig(level=logging.DEBUG)
        
        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            train_data.reset()
            iter = 0
            for batch1  in train_data:
            #for batch in range(400):
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
                real_out = batch1.data[1].as_in_context(ctx)
                
                #G1
                fake_out = netG1(real_in)
                with autograd.record():
                    errG1 = L1_loss(real_out, fake_out)* 20 +  L1_loss(netG1(netG2(real_out)), real_out) *10
                  
                    #land_mark_errs(real_in, fake_out)*0.4
                    errG1.backward()
                                                                                                        
                trainerG1.step(batch1.data[0].shape[0])
                           
                #G2
                fake_out2 = netG2(real_out)
                with autograd.record():
                    errG2 = L1_loss(real_in, fake_out2)* 20 +  L1_loss(netG2(netG1(real_in)), real_in)*10
                    
                    #land_mark_errs(real_out, fake_out2)*0.4
                trainerG2.step(batch1.data[0].shape[0])
                
                # Print log infomation every ten batches
                if iter % 10 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                    logging.info('generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                    logging.info('generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errG2).asscalar(), acc, iter, epoch))
                iter = iter + 1
                btic = time.time()
    
            name, acc = metric.get()
            metric.reset()
            logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
    
            # Visualize one generated image for each epoch
           
            fake_img = fake_out[0]
            visualize(fake_img)
            plt.show()
            
            
    
    generate_train_single()
    from skimage import io
    bgrImg = cv2.imread('CAS/test_aligned_128/FM_000046_IFD+90_PM+00_EN_A0_D0_T0_BW_M0_R1_S0.png')
    rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB)
    plt.imshow(rgbImg)
    plt.show()
    lab = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2LAB)
    plt.imshow(lab)
    plt.show()
    
    img_test = lab[:,:,0].astype(np.float32)/127.5 - 1
    img_test = nd.array(img_test)
    img_arr_in= img_test.reshape((1,1,) + img_test.shape).as_in_context(ctx)
    test1 = netG1(img_arr_in)
    test2 = test1[0][0]
    cv2.imshow(((test2.asnumpy() + 1.0) * 127.5).astype(np.uint8))
    from datetime import datetime
    import time
    import logging
    def facc(label, pred):
            pred = pred.ravel()
            label = label.ravel()
            return ((pred > 0.5) == label).mean()
    
    def Dual_train_single():
        image_pool = ImagePool(pool_size)
        metric = mx.metric.CustomMetric(facc)
        stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
        logging.basicConfig(level=logging.DEBUG)
        
        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            train_data.reset()
            iter = 0
            for batch1  in train_data:
            #for batch in range(400):
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
                real_out = batch1.data[1].as_in_context(ctx)
                          
               
                #D1  
                fake_out = netG1(real_in)
                fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                with autograd.record():
                    output = netD1(fake_concat)
                    #output = netD1(fake_out)
                    fake_label = nd.zeros(output.shape, ctx=ctx)
                    errD_fake = GAN_loss(output, fake_label)
                    metric.update([fake_label,], [output,])
                               
                    # Train with real image
                    real_concat = image_pool.query(nd.concat(real_in, real_out, dim=1))
                    #ground truth 也要经过G1
                    output = netD1(real_concat) 
                    real_label = nd.ones(output.shape, ctx=ctx)
                    errD_real = GAN_loss(output, real_label)
                    
                    errD1 = (errD_real + errD_fake) *0.5
                    errD1.backward()
                    metric.update([real_label,], [output,])
    
                trainerD1.step(batch1.data[0].shape[0])
               
                #G1
                with autograd.record():
                    #fake_out = netG1(real_in)
                    fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                    output = netD1(fake_concat)
                    #output = netD1(fake_out)
                    real_label = nd.ones(output.shape, ctx=ctx)
                    #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ 
                    #L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                    #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ 
                    #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1
                                                  
                    errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * 20+ 
                    L1_loss(netG1(netG2(real_out)), real_out) *10
                    #land_mark_errs(real_out, fake_out)
                
                    errG1.backward()
                                                                                                         
                trainerG1.step(batch1.data[0].shape[0])
                
               
                #D2  
                fake_out2 = netG2(real_out)
                fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))
                with autograd.record():
                    output2 = netD2(fake_concat2)
                    fake_label2 = nd.zeros(output2.shape, ctx=ctx)
                    errD_fake2 = GAN_loss(output2, fake_label2)
                    metric.update([fake_label2,], [output2,])
                               
                    # Train with real image
                    real_concat2 = image_pool.query(nd.concat(real_out, real_in, dim=1))
                    output2 = netD2(real_concat2)
                    real_label2 = nd.ones(output2.shape, ctx=ctx)
                    errD_real2 = GAN_loss(output2, real_label2)
                    
                    errD2 = (errD_real2 + errD_fake2) * 0.5 
                    errD2.backward()
                    metric.update([real_label2,], [output2,])
    
                trainerD2.step(batch1.data[0].shape[0])
               
                #G2   
                with autograd.record():
                    #fake_out2 = netG2(real_out)
                    fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))
                    output2 = netD2(fake_concat2)
                    real_label2 = nd.ones(output2.shape, ctx=ctx)
                    
                
                    #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ 
                    #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                    errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * 20+ 
                    L1_loss(netG2(netG1(real_in)), real_in) *10
                    #land_mark_errs(real_in, fake_out2)
                    errG2.backward()
                    
                trainerG2.step(batch1.data[0].shape[0])
                
                # Print log infomation every ten batches
                if iter % 10 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                    logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errD1).asscalar(),
                               nd.mean(errG1).asscalar(), acc, iter, epoch))
                    logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errD2).asscalar(),
                               nd.mean(errG2).asscalar(), acc, iter, epoch))
                iter = iter + 1
                btic = time.time()
    
            name, acc = metric.get()
            metric.reset()
            logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
    
            # Visualize one generated image for each epoch
           
            fake_img = fake_out[0]
            visualize(fake_img)
            plt.show()
            
            
    
    Dual_train_single()
    from datetime import datetime
    import time
    import logging
    def facc(label, pred):
            pred = pred.ravel()
            label = label.ravel()
            return ((pred > 0.5) == label).mean()
    
    def train():
        #image_pool = ImagePool(pool_size)
        metric = mx.metric.CustomMetric(facc)
        stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
        logging.basicConfig(level=logging.DEBUG)
        
        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            retinex_data.reset()
            PIE_normal_to_lighting.reset()
            iter = 0
            for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):
            #for batch in range(400):
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
                real_out = batch1.data[1].as_in_context(ctx)
                lighing_bad = batch2.data[0].as_in_context(ctx) 
                lighing_good = batch2.data[1].as_in_context(ctx)
                          
                
                fake_out = netG1(real_in)
                #D1  
                with autograd.record():
                    
                    #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                    #output = netD1(fake_concat)
                    output = netD1(fake_out)
                    fake_label = nd.zeros(output.shape, ctx=ctx)
                    errD_fake = GAN_loss(output, fake_label)
                    metric.update([fake_label,], [output,])
                               
                    # Train with real image
                    #real_concat = image_pool.query(nd.concat(real_in, lighing_good, dim=1))
                    output = netD1(lighing_good)
                    real_label = nd.ones(output.shape, ctx=ctx)
                    errD_real = GAN_loss(output, real_label)
                    
                    errD1 = (errD_real + errD_fake) * 0.5 
                    errD1.backward()
                    metric.update([real_label,], [output,])
    
                trainerD1.step(batch1.data[0].shape[0])
               
                #G1
                with autograd.record():
                    #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                    #output = netD1(fake_concat)
                    fake_out = netG1(real_in)
                    output = netD1(fake_out)
                    real_label = nd.ones(output.shape, ctx=ctx)
                    #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ 
                    #L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                    #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ 
                    #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1
                    errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ 
                    L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                    errG1.backward()
                                                                                                         
                trainerG1.step(batch1.data[0].shape[0])
                
               
                #D2  
                fake_out2 = netG2(lighing_good)
                with autograd.record():
                    #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))
                    output2 = netD2(fake_out2)
                    fake_label2 = nd.zeros(output2.shape, ctx=ctx)
                    errD_fake2 = GAN_loss(output2, fake_label2)
                    metric.update([fake_label2,], [output2,])
                               
                    # Train with real image
                    #real_concat2 = image_pool.query(nd.concat(lighing_good, real_in, dim=1))
                    output2 = netD2(real_in)
                    real_label2 = nd.ones(output2.shape, ctx=ctx)
                    errD_real2 = GAN_loss(output2, real_label2)
                    
                    errD2 = (errD_real2 + errD_fake2) * 0.5 
                    errD2.backward()
                    metric.update([real_label2,], [output2,])
    
                trainerD2.step(batch2.data[0].shape[0])
               
                #G2   
                with autograd.record():
                    fake_out2 = netG2(lighing_good)
                    #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))
                    output2 = netD2(fake_out2)
                    real_label2 = nd.ones(output2.shape, ctx=ctx)
                  
                    #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ 
                    #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                    errG2 = GAN_loss(output2, real_label2)+ L1_loss(lighing_good, fake_out2) * lambda1+ 
                    L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                    errG2.backward()
                    
                trainerG2.step(batch2.data[0].shape[0])
                
                # Print log infomation every ten batches
                if iter % 10 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                    logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errD1).asscalar(),
                               nd.mean(errG1).asscalar(), acc, iter, epoch))
                    logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                             %(nd.mean(errD2).asscalar(),
                               nd.mean(errG2).asscalar(), acc, iter, epoch))
                iter = iter + 1
                btic = time.time()
    
            name, acc = metric.get()
            metric.reset()
            logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
    
            # Visualize one generated image for each epoch
           
            fake_img = fake_out[0]
            visualize(fake_img)
            plt.show()
            
            
    
    train()
  • 相关阅读:
    TCP原理简介
    zabbix_get [109064]: Check access restrictions in Zabbix agent configuration
    Log4j2:异步日志中打印方法名和行号信息
    高仿腾讯QQ最终版
    启动TOMCAT报错 java.util.zip.ZipException: invalid LOC header (bad signature)
    修改hosts立刻生效不必重启
    MyEclipse的Debug模式启动缓慢
    SpringBatch配置数据库
    SpringBatch的核心组件JobLauncher和JobRepository
    SpringBatch前言
  • 原文地址:https://www.cnblogs.com/hxjbc/p/9480112.html
Copyright © 2011-2022 走看看