zoukankan      html  css  js  c++  java
  • 学习笔记GAN002:DCGAN

    Ian J. Goodfellow 论文:https://arxiv.org/abs/1406.2661

    两个网络:G(Generator),生成网络,接收随机噪声Z,通过噪声生成样本,G(z)。D(Dicriminator),判别网络,判别样本是否真实,输入样本x,输出D(x)代表x真实概率,如果1,100%真实样本,如果0,代表不可能是真实样本。

    训练过程,生成网络G尽量生成真实样本欺骗判别网络D,判别网络D尽量把G生成样本和真实样本分别开。理想状态下,G生成样本G(z),使D难以判断真假,D(G(z))=0.5。此时,生成模型G,可以用来生成样本。

    数学公式:minG maxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1-D(G(z)))]
    二项式。x真实样本,z输入G网噪声,G(z) G网生成样本。D(x) D网判断真实样本是否真实概率,越接近1越好。D(G(z)) D网判断G网生成样本真实概率。G网,D(G(z))尽可能大,V(D,G)变小,min_G。D网,D(x)越大,D(G(x))越小,V(D,G)越大,max_D。
    1、x sampled from data -> Differentiable function D -> D(x) tries to be near 1
    2、Input noise z -> Differntiable function G -> x sampled from model -> D -> D tries to make D(G(z)) near 0,G tries to make D(G(z)) near 1

    随机梯度下降法训练D、G。

    Algorithm 1 Minibatch stochastic gradient descent training of genegative adversarial nets. The number of steps to apply to the discriminator,k,is a hyperparameter, we used k = 1,the least expensive option, in our experiments.
    for number of training iterations do
    for k steps do
    Sample minibatch of m noise samples {z(1),...,z(m)} from noise prior pg(z)
    Sample minibatch of m examples {x(1),...,x(m)} from data generating distribution pdata(x)
    Update the discriminator by ascending its stochastic gradient:
    end for
    Sample mninbatch of m noise samples {z(1),...,z(m)} from noise prior pg(z)
    Update the generator by descending its stochastic gradient:
    end for
    The gradient-based updates can use any standard gradient-based learning rule.We used momenttum in our experiments.
    第一步训练D,V(G,D)越大越好,上升(增加)梯度(ascending)。第二步训练G,V(G,D)越小越好,下降(减少)梯度(descending)。交替进行。

    DCGAN原理。https://arxiv.org/abs/1511.06434 。Alec Radford, Luke Metz, Soumith Chintala,《Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks》。G、D换成卷积神经网络(CNN)。取消所有pooling层。G网络使用转置卷积(transposed convolutional layer)上采样,D网络加入stride卷积代替pooling。D、G都batch normalization。去掉FC层,全卷积网络。G网用ReLU激活函数,最后一层用tanh。D网用LeakyRelu激活函数。

    G网络。Project reshape 100 z -> 4X4X1024 -> 8X8X512 -> 16X16X256 -> 32X32X128 -> 64X64X3

    用DCGAN生成动漫人物头像:
    http://qiita.com/mattya/items/e5bfe5e04b9d2f0bbd47 。

    原始数据搜集。http://safebooru.donmai.us 。http://konachan.net 。
    爬虫代码:

    import requests
    from bs4 import BeautifulSoup
    import os
    import traceback
    def download(url, filename):
    if os.path.exists(filename):
    print('file exists!')
    return
    try:
    r = requests.get(url, stream=True, timeout=60)
    r.raise_for_status()
    with open(filename, 'wb') as f:
    for chunk in r.iter_content(chunk_size=1024):
    if chunk: # filter out keep-alive new chunks
    f.write(chunk)
    f.flush()
    return filename
    except KeyboardInterrupt:
    if os.path.exists(filename):
    os.remove(filename)
    raise KeyboardInterrupt
    except Exception:
    traceback.print_exc()
    if os.path.exists(filename):
    os.remove(filename)
    if os.path.exists('imgs') is False:
    os.makedirs('imgs')
    start = 1
    end = 8000
    for i in range(start, end + 1):
    url = 'http://konachan.net/post?page=%d&tags=' % i
    html = requests.get(url).text
    soup = BeautifulSoup(html, 'html.parser')
    for img in soup.find_all('img', class_="preview"):
    target_url = 'http:' + img['src']
    filename = os.path.join('imgs', target_url.split('/')[-1])
    download(target_url, filename)
    print('%d / %d' % (i, end))

    头像截取:
    https://github.com/nagadomi/lbpcascade_animeface 。
    封装:

    import cv2
    import sys
    import os.path
    from glob import glob
    def detect(filename, cascade_file="lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
    raise RuntimeError("%s: not found" % cascade_file)
    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)
    faces = cascade.detectMultiScale(gray,
    # detector options
    scaleFactor=1.1,
    minNeighbors=5,
    minSize=(48, 48))
    for i, (x, y, w, h) in enumerate(faces):
    face = image[y: y + h, x:x + w, :]
    face = cv2.resize(face, (96, 96))
    save_filename = '%s-%d.jpg' % (os.path.basename(filename).split('.')[0], i)
    cv2.imwrite("faces/" + save_filename, face)
    if __name__ == '__main__':
    if os.path.exists('faces') is False:
    os.makedirs('faces')
    file_list = glob('imgs/*.jpg')
    for filename in file_list:
    detect(filename)

    训练:
    https://github.com/carpedm20/DCGAN-tensorflow 。
    model.py:

    if config.dataset == 'mnist':
    data_X, data_y = self.load_mnist()
    else:
    data = glob(os.path.join("./data", config.dataset, "*.jpg"))
    data文件夹新建anime文件夹放图片,运行时指定 --dataset anime。

    python main.py --image_size 96 --output_size 48 --dataset anime --is_crop True --is_train True --epoch 300 --input_fname_pattern "*.jpg"

    GAN论文:https://github.com/zhangqianhui/AdversarialNetsPapers 。

    SGD优化。目标函数(objective function)判断、监视学习成果。J(D) 判别网络目标函数,交叉熵(cross entropy)函数。左边D判别真实数据,右边D判别G生成噪音数据。J(G) 生成网络目标函数。
    最小最大博弈,minimax game。均衡点(纳什均衡),J(D)鞍点(saddle point)。
    真实数据(data)和模型生成伪数据(model distribution z映射)。 学习D,区分data、model分布。data、model分布相加做分母,分子是真实data分布。目标,D无限接近常数1/2.Pmodel、Pdata无限相似。生成模型与源数据拟合后,无法再学习,常数y=1/2求导永远0。
    非饱和博弈(Non-Saturating)。G伪装成功率表示目标函数,均衡不由损失(loss)决定。D完美后,G还可以继续优化。
    DCGAN(深度卷积生成对抗网络 Deep Convolutional Generative Adversarial Network),反向CNN。
    卷积神经网络原理,convolutinoal filter 卷积过滤器(滤镜),把图片过滤(转化)成各种样式。不同过滤器,把图片转化成不同样式。不同样式为原图片不同特征表达。特征学习。
    DCGAN创造图片。把一组特征值慢慢恢复成一张图片。
    每一个滤镜层,CNN把大图片重要特征提取出来,一步一步减小图片尺寸。DCGAN把小图片(小数组)特征放大,排列成新图片。DCGAN输入最初小数据是噪声数据。图片RGB矩阵,可以向量加减。戴墨镜的男人-不戴墨镜的男人+不戴墨镜的女人=戴墨镜的女人。NLP,word2vec,king-man+woman=queen。向量/矩阵加减后,还原成“图义”代表的图片。NLP,word2vec,向量对应有意义的词;DCGAN,矩阵对应有意义的图片。
    统计学科,JS距离(minimax),KL距离,散度(divergence)方程。创造目标函数。DKL(P||Q)=S∞-∞p(x)log(p(x)/q(x))dx 。
    GAN 神经网络构造,通过类SGD方法优化模型。目标函数重要。Q噪声数据分布。P目标分布。求最大似然(Maximum Likelihood),使KL距离最小化。
    KL[P||Q]=SPlog(P/Q)dx=SPlogPdx-SPlogQdx
    P、Q都以x为变量,P是真实数据分类。SPlogPdx 是常数。SPlogQdx 只有logQ是变量。正比于-logQ。
    KL[P||Q]=-常数-S另一常数·logQdx
    Q,P(x|θ)。P模型,θ参数。负的最大似然。
    类GAN算法最小化任何f-divergence方程。
    面对无限多数据,都可以学到真实数据分布P。现实,数据有限。KL公式理论,KL(P||Q),Q拟合真实数据P,极大解释全部P内涵(overgeneralization)。多模态(multimodal),数据不够多,KL(P||Q)覆盖不完整。KL(Q||P),undergeneralization 。先覆盖较大,再覆盖较小。
    G目标函数改造成最大似然。J(G)求导,得到最大似然表达形式。Maximal Likelihood跑得最快。
    GAN,生成(复刻)样本,还可以转为强化学习模型(Reinforcement Learning)。上海交大 SeqGAN[Yu et al. 2016]论文。
    把数据标签给GAN。学习条件概率p(y|x)远比单独p(x)容易。部分有标签就能大幅提升GAN训练效果,半监督(semi-supervising)学习。半监督学习,三类数据,真实无标签数据,有标签数据,噪音生成数据。目标函数,监督方法和无监督方法结合。标签平滑(smooth),把0、1离散标签,转变成更加平滑的0.1(beta)、0.9(alpha)等。分子混进beta系数假数据分布。假数据建议保留标签0,一个平滑,另一个不平滑,one-sided label amoothing(单边标签平滑)。平滑,GAN判别函数不会给出太大梯度信号(gradient signal),防止算法走向极端样本陷阱。
    Batch Norm,取一批数据,规范化(normalise,减平均值,除以标准差)。数据更集中,不太大太小,学习效率更高。同一批(batch)数据太相似,无监督GAN,容易被带偏,认为数据都一样,最终生成模型混杂很多其它特征。
    Reference Batch Norm,取一批数据(固定)作参照数据集R,新数据batch依据r平均值、标准差规范化。R取得不好,效果也不会好,或R过拟合。
    Virtual Batch Norm,取R,新数据x规范化,x加入R形成virtual batch V,用V的平均值、标准差来标准化x,极大减少R风险。
    平衡好G、D。通常对抗网络,判别模型D赢,D比G深。用非饱和(non-saturating)博弈写目标函数,保证D学完后,G可以继续学习。
    GAN问题。不收敛(non-convergence),容易只找到局部最优点,非全局最优点,或根本无法收敛。模式崩溃(mode collapse),minmaxV(G,D)不等于maxminV(G,D),如果maxD放在内圈,算法可以收敛到应有位置,如果maxG放在内圈,算法扑向聚集区,看不到全局分布。Reverse KL,保守损失(loss)。
    Minibatch GAN,原数据分成小batch,保证太相惟数据样本不被放到一个小batch,数据足够多样,避免模式崩溃。
    空间理解错误。图片用2D表示3D。图片样本生成图片,空间表达不好。
    Unrolled(不滚) GAN,每一步不把判别模型D滚起,把K次D存起,根据损失(loss)选择最好。
    无法科学评估,无法量化标准。
    离散输出,无法微分(differentiate)。Williams(1992) REINFORCE。Jang et al.(2016) Gumbel-softmax。用连续数值训练,框定范围,输出离散值。
    强化学习连接,无法收敛,有限步数,穷举更简单粗暴效果好。
    PPGN(Plug and Play Generative Models 即插即用生成模型),Nguten et al,2016。生成模型领域新State-of-the-art(当前最佳)。
    GAN用可以利用监督学习估测复杂目标函数生成模型,GAN内部自己拿真假样本对照。高维连续非凸找纳什均衡,有待研究。

    参考资料:
    https://zhuanlan.zhihu.com/p/24767059
    http://www.sohu.com/a/121189842_465975

    欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

    群体经过适当组织,可以互相促进。我一直相,信良性互动可以使彼此更快速地成长,帮助更多人进入自己感兴趣的领域。我在创建一个一起学习GAN的微信群,我们以每天各报各的学习进度为主。加我微信,我会把你拉进群里,加的时候请注明:加入GAN日报群。

  • 相关阅读:
    Action<T>和Func<T>的使用(转)
    javascript将网页表格导出Word(转载)
    jQuery扩展方法2(转载)
    Microsoft SQL Server 2008 安装说明
    INSERT INTO 语句的语法错误 access 保留字段
    .NET 验证控件
    TFS强制撤销用户的签出
    .net 操作XML
    TFS 恢复删除文件
    checkboxlist详细用法、checkboxlist用法、checkboxlist
  • 原文地址:https://www.cnblogs.com/libinggen/p/7476259.html
Copyright © 2011-2022 走看看