On Using Very Large Target Vocabulary for Neural Machine Translation

Sébastien Jean, Kyunghyun Cho, Roland Memisevic, Yoshua Bengio

ACL 2015 | arxiv

Problematic:

Seq2seq models struggle with large vocabularies: training and decoding complexity increses proportionally. Besides, training with a shortlist of words (k most frequent) hurts the model's performance (especially with languages having a rich set of words (german, arabic))

Both reduces the computational complexity of training but not that of decoding.

Contributions:

Propose an approximate training based on (biased) importance sampling.

The computational complexity due to the vocabulary size stems from the dot product \(w^Th_t\) at each time step needed to retrieve the probability distribution over the vocabulary.

The proposed approach renders training complexity constant wrt. the size of the vocabulary.

At each update, use only a small subset of the vocabulary \(V'\)

\[ p(y_t| h_t) = \frac{1}{Z} exp\left( w_t^Th_t + b_t\right) \\ Z = \sum_{k:y_k\in V} exp\left( w_k^Th_t + b_k\right) \] Let \(\mathcal E\) be the energy defined as: \[ \mathcal E(y_j) = w_j^Th_t + b_j \] Thus: \[ \begin{align} \nabla \log p(y_t| h_t) & = \nabla\mathcal E(y_t) - \sum_{k:y_k\in V} p(y_k|h_t)\nabla\mathcal E(y_k) \\ & = \nabla \mathcal E(y_t) - E_P\left[\nabla \mathcal E(y) \right] \end{align} \]

The main idea is to approximate the expectation term above via importance sampling:

Given a proposal distrib \(Q\) and a subest \(V'\):

\[ E_P\left[\nabla \mathcal E(y) \right] = \sum_{y_k\in V'} \frac{\omega_k}{\sum \omega_k} \nabla \mathcal E(y_k), \] where \(\omega_k = \exp(\mathcal E(y_k) -\log Q(y_k))\)

Choice of \(V'\):

We partition the training corpus and define a subset \(V'\) for each partition prior to training. \(V'\) is simply the vocab that you would define if the partition is the full traiing corpus.

For each \(V'_i\) assigned the ith partition: \[ Q_i(y_k) = \begin{cases} 1/|V_i'| \;\text{if} \;y_k \in V'_i\\ 0\; \text{otherwise}. \end{cases} \]

This choice cancels out the \(Q\) term in the importance weights.

Decoding time:

Either:

Experiments:

WMT'14 English-French and English-German. The full vocab is of 500k words (99.5% coverage compared to 96% coverage of the usual 30k words EN-FR).

\(V'\) is either 15k or 30k

Experimented with shuffling the training corpus every epoch, allowing the words to be contrasted with different sets every time.

Issues & comments:

Check: