Read-through: Wasserstein GAN
I really, really like the Wasserstein GAN paper. I know it’s already gotten a lot of hype, but I feel like it could use more.
I also think the theory in the paper scared off a lot of people, which is a bit of a shame. This is my contribution to make the paper more accessible, while hopefully retaining the thrust of the argument.
Why Is This Paper Important?
There’s a giant firehose of machine learning papers - how do you decide which ones are worth reading closely?
For Wasserstein GAN, it was mostly compelling word of mouth.
- The paper proposes a new GAN training algorithm that works well on the common GAN datasets.
- Said training algorithm is backed up by theory. In deep learning, not all theory-justified papers have good empirical results, but theory-justified papers with good empirical results have really good empirical results. For those papers, it’s very important to understand their theory, because the theory usually explains why they perform so much better.
- I heard that in Wasserstein GAN, you can (and should) train the discriminator to convergence. If true, it would remove needing to balance generator updates with discriminator updates, which feels like one of the big sources of black magic for making GANs train.
- The paper shows a correlation between discriminator loss and perceptual quality. This is actually huge if it holds up well. In my limited GAN experience, one of the big problems is that the loss doesn’t really mean anything, thanks to adversarial training, which makes it hard to judge if models are training or not. Reinforcement learning has a similar problem with its loss functions, but there we at least get mean episode reward. Even a rough quantitative measure of training progress could be good enough to use automated hyperparam optimization tricks, like Bayesian optimization. (See this post and this post for nice introductions to automatic hyperparam tuning.)
Additionally, I buy the argument that GANs have close connections to actor-critic reinforcement learning. (See Pfau & Vinyals 2017.) RL is definitely one of my research interests. Also, GANs are taking over the world; I should probably keep an eye on GAN papers anyways.
lacksquare■At this point, you may want to download the paper yourself, especially if you want more of the theoretical details. To aid anyone who takes me up on this, the section names in this post will match the ones in the paper.
Introduction
The paper begins with background on generative models.
When learning generative models, we assume the data we have comes from some unknown distributionP_rPr. (The r stands for real.) We want to learn a distribution P_ hetaPθ that approximates P_rPr, where hetaθ are the parameters of the distribution.
You can imagine two approaches for doing this.
- The parameters hetaθ directly describe a probability density. Meaning, P_ hetaPθ is a function such that P_ heta(x) ge 0Pθ(x)≥0 and int_x P_ heta(x)\, dx = 1∫xPθ(x)dx=1. We optimize P_ hetaPθ through maximum likelihood estimation.
- The parameters hetaθ describe a way to transform an existing distribution ZZ. Here, g_ hetagθ is some differentiable function, ZZ is a common distribution (usually uniform or Gaussian), and P_ heta = g_ heta(Z)Pθ=gθ(Z).
The paper starts by explaining why the first approach runs into problems.
Given function P_ hetaPθ, the MLE objective is
max_{ heta in mathbb{R}^d} frac{1}{m}sum_{i=1}^m log P_ heta(x^{(i)})θ∈Rdmaxm1i=1∑mlogPθ(x(i))In the limit, this is equivalent to minimizing the KL-divergence KL(P_r | P_ heta)KL(Pr∥Pθ).
Aside: Why Is This True?
Recall that for continuous distributions PP and QQ, the KL divergence is
KL(P || Q) = int_x P(x) log frac{P(x)}{Q(x)} \,dxKL(P∣∣Q)=∫xP(x)logQ(x)P(x)dxIn the limit (as m o inftym→∞), samples will appear based on the data distribution P_rPr, so
egin{aligned} lim_{m o infty} max_{ heta in mathbb{R}^d} frac{1}{m}sum_{i=1}^m log P_ heta(x^{(i)}) &= max_{ heta in mathbb{R}^d} int_x P_r(x) log P_ heta(x) \, dx \ &= min_{ heta in mathbb{R}^d} -int_x P_r(x) log P_ heta(x) \, dx \ &= min_{ heta in mathbb{R}^d} int_x P_r(x) log P_r(x) \, dx -int_x P_r(x) log P_ heta(x) \, dx \ &= min_{ heta in mathbb{R}^d} KL(P_r | P_ heta) end{aligned}m→∞limθ∈Rdmaxm1i=1∑mlogPθ(x