Generative Adversarial Networks #

Since their conception in 2014, the use of generative adversarial networks has exploded throughout machine learning, vision and robotics. Alongside novel application domains, much time and effort has been devoted to developing new architectures and training approaches. Modern instantiations of the GAN training paradigm are responsible for some truly remarkable results. However, in trying to access the wealth of material that is out there, the sheer quantity can be difficult to penetrate. This blog is intended to be a brief introduction to generative adversarial networks and their development over the last few years. It is by no means exhaustive but simply communicates my ideas of what GANs are and my interpretation of how we got here. So without further a do...

Machine Learning
Generative Models

$ \def\x{\hat{\mathbf{x}}} \def\xx{x} \def\y{y} \def\z{\epsilon} \def\yy{y} \def\zz{\z} \def\o{\theta} \def\oo{\phi} \def\p{p} \def\q{q_{\oo}} \def\L{L_\o} \def\g{g_\oo} \def\d{d_\o} \def\D{\Delta} \def\E{\mathbb{E}} \def\k{\lambda} \def\bern{\text{Bernoulli}} \def\py{p_\o} \def\jsd{\text{JSD}} \def\kl{\text{KL}} $

Generative Models #

Generative Modeling #

GANs are a type of generative model: given samples $\{ \xx_n \}_{n =1}^N$ from the true data generating process $\xx \sim \p(\xx)$ we aim to find a $q(\xx) \in \mathcal{Q}$ that best approximates $\p$ over a family of possible distibutions $q \in \mathcal{Q}$.

Whilst for low dimensional data points it might be possible to define a model based on some underlying physics or our own intuitions and insights about the data (eg. by assuming a particular parameteric form for the model), this becomes significantly more challenging when we start working with high dimensional data points such as images. In this case the distribution $\p(\xx)$ is highly complex and difficult to specify.

Deep Implicit Models #

Deep implicit models allow a flexible and general class of models to be specified without relying on human insight, bringing the benefits of deep learning to the generative modeling setting . Samples $\xx \sim \q(\xx)$ are generated from a simpler base distribution $p(\zz)=\text{Normal}(\zz)$ as,

\[\xx = \g(\zz) \quad \text{with} \quad \zz \sim p(\zz)\]

where $\g(\zz)$ is a neural network with parameters $\oo$. Whilst this is evidently an attractive alternative, learning the parameters in this case is challenging. Though it is possible to generate samples from $\q$, evaluating the likelihood of a particular data point $\q(\xx)$ is generally intractible, which makes training by maximum likelihood infeasible.

It is worth briefly discussing the role that the variable $\zz \sim \p (\zz)$ plays, overcoming a potential source of confusion. This variable is often referred to as a latent variable in the literature but its role is distinctly different from a latent variable in the Bayesian sense. We do not for example ever consider performing inference over $\zz$ by inferring $p(\zz | \xx)$. Instead we are better thinking of $\zz$ as inducing the underlying uncertainty in the data generating process $\p(\xx)$. For this reason it might be better to refer to $\zz$ as an inducing variable rather than latent.

The Generative Adversarial Paradigm #

An Adversarial Approach to Likelihood Free Learning #

If the likelihood $q_\phi$ were tractible, we might consider training our model using the maximum-likelihood estimate $\oo = \text{arg} \max_\oo \E_{\xx \sim \p(\xx)}[\log \q(\xx)]$ using a dataset of observations $\{\xx_n\}_{n=1}^N$ sampled from the true data generating process $\xx \sim \p(\xx)$1. Unfortunately as we have already discussed this is not the case.

Instead in a generative-adversarial paradigm we introduce a discriminator $\py(\yy \mid \xx) =\bern\left(\d(\xx)\right)$ where the discriminator network $\d(\xx)$ is tasked with predicting the probability that the example $\xx$ is a real or generated datum. The generator $q_\phi(\xx)$ and the discriminator $\py(\yy | \xx)$ are trained together, as the nash equillibrium of a min-max game

\[G = \min_\oo \max_\o \E_{\p(\xx)}[\log{(\d(\xx))}] + \E_{\q(\xx)}[\log{(1 - \d(\xx)})]\]

Here, the discriminator is trained to distinguish between real and simulated data whilst the generator is trained to fool the discriminator.

As an aside, this is not the only way to learn an implicit generative model without evaluating the log-likelihood. This paper provides an excellent discussion of several other approaches.

Theoretical Guarantees #

Whilst, intuitively this training setup might make sense, ideally we need to have some guarantee that in the limit $\q \rightarrow \p$. Luckily the original GAN paper shows that if it is possible to find the optimal discriminator that can perfectly distinguish between $\xx \sim \p(\xx)$ and $\xx \sim \q(\xx)$, then the optimal generator is given by minimising the Jensen-Shannon divergence between $\p$ and $\q$,

\[\jsd(\p, \q) = \frac{1}{2}\kl(\p \| m) + \frac{1}{2}\kl(q \| m)\]

where $m = (\p + \q)/2$. This will be minimised if (and only if) $\q ^*(\xx) = \p(\xx)$.

If Only Life were That Simple… #

In reality, we optimise $G$ using Monte Carlo integration in combination with stochastic gradient optimisation, training the discriminator and generator iteratively. A couple of problems emerge here:

  1. Sub Optimal $\d$: In reality the discriminator $\py(\yy \mid \xx)$ is likely to be imperfect. But for a sub-optimal discriminator we no longer have any guarantees that in the limit $\q^*(\xx)=p(\xx)$.
  2. Loss Saturation: Indeed if we can train the discrimantor to optimality, the $\log{(1 - \d(\xx))}$ term saturates providing zero gradient signal. This can happen early on in training leading to a generator that never learns. One solution that is commonly used in practices is to rely on a sub-optimal discriminator, performing only a few discriminator optimisation steps for each optimisation step of the generator, but this takes us back to the problem in 1.
  3. Mode Collapse: Whilst the optimisation problem $P$ is concave in $\q$ it is not concave in the network parameters $\oo$. This means that multiple local minima exist in practice which can hinder training when using first order gradient methods. A particularly troublesome local minimum can occur, if the generator is able to predict a single realisitc example that the discriminator predicts as being real. In this case the gradient signal continually encourages the generator to carry on generating this single image a phenomena referred to as mode collapse.

Modern Developments #

Since then significant developments have taken place that have allowed GAN approaches to flourish. In getting there several important research strands trying to overcome the problems mentioned above have paid significant dividends…

Network Architectures #

GAN architectures have received a significant amount of interest during this time. The use of CNNs for both the generator and discriminator resulted in significant gains in the first instance. More recently, there has been significant research into the use of self-attention mechanisms particularly within the discriminator model. Such approaches have proven key in generating high resolution images, as has progressively growing networks by gradually training at higher resolutions whilst adding more layers.

Much research has also been given to how to best allow the generative model to utilise the random variable $\zz$. By injecting $\zz$ at deeper levels levels through the network for example it is possible to control higher level variations in the data.

Alongside conditioning on random inputs $\zz$, conditioning on other inputs has also received significant interest (like the number when generating pictures of handwritten digits). This is the central idea between domain transfer approaches for example, such as Pix2Pix, in which the aim is to recast an image from one domain (eg. a map) to that of another (eg. an aerial photo). More recent approaches such as CycleGan, take this a step further learning from unaligned examples only. In this case the forward and backward generative processes are learnt side-by-side with a cyclical consistency loss imposed to further constrain the learning problem.

Optimisation Criteria #

In the original paper it was shown that the optimal discriminator was given by minimising the jenson shannon divergence $p$ and $q$. But this is by no means the only divergence we might consider. This got people wondering whether alternative divergences might result in better optimisation guarantees and properties, helping to overcome some of the limitations of the original formulation.

As perhaps one of the earliest examples of this strand of research the f-GAN paper generalises the results from the original formulation to all f-divergences (and is an excellent read). Multiple other optimisation criteria were also proposed over the coming years. The least squares GAN is a notable and popular example, in which least squares is substituted for a binary cross entropy criterion, to avoid discrimator saturation. This can be shown to be the same as minimising the Pearson $\chi^2$ divergence (a similar formulation had also been proposed in the f-GAN paper previously) and was widely adopted within the domain transfer literature (Eg. Pix2Pix, CycleGan).

However, perhaps the greatest innovation in this area (at least judged by popularity of uptake) was the Wasserstein GAN formulation, which is worthy of a blog post in its own right. Briefly, the authors reconsider the GAN learning problem from an optimal transport perspective proposing to minimise the Wasserstein distance between the generative and real data distributions. This formulation provides a gradient over the entire input space (something which is not guaranteed minimising the KL or Jensen Shannon Divergence) and a generative loss curve which actually decreases as training progresses and is far better aligned to sample quality (the importance of which should not be under stated for the practitioners sanity if nothing else!).

In achieving this the discriminator network must remain lipschitz continuous, which means that the the gradient of the network must remain bounded. Whilst, the original paper enforced this using weight clipping, subsequent papers proposed other methods as superior alternatives such as gradient penalty.

Spectral Normalisation and Lipschitz Continuity #

At around the same time the Wasserstein GAN formulation was released, there was a growing body of research linking training instability to the spectral radius of the weights throughout the network. For example BigGAN identified that at the point of training collapse the singular values of the weight matrices could be seen to explode. Spectral normalisation takes this one step further proposing to fix training instability by limiting the spectral radius of each weight matrix. These approaches were a key step to generalising GAN approaches to high resolution images.

Interestingly, limiting the spectral radius of the discriminator can be shown to be a way of imposing a limit on its lipschitz constant with respect to the 2 norm, and so spectral normalisation provides another alternative to gradient penalty in the Wasserstein setup.

Solved? #

Not quite. These modifications have allowed GANs to achieve some truly remarkable results, that is true. But the GAN training setup is still far from being fully understood. A large scale study comparing different adversarial approaches found that there is no clear winner amongst the proposed solutions, with the best performing approach being highly sensitive to the problem setup and data type. Part of the problem is that evaluating the performance of generative models still remains an open problem. This makes comparing between different approaches challenging and lying down concrete conclusions that can be built on at the next stage difficult. There is still a lot of work to be done here for sure!

Footnotes #

  1. I got down a bit of rabbit hole trying to justify why maximum-likelihood is often our preferred approach. After some digging, from a theoretical stand point not only is the maximum likelihood estimate consistent, meaning that in the limit as $N \rightarrow \infty$ that $\q \rightarrow \p$, it can be shown to be the estimator that converges to $\p$ most quickly. Of course, as is often the case, their are several reasons why things are slightly more complicated in reality. For consitency to hold we require $p$ to belong to $\mathcal{Q}$ which might not hold in practice. Similarly, the log-likelihood is a non convex function in the network parameters $\phi$, making finding the the optimum $\q$ significantly more challenging. Nonetheless, the success of maximum-likelihood estimation across Machine Learning, provides compelling evidence in its favor. If anyone else has any other insights into why ML estimation is the preferred approach I would love to hear them. Just drop me a message.