zoukankan      html  css  js  c++  java
  • Pytorch-GAN

    任务:使用8个高斯混合模型生成一系列数据,通过GAN学习它的分布,比较学习的分布和真实的分布是否一样。

    GAN文字版算法:

    GAN公式版算法:

    在命令行执行如下语句(详细Visdom的使用见https://www.cnblogs.com/cxq1126/p/13285150.html

    python -m visidom.server

    提前导包:

     1 import  torch 
     2 from    torch import nn, optim, autograd
     3 import  numpy as np
     4 import  visdom
     5 from    torch.nn import functional as F
     6 from    matplotlib import pyplot as plt
     7 import  random
     8 
     9 h_dim = 400
    10 batchsz = 512
    11 viz = visdom.Visdom()

    1.实现Generator

    输入点坐标,输出点坐标。

     1 class Generator(nn.Module):
     2 
     3     def __init__(self):
     4         super(Generator, self).__init__()
     5 
     6         self.net = nn.Sequential(             
     7             nn.Linear(2, h_dim),          
     8             nn.ReLU(True),
     9             nn.Linear(h_dim, h_dim),
    10             nn.ReLU(True),
    11             nn.Linear(h_dim, h_dim),
    12             nn.ReLU(True),
    13             nn.Linear(h_dim, 2),
    14         )
    15 
    16     def forward(self, z):
    17         output = self.net(z)
    18         return output

    2.实现Discriminator

    输入点坐标,输出数值(用来评判输入的坐标是否在真实数据附近)。

     1 class Discriminator(nn.Module):
     2 
     3     def __init__(self):
     4         super(Discriminator, self).__init__()
     5 
     6         self.net = nn.Sequential(
     7             nn.Linear(2, h_dim),
     8             nn.ReLU(True),
     9             nn.Linear(h_dim, h_dim),
    10             nn.ReLU(True),
    11             nn.Linear(h_dim, h_dim),
    12             nn.ReLU(True),
    13             nn.Linear(h_dim, 1),
    14             nn.Sigmoid()
    15         )
    16 
    17     def forward(self, x):
    18         output = self.net(x)
    19         return output.view(-1)

    3.权重初始化

    1 def weights_init(m):
    2     if isinstance(m, nn.Linear):
    3         # m.weight.data.normal_(0.0, 0.02)
    4         nn.init.kaiming_normal_(m.weight)
    5         m.bias.data.fill_(0)

    4.生成数据集 8-gaussian mixture models

    对于高斯混合模型的理解:

     1 def data_generator():
     2 
     3     scale = 2.
     4     centers = [
     5         (1, 0),
     6         (-1, 0),
     7         (0, 1),
     8         (0, -1),
     9         (1. / np.sqrt(2), 1. / np.sqrt(2)),
    10         (1. / np.sqrt(2), -1. / np.sqrt(2)),
    11         (-1. / np.sqrt(2), 1. / np.sqrt(2)),
    12         (-1. / np.sqrt(2), -1. / np.sqrt(2))
    13     ]
    14     centers = [(scale * x, scale * y) for x, y in centers]
    15     while True:
    16         dataset = []
    17         for i in range(batchsz):
    18             point = np.random.randn(2) * .02
    19             center = random.choice(centers)
    20             
    21             #N(0,1)sample出来一个点 + center_x1/x2
    22             point[0] += center[0]
    23             point[1] += center[1]
    24             dataset.append(point)
    25         dataset = np.array(dataset, dtype='float32')
    26         dataset /= 1.414                                #stdev
    27         yield dataset

    5.可视化

     1 def generate_image(D, G, xr, epoch):      #xr表示真实的sample
     2     """
     3     Generates and saves a plot of the true distribution, the generator, and the
     4     critic.
     5     """
     6     N_POINTS = 128
     7     RANGE = 3
     8     plt.clf()
     9 
    10     points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    11     points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    12     points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    13     points = points.reshape((-1, 2))             # (16384, 2)
    14     
    15 
    16     # draw contour
    17     with torch.no_grad():
    18         points = torch.Tensor(points)      # [16384, 2]
    19         disc_map = D(points).cpu().numpy() # [16384]
    20     x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    21     cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    22     plt.clabel(cs, inline=1, fontsize=10)
    23     # plt.colorbar()
    24 
    25 
    26     # draw samples
    27     with torch.no_grad():
    28         z = torch.randn(batchsz, 2)                 # [b, 2]
    29         samples = G(z).cpu().numpy()                # [b, 2]
    30     plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    31     plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')
    32 
    33     viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))

    6.训练

    和上面图片中的梯度上升法不同,下面训练使用的梯度下降法,所以对于原本需要最大化的的数据添加负号,就能实现梯度下降。

    optim.Adam()中的batas参数用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))

     1 def main():
     2 
     3     torch.manual_seed(23)
     4     np.random.seed(23)
     5 
     6     G = Generator()
     7     D = Discriminator()
     8     G.apply(weights_init)
     9     D.apply(weights_init)
    10 
    11     optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9)) 
    12     optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
    13 
    14 
    15     data_iter = data_generator()
    16     print('batch:', next(data_iter).shape)              #[b, 2]
    17 
    18     viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))
    19 
    20     for epoch in range(1000):
    21 
    22         # 1. train discriminator for k steps
    23         for _ in range(5):
    24             
    25             #1.1首先train on real data
    26             x = next(data_iter)
    27             xr = torch.from_numpy(x)       #真实数据
    28             predr = (D(xr))                # [b, 2] -> [b, 1]
    29             # max log(lossr),即min (-lossr)
    30             lossr = - (predr.mean())
    31 
    32             #1.2 train on fake data
    33             z = torch.randn(batchsz, 2)    # [b, 2]随机产生的伪数据
    34             xf = G(z).detach()             # [b, 2] 此处固定G,更新D,所以不更新G的参数
    35             predf = (D(xf))                # [b]
    36             # min predf
    37             lossf = (predf.mean())
    38 
    39             loss_D = lossr + lossf 
    40             
    41             optim_D.zero_grad()
    42             loss_D.backward()
    43             optim_D.step()
    44 
    45 
    46         # 2. train Generator
    47         z = torch.randn(batchsz, 2)       #[b, 2]随机产生的伪数据
    48         xf = G(z)
    49         predf = (D(xf))
    50         # max predf,即min(-predf)
    51         loss_G = - (predf.mean())        
    52         
    53         optim_G.zero_grad()               #此处固定D,更新G,所以不更新D的参数
    54         loss_G.backward()
    55         optim_G.step()
    56 
    57 
    58         if epoch % 100 == 0:
    59             viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
    60             generate_image(D, G, xr, epoch)
    61             print(loss_D.item(), loss_G.item())
    62 
    63 
    64 if __name__ == '__main__':
    65     main()

    下图显示Generator和Discriminator的训练结果,两者的loss都接近于0,sample出来的数据经过Generator后覆盖住了真实数据。

     1 -0.6137181520462036 -0.0665668323636055
     2 7.620922672430937e-23 -6.780519236493468e-22
     3 2.8811570836759978e-37 -3.576674200342663e-41
     4 2.7326036398933606e-12 -5.0371435361452164e-33
     5 1.5811193901845998e-21 -1.796846778152569e-15
     6 2.1587619070328268e-20 -0.0
     7 2.0948376092776535e-32 -2.429269052203856e-16
     8 6.822592214491066e-14 -0.0
     9 9.122023851176224e-35 -2.9085079939065756e-34
    10 0.0 -4.5381907731517276e-14

    并不是每次运行都是这种结果,Generator常常由于GAN训练的不稳定(真实数据和生成数据没有重叠),loss保持在非0的某个值,长期得不到更新。

    解决方案:W-GAN(通过用Wasserstein距离代替JS散度来优化训练的生成对抗网络)

    在代码中增加gradient penalty

     1 def gradient_penalty(D, xr, xf):  #xr和xf的shape=[b, 2]
     2     
     3     t = torch.rand(batchsz, 1)   #sample一个均值分布[b, 1]
     4     t = t.expand_as(xr)          #[b, 1] -> [b, 2]
     5     
     6     mid = t *xr +(1-t) * xf      #在真实数据和fake数据之间做线性差值,即图中的xhat
     7     mid.requires_grad_()         #设置导数信息
     8     
     9     pred = D(mid)
    10     grads = autograd.grad(outputs=pred, inputs=mid,
    11                           grad_outputs=torch.ones_like(pred),
    12                           create_graph=True, retain_graph=True, only_inputs=True)[0]     #create_graph用来二次求导
    13     
    14     gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
    15     return gp

    在主函数中计算Discriminator的loss时加上gp项。

    1             #1.3 gradient penalty
    2             gp = gradient_penalty(D, xr, xf.detach())    #因为不需要对D求导,所以detach
    3             
    4             loss_D = lossr + lossf + 0.2 * gp

    迭代2000次的结果如下:

     1 -0.5074049830436707 -0.11114723235368729
     2 -0.4972265362739563 -0.3060861825942993
     3 -0.5360251665115356 -0.23698316514492035
     4 -0.3480537533760071 -0.3882521390914917
     5 -0.22527582943439484 -0.5057252645492554
     6 -0.13060471415519714 -0.5396959185600281
     7 -0.07626903057098389 -0.6366142630577087
     8 -0.09713903069496155 -0.6304153203964233
     9 -0.1190759465098381 -0.5412021279335022
    10 -0.1230357214808464 -0.5588557124137878
    11 -0.04560390114784241 -0.6632308959960938
    12 -0.06906679272651672 -0.6173125505447388
    13 -0.04104984924197197 -0.7628952860832214
    14 -0.0408158078789711 -0.7121548652648926
    15 -0.04687119275331497 -0.7424123287200928
    16 -0.024066904559731483 -0.7196884751319885
    17 -0.04576507583260536 -0.7208324670791626
    18 -0.02462894842028618 -0.7012563943862915
    19 -0.01230126153677702 -0.7875514030456543
    20 -0.02122686244547367 -0.7108622193336487

  • 相关阅读:
    机器学习笔记之Boosting算法
    机器学习笔记之多重共线性问题以及如何解决
    机器学习笔记之在Visual Studio Code中使用Jupyter Notebook
    高数学习笔记之范数与距离度量(python实现)
    机器学习笔记之线性回归最小二乘法(公式推导和非调包实现)
    机器学习笔记之线性回归数学推导
    高数笔记之期望、方差与最小二乘法
    DataTable 去重 测试
    vs2019 项目历史记录
    sql 求和 语句
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13538409.html
Copyright © 2011-2022 走看看