Sébastien Jean, Kyunghyun Cho, Roland Memisevic, Yoshua Bengio
ACL 2015 | arxiv
Seq2seq models struggle with large vocabularies: training and decoding complexity increses proportionally. Besides, training with a shortlist of words (k most frequent) hurts the model's performance (especially with languages having a rich set of words (german, arabic))
Stochastic approximation via noise contrastive estimation
Words clustered into hierarchical classes and the proba is factorized as \(p(y|h) = p(c|h). p(y|c, h)\) (class x intraclass proba)
Both reduces the computational complexity of training but not that of decoding.
Propose an approximate training based on (biased) importance sampling.
The computational complexity due to the vocabulary size stems from the dot product \(w^Th_t\) at each time step needed to retrieve the probability distribution over the vocabulary.
The proposed approach renders training complexity constant wrt. the size of the vocabulary.
At each update, use only a small subset of the vocabulary \(V'\)
\[ p(y_t| h_t) = \frac{1}{Z} exp\left( w_t^Th_t + b_t\right) \\ Z = \sum_{k:y_k\in V} exp\left( w_k^Th_t + b_k\right) \] Let \(\mathcal E\) be the energy defined as: \[ \mathcal E(y_j) = w_j^Th_t + b_j \] Thus: \[ \begin{align} \nabla \log p(y_t| h_t) & = \nabla\mathcal E(y_t) - \sum_{k:y_k\in V} p(y_k|h_t)\nabla\mathcal E(y_k) \\ & = \nabla \mathcal E(y_t) - E_P\left[\nabla \mathcal E(y) \right] \end{align} \]
The main idea is to approximate the expectation term above via importance sampling:
Given a proposal distrib \(Q\) and a subest \(V'\):
\[ E_P\left[\nabla \mathcal E(y) \right] = \sum_{y_k\in V'} \frac{\omega_k}{\sum \omega_k} \nabla \mathcal E(y_k), \] where \(\omega_k = \exp(\mathcal E(y_k) -\log Q(y_k))\)
We partition the training corpus and define a subset \(V'\) for each partition prior to training. \(V'\) is simply the vocab that you would define if the partition is the full traiing corpus.
For each \(V'_i\) assigned the ith partition: \[ Q_i(y_k) = \begin{cases} 1/|V_i'| \;\text{if} \;y_k \in V'_i\\ 0\; \text{otherwise}. \end{cases} \]
This choice cancels out the \(Q\) term in the importance weights.
Either:
Limit to K most frequent term in the full vocab (no gain).
use an existing word alignement model (eg. attention weights of seq2seq-attn or fast_align) to align the source and target words in the training corpus. For each source setence build a target word set consisting og K-most frequence words and K' likely target words for each source word.
WMT'14 English-French and English-German. The full vocab is of 500k words (99.5% coverage compared to 96% coverage of the usual 30k words EN-FR).
\(V'\) is either 15k or 30k
Experimented with shuffling the training corpus every epoch, allowing the words to be contrasted with different sets every time.
Decoding dependancy on other alignement models (fast_align) is not convincing.
The choice of \(Q\) as the uniform distribution is computationally efficient but doesn't make much sense semantically speaking.
Isn't there a better way to partition the training corpus?
A good trick: At the end of the trainign stage, freeze the word emebddings and cotinue training the model > help increase BLEU
Beam search of nmt: normalize proba by the length of the candidate (didn't work well for captioning)
I face the same issues of dot products in the size of the vocabulary with my word-level smoothing.
Check:
stochastic approximation of the vocab distributio Mnih et al. 2013 and Mikolov et al. 2013 based on noise contrastive estimation.
A translation-specific method to handle OOV (Luong et al. 2015)
The parent paper Bengio and Sénécal 2008.