zoukankan      html  css  js  c++  java
  • DCGAN

    # -*- coding: UTF-8 -*-
    import torch
    import torch.nn as nn
    import numpy as np
    import torch.nn.init as init
    import os
    import test
    from GAN_model import Generator,Discriminator
    
    print("data loading ...")
    
    G_LR=0.0002
    D_LR=0.0002
    BATCHSIZE=50
    EPOCHES=3000
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    data_pt="data51200.pt"###the position and name of train data(can get it by data_loader.py or data_loader_sketch.py)
    train_para_save_path="./pkl/"
    loss_save_file = 'loss.txt'
    
    
    def init_ws_bs(m):
        if isinstance(m,nn.ConvTranspose2d):
            init.normal_(m.weight.data,std=0.2)
            init.normal_(m.bias.data,std=0.2)
    
    g=Generator().to(device)
    d=Discriminator().to(device)
    
    
    init_ws_bs(g),init_ws_bs(d)
    
    ###load traind model
    # para_path="./pkl/"
    # para_file="29g.pkl"
    # g=torch.load(para_path+para_file)
    # d=torch.load(para_path+para_file)
    
     
    g_optimizer=torch.optim.Adam(g.parameters(),betas=(.5,0.999),lr=G_LR)
    d_optimizer=torch.optim.Adam(d.parameters(),betas=(.5,0.999),lr=D_LR)
     
    g_loss_func=nn.BCELoss()
    d_loss_func=nn.BCELoss()
     
    label_real = torch.ones(BATCHSIZE).to(device)
    label_fake = torch.zeros(BATCHSIZE).to(device)
    
    if os.path.exists(loss_save_file):
        os.remove(loss_save_file)
    if os.path.exists(data_pt):
        real_img=torch.load(data_pt)
    if real_img !=None:
        print("load data successfully")
    else:
        print("fail to load data")
    if not os.path.exists(train_para_save_path):
        os.makedirs(train_para_save_path)
    for file in os.listdir(train_para_save_path):
          os.remove(train_para_save_path + file)
    
    
    print("start training")
    batch_imgs=[]
    for epoch in range(EPOCHES):
    
        np.random.shuffle(real_img)
        loss_epoch=[]
        for i in range(len(real_img)):
            batch_imgs.append(real_img[i].numpy())
            if (i+1) % BATCHSIZE == 0:
                batch_real=torch.Tensor(batch_imgs).to(device)
    
                batch_imgs.clear()
    
                ####min Discriminate loss
                d_optimizer.zero_grad()
                pre_real=d(batch_real).squeeze()
                # pre_real = d(batch_real)
                d_real_loss=d_loss_func(pre_real,label_real)
                d_real_loss.backward()
    
                batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)
    
                img_fake=g(batch_fake)
                pre_fake=d(img_fake.detach()).squeeze()
                d_fake_loss=d_loss_func(pre_fake,label_fake)
                d_fake_loss.backward()
     
                d_optimizer.step()
    
                ####min Generate loss
                g_optimizer.zero_grad()
                batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)
    
                img_fake=g(batch_fake)
                pre_fake=d(img_fake).squeeze()
                g_loss=g_loss_func(pre_fake,label_real)
                g_loss.backward()
                g_optimizer.step()
                batch_num=i/BATCHSIZE
                print("epoch%d batch%d:"%(epoch,batch_num),(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy())
                loss_epoch.append([(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy()])
        ###After finishing an epoch,record the data
        torch.save(g,train_para_save_path+str(epoch)+"g.pkl")
        torch.save(d,train_para_save_path+str(epoch)+"d.pkl")
        with open(loss_save_file, 'a+') as f:
            for d_loss_epoch,g_loss_epoch in loss_epoch:
                f.write(str(d_loss_epoch)+' '+str(g_loss_epoch)+'
    ')
    
        test.draw(train_para_save_path+str(epoch)+"g.pkl",str(epoch))
    print("finish the train")
    

      GAN_model.py

    import torch.nn as nn
    class Generator(nn.Module):
    	def __init__(self):
    		super(Generator, self).__init__()
    		self.deconv1 = nn.Sequential(#batchsize,100,1,1
    			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
    				in_channels=100,
    				out_channels=64 * 8,
    				kernel_size=4,
    				stride=1,
    				padding=0,
    				bias=False,
    			),
    			nn.BatchNorm2d(64 * 8),
    			nn.ReLU(inplace=True),
    
    		)  # 14
    		self.deconv2 = nn.Sequential(
    			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
    				in_channels=64 * 8,
    				out_channels=64 * 4,
    				kernel_size=4,
    				stride=2,
    				padding=1,
    				bias=False,
    			),
    			nn.BatchNorm2d(64 * 4),
    			nn.ReLU(inplace=True),
    
    		)  # 24
    		self.deconv3 = nn.Sequential(
    			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
    				in_channels=64 * 4,
    				out_channels=64 * 2,
    				kernel_size=4,
    				stride=2,
    				padding=1,
    				bias=False,
    			),
    			nn.BatchNorm2d(64 * 2),
    			nn.ReLU(inplace=True),
    
    		)  # 48
    		self.deconv4 = nn.Sequential(
    			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
    				in_channels=64 * 2,
    				out_channels=64 * 1,
    				kernel_size=4,
    				stride=2,
    				padding=1,
    				bias=False,
    			),
    			nn.BatchNorm2d(64),
    			nn.ReLU(inplace=True),
    
    		)
    		self.deconv5 = nn.Sequential(
    			nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
    			nn.Tanh(),
    		)
    
    	def forward(self, x):
    		x = self.deconv1(x)
    		x = self.deconv2(x)
    		x = self.deconv3(x)
    		x = self.deconv4(x)
    		x = self.deconv5(x)
    		return x
    class Discriminator(nn.Module):
    	def __init__(self):
    		super(Discriminator, self).__init__()
    		self.conv1 = nn.Sequential(
    			nn.Conv2d(  # batchsize,3,96,96
    				in_channels=3,
    				out_channels=64,
    				kernel_size=5,
    				padding=1,
    				stride=3,
    				bias=False,
    			),
    			nn.BatchNorm2d(64),
    			nn.LeakyReLU(.2, inplace=True),
    
    		)
    		self.conv2 = nn.Sequential(
    			nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ),  # batchsize,16,32,32
    			nn.BatchNorm2d(64 * 2),
    			nn.LeakyReLU(.2, inplace=True),
    
    		)
    		self.conv3 = nn.Sequential(
    			nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
    			nn.BatchNorm2d(64 * 4),
    			nn.LeakyReLU(.2, inplace=True),
    
    		)
    		self.conv4 = nn.Sequential(
    			nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
    			nn.BatchNorm2d(64 * 8),
    			nn.LeakyReLU(.2, inplace=True),
    
    		)
    		self.output = nn.Sequential(
    			nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
    			nn.Sigmoid()  #
    		)
    
    	def forward(self, x):
    		x = self.conv1(x)
    		x = self.conv2(x)
    		x = self.conv3(x)
    		x = self.conv4(x)
    		x = self.output(x)
    		return x
    

      GAN的精髓在于对抗。生成损失和对抗损失的网络反向传播的方式是一样的,只不过生成损失只更新生成器的参数,判别损失只更新判别器的参数(在优化器里面定义)。

      生成器的训练目标只有一个,让生成的假的图片更像真的:g_loss=g_loss_func(pre_fake,label_real)

      而判别器的目标有两个,让真的更像真的:d_real_loss=d_loss_func(pre_real,label_real)

                 让假的更像假的:d_fake_loss=d_loss_func(pre_fake,label_fake)

      

  • 相关阅读:
    计算机网络基础 汇总
    指针与数组
    卡特兰数
    Leetcode Sort Colors
    Leetcode Group Shifted Strings
    Leetcode Summary Ranges
    Leetcode Count Primes
    Leetcode Reverse Words in a String II
    Leetcode Reverse Words in a String
    Leetcode Rotate Array
  • 原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/11995681.html
Copyright © 2011-2022 走看看