The main idea of XLNet is to model language autoregressively like the GPT models, but allow for all possible permutations of a sentence.2 Concretely, consider the following sentence:
My dog is cute.
In standard autoregressive language modeling, the model would be tasked with predicting the probability of each word, conditioned on the previous words as its context:
We factorize the joint probability of a sequence of words x 1 , … , x T {\displaystyle x_{1},\ldots ,x_{T}} using the chain rule: Pr ( x 1 , … , x T ) = Pr ( x 1 ) Pr ( x 2 | x 1 ) Pr ( x 3 | x 1 , x 2 ) … Pr ( x T | x 1 , … , x T − 1 ) . {\displaystyle \Pr(x_{1},\ldots ,x_{T})=\Pr(x_{1})\Pr(x_{2}|x_{1})\Pr(x_{3}|x_{1},x_{2})\ldots \Pr(x_{T}|x_{1},\ldots ,x_{T-1}).}
For example, the sentence "My dog is cute" is factorized as:
Pr ( My , dog , is , cute ) = Pr ( My ) Pr ( dog | My ) Pr ( is | My , dog ) Pr ( cute | My , dog , is ) . {\displaystyle \Pr({\text{My}},{\text{dog}},{\text{is}},{\text{cute}})=\Pr({\text{My}})\Pr({\text{dog}}|{\text{My}})\Pr({\text{is}}|{\text{My}},{\text{dog}})\Pr({\text{cute}}|{\text{My}},{\text{dog}},{\text{is}}).}
Schematically, we can write it as
<MASK> <MASK> <MASK> <MASK> → My <MASK> <MASK> <MASK> → My dog <MASK> <MASK> → My dog is <MASK> → My dog is cute . {\displaystyle {\texttt {<MASK>}}{\texttt {<MASK>}}{\texttt {<MASK>}}{\texttt {<MASK>}}\to {\text{My }}{\texttt {<MASK>}}{\texttt {<MASK>}}{\texttt {<MASK>}}\to {\text{My dog }}{\texttt {<MASK>}}{\texttt {<MASK>}}\to {\text{My dog is }}{\texttt {<MASK>}}\to {\text{My dog is cute}}.}
However, for XLNet, the model is required to predict the words in a randomly generated order. Suppose we have sampled a randomly generated order 3241, then schematically, the model is required to perform the following prediction task:
<MASK> <MASK> <MASK> <MASK> → <MASK> <MASK> is <MASK> → <MASK> dog is <MASK> → <MASK> dog is cute → My dog is cute {\displaystyle {\texttt {<MASK>}}{\texttt {<MASK>}}{\texttt {<MASK>}}{\texttt {<MASK>}}\to {\texttt {<MASK>}}{\texttt {<MASK>}}{\text{is }}{\texttt {<MASK>}}\to {\texttt {<MASK>}}{\text{dog is }}{\texttt {<MASK>}}\to {\texttt {<MASK>}}{\text{dog is cute}}\to {\text{My dog is cute}}}
By considering all permutations, XLNet is able to capture longer-range dependencies and better model the bidirectional context of words.
To implement permutation language modeling, XLNet uses a two-stream self-attention mechanism. The two streams are:
The content stream uses the causal mask M causal = [ 0 − ∞ − ∞ … − ∞ 0 0 − ∞ … − ∞ 0 0 0 … − ∞ ⋮ ⋮ ⋮ ⋱ ⋮ 0 0 0 … 0 ] {\displaystyle M_{\text{causal}}={\begin{bmatrix}0&-\infty &-\infty &\dots &-\infty \\0&0&-\infty &\dots &-\infty \\0&0&0&\dots &-\infty \\\vdots &\vdots &\vdots &\ddots &\vdots \\0&0&0&\dots &0\end{bmatrix}}} permuted by a random permutation matrix to P M causal P − 1 {\displaystyle PM_{\text{causal}}P^{-1}} .
The query stream uses the cross-attention mask P ( M causal − ∞ I ) P − 1 {\displaystyle P(M_{\text{causal}}-\infty I)P^{-1}} , where the diagonal is subtracted away specifically to avoid the model "cheating" by looking at the content stream for what the current masked token is.
Like the causal masking for GPT models, this two-stream masked architecture allows the model to train on all tokens in one forward pass.
Two models were released:34
It was trained on a dataset that amounted to 32.89 billion tokens after tokenization with SentencePiece. The dataset was composed of BooksCorpus, and English Wikipedia, Giga5, ClueWeb 2012-B, and Common Crawl.
It was trained on 512 TPU v3 chips, for 5.5 days. At the end of training, it still under-fitted the data, meaning it could have achieved lower loss with more training. It took 0.5 million steps with an Adam optimizer, linear learning rate decay, and a batch size of 8192.5
"xlnet". GitHub. Retrieved 2 January 2024. https://github.com/zihangdai/xlnet/ ↩
"Pretrained models — transformers 2.0.0 documentation". huggingface.co. Retrieved 2024-08-05. https://huggingface.co/transformers/v2.0.0/pretrained_models.html ↩
Yang, Zhilin; Dai, Zihang; Yang, Yiming; Carbonell, Jaime; Salakhutdinov, Ruslan; Le, Quoc V. (2 January 2020). "XLNet: Generalized Autoregressive Pretraining for Language Understanding". arXiv:1906.08237 [cs.CL]. /wiki/ArXiv_(identifier) ↩