William Fedus; Ian Goodfellow and Andrew M. Dai
ICLR 2018 | arxiv| openreview | code* | slides
Exposure bias & teacher forcing training of autoregressive seq2seq models. GANs was originally designed to output differentiable values, so it's challenging to use them for discrete language generation.
GAN also suffers from issues such as training instability and mode dropping. Mode dropping occurs when the generator rarely generates certain modalities. This is exacerbated in text generation with many complex modes bigrams, phrases, idioms...
Besides, in the autoregressive setting, the discriminator loss is only observed after the full sequence has been sampled which makes the training unstable especially with longuer sequences.
An actor-critic conditional GAN that fills in missing text conditionned on the surrounding context.
Fill-in-the-blank task. It improves the information error signal sent to the generator as it'll only focus on the outlier term i.e. the in-filling
For each sequence \(x=(x_1, .., x_T)\) a binary mask is generated \(m=(m_1, ...m_T)\) to select which tokens remain. if dropped ie \(m_t=0\) \(x_t\) is replaced with a special token \(<m>\). The masked sequence is referred to as \(m(x)\)
The decoder is now conditionned on what it has generated so far as well as the masked input i.e. \[ p(\hat x_1, ... \hat x_T| m(x)) = \prod_{t=1}^T G(x_t)\\ G(x_t) \equiv p(\hat x_t| \hat x_{\leq t-1}, m(x)) \]
The discriminator \(D_\phi\) has a similar architecture with binary output instead of a distribution over the vocabulary. \(D_\phi\) is given the filled in sequence as well as the original input \(x\) and computes the probability of each token \(\tilde x_t\) being real: \[ D_\phi(\tilde x_t | \tilde x, m(x)) = p(\tilde x_t = x_t^* |\tilde x, m(x)) \]
The logarithm of this probability is regarded as the reward: \[ r_t \equiv \log D_\phi(\tilde x_t | \tilde x, m(x)) \]
The critic on top of the \(D_\phi\) estimates the value function, which is the discounted total return of the filled-in sequence \(R_t = \sum_{s=t}^T \gamma^s r_s\), \(\gamma\) being the discount factor (a token generated at t will influence the rewards of that time step and subsequent time steps).
The model is not differentiable due to the sampling performed by the generator. Therefore, the generator's gradient is estimated via policy gradient.
\(G_\theta\) seeks to maximize the cumulative reward \(R=\sum_t R_t\), thus we perform gradient ascent on \(E_{G_\theta}[R]\).
An unbiased estimator of the gradient for a single token would be:
\[\nabla_\theta E_{G_\theta}[R_t] = R_t \nabla_\theta \log G_\theta(\hat x_t), \]
the variance of which may be reduced by using the learnt value function as a baseline \(b_t=V^G(x_{1:t})\) which is produced by the critic. Hence,
\[\nabla_\theta E_{G_\theta}[R_t] = (R_t -b_t) \nabla_\theta \log G_\theta(\hat x_t), \]
The discriminator will be updates using the gradient \[ \nabla_\phi \frac{1}{m} \sum_{i=1}^m \left[ \log D(x^i) + \log (1 - D(G(z^i)))\right] \]
Modify the core algorithm with a dynamic task consisting of training up to a maximum length \(T\) then increment the length upon convergence.
Intuition: capture dependencies over short sequences before moving to the longuer ones. (cf. curriculum learning)
Instead of computing of only the sampled token, use the full information outputed by the generator i.e. the probability distribution over the full vocabulary and evaluate the reward of each possibility. Computationally costly but might be beneficial.
A language model is trained using MLE, then this pretrained model initializes a seq2seq encoder and decoders. This seq2seq model is furthermore trained on the fill-in task with MLE (MaskMLE).
The model with the lowest validation perplexity is selected via a hypermatarer sweep over 500 runs !!!
The Penn Treebank (PTB) and IMDB datasets.
The model can be run in unconditional mode where all the context is masked.
The maskGAN generates samples which are more likely under the MLE than MaskMLE which translates into a low perplexity under the pretrained LM.
Qualitatively, the human evaluation favors the MaskGAN as well:
Mode collapse:
Mode Dropping is less extreme than SeqGAN but still noticeable:
Check
Designing error attribution per time step in prior NLP GANs (Yu et al. 2017 & Li et al. 2017)