Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen
ICLR 2018 | openreview | code* (pytorch)
The expressiveness of softmax based model is limited by a softmax bottleneck. It lacks the capacity to model natural language in its complexity.
Formulating language modeling as a matrix factorization problem.
The RNN autoregressive approach consists of predicting the next token condtionned on a fixed size vector (context vector) encoding the previously generated tokens as well as the source code.
The probability distribution over the vocabulary is given by a softmax of the logits obtained by a dot product between the word emebddings and the context vector.
With the matrix factorization formulation of the problem, we study the effectiveness of the softmax-based RNN and introduce discrete latent variables into the RNN LM and and formulate the next token probability as Mixture of Softmaxes (MoS)
Consider a limited set of possible contexts and the pairs of (context, conditional next token distribution) i.e \(\{(c_i, P^*(X|c_i))\}_{1\leq i\leq N}\). We assume the data distribution \(P^* > 0\) everywhere.
Given a ensemble of parametric distributions \(\{P_\theta\}\) is there a \(\theta\) such that \(P_\theta\) matches the data distribuion \(P^*\)
For the sotmax models we define the following matrices:
A set of matrices is furthermore formed by applying row-wise shift to A:
\[ F(A) = \{ A+ \Lambda J_{N, M} | \Lambda\; \text{is diagonal from} \;\mathbb R^{N\times N})\}, \] where \(J\) is the all ones matrix. This simply add an arbitrary scalar \(\lambda_i\) to each row.
For \(A'\) a matrix of logits we derive the underlying distribution \(P_{A'}\) by applying a softmax. ie. \(P_{A'}(x_j|c_i)=\frac{\exp(A'_{i,j})}{\sum\exp A'_{ik}}\)
\[\begin{align} & (P1) \; A' \in F(A) \iff Softmax(A')=P^*\\ &(P2)\; \forall A_1\neq A_2 \in F(A),\; |rank(A_1) - rank(A_2)| \leq 1. \end{align}\] For (P1) the added constant to every row gets simply canceled out by the softmax.
The \(expressiveness\) issue becomes: \[ \exists? \theta\; \exists A'\in F(A), s.t\, H_\theta W_\theta^T = A' \]
Some rank contraints arise: first, we need the rank \(HW^T\) to be at least as large as the rank of \(A\). Second, given the dimensions of \(H\) and \(W\), the rank of \(HW^T\) is upper bounded by \(d\). So if \(d < rank(A')\) the model is simply not experssive enough.
The softmax bottleneck is expressed as follows:
if \(d < rank(A) - 1\), for any function family \(\mathcal U\) and any model parameter \(\theta\), there exists a context \(c\) s.t. \(P_\theta(X|c)\neq P^*(X|c)\)
It's hypothesized that the natural language true log probability matrix is of a high rank since the probelm is highly context dependent (fair enough)
Naively increasing the experssiveness by increasing \(d\) is not a fix; it explodes the parametric space and thus the model will be prone to overfitting.
Introduce Mixture of softmaxes:
\[ P_\theta(x|c) = \sum_{k=1}^K \pi_{c,k} \frac{\exp(h_{c,k}^Tw_x)}{\sum_{x'}\exp(h_{c,k}^Tw_{x'})}, \] where \(\pi_{c,k}\) is the prior or mixture weight s.t. \(\sum_k\pi_{c,k} = 1\).
The prior \(\pi_{c,k}\) and the kth context vector \(h_{c,k}\) are obtained by applying a stack of recurrent layers on top of \(X\) to obtain hidden states \((g_1, ...g_T)\) and then:
\[ \pi_{c_t,k} = softmax(w_{\pi, k}^Tg_t)\\ h_{c_t, k} = tanh(W_{h,k}g_t), \] where \(W_{h,k}\) and \(w_{\pi, k}\) are model parameters.
In matrix form:
\[ \hat A_{MoS} = \log \sum_{k=1}^K \Pi_k \exp(H_{\theta, k}W_\theta^T) \]
Given the non-linearity (log sum exp) this matrix can be arbitrarly high-rank.
An aletrnative to mixing the softmaxes is to simply mix the contex vector prior to applying softmax. However this suffers from the same lack of expressiveness as a simple softmax.
Language modeling on PTB , WikiText-2, 1B Word and the SwitchBoard dialog dataset.
Results on SwitchBoard:
Beautifully written and argued with extensive experiments.
Check: