Self-Attention in Transformers

Problem: Explain how the transformer architecture works at a mathematical level (e.g. as outlined in the Attention Is All You Need paper).

Solution:

  1. (Tokenization) Partition the inputted natural language text into a sequence of tokens \(\tau_1,…,\tau_N\) (here \(N\) is bounded from above by the LLMs context size).
  2. (Embedding) Each token \(\tau_i\) is embedded as a vector \(\tau_i\mapsto\mathbf x_i\in\mathbf R^{n_e}\) that contains information about the token’s generic meaning as well as its position in the inputted natural language text. Here, the hyperparameter \(n_e\) is the dimension of the embedding space.
  3. (Single Self-Attention Head) For each embedding vector \(\mathbf x_i\), compute its query vector \(\mathbf q_i=W_{\mathbf q}\mathbf x_i\), its key vector \(\mathbf k_i=W_{\mathbf k}\mathbf x_i\), and its value vector \(\mathbf v_i=W_{\mathbf v}\mathbf x_i\). Here, \(W_{\mathbf q},W_{\mathbf k}\in\mathbf R^{n_{qk}\times n_e}\) are weight matrices that map from the embedding space \(\mathbf R^{n_e}\) to the query/key space \(\mathbf R^{n_{qk}}\) of dimension \(n_{qk}\) and \(W_{\mathbf v}\in\mathbf R^{n_e\times n_e}\) is the weight matrix of values (which in practice is decomposed into a low-rank approximation \(W_{\mathbf v}=W_{\mathbf v\uparrow}W_{\mathbf v\downarrow}\) where typically \(W_{\mathbf v\downarrow}\in\mathbf R^{n_{qk}\times n_e}\) and \(W_{\mathbf v\uparrow}\in\mathbf R^{n_e\times n_{qk}}\)). For each \(\mathbf x_i\), one computes an update vector \(\Delta\mathbf x_i\) to be added to it according to a convex linear combination of the value vectors \(\mathbf v_1,…,\mathbf v_N\) of all the embeddings \(\mathbf x_1,…,\mathbf x_N\) in the context, specifically:

\[\Delta\mathbf x_i=V\text{softmax}\left(\frac{K^T\mathbf q_i}{\sqrt{n_{qk}}}\right)\]

where \(K=(\mathbf k_1,…,\mathbf k_N)\in\mathbf R^{n_{qk}\times N}\) and \(V=(\mathbf v_1,…,\mathbf v_N)\in\mathbf R^{n_e\times N}\) are key and value matrices associated to the inputted context (filled with column vectors here rather than the ML convention of row vectors). This map that takes the initial, generic token embeddings \(\mathbf x_i\) and nudges them towards more contextualized embeddings \(\mathbf x_i\mapsto\mathbf x’_i=\mathbf x_i+\Delta\mathbf x_i\) is called a head of self-attention. The \(1/\sqrt{n_{qk}}\) scaling in the softmax temperature is justified on the grounds that if \(\mathbf k\) and \(\mathbf q\) are random vectors whose independent components each have mean \(0\) and variance \(1\), then \(\mathbf k\cdot\mathbf q\) will have mean \(0\) and variance \(n_{qk}\), hence the need to normalize by \(\sqrt{n_{qk}}\) to ensure \(\mathbf k\cdot\mathbf q/\sqrt{n_{qk}}\) continues to have variance \(1\).

4. (Multi-Headed Self-Attention) Since context can influence meaning in different ways, repeat the above procedure in parallel for several heads of self-attention; each head will propose a displacement update to each of the \(N\) original embeddings \(\mathbf x_i\); add up all of them.

5. (Multilayer Perceptron) Linear, ReLU, Linear basically. It is hypothesized that facts are stored in this part of the transformer.

6. (Layers) Alternate between the multi-headed self-attention blocks and MLP blocks, make a probabilistic prediction of the next token \(\hat{\tau}_{N+1}\) using only the final, context-rich, modified embedding \(\mathbf x’_N\) of the last token \(\tau_N\) in the context by applying an unembedding matrix \(\mathbf u=W_{\mathbf u}\mathbf x’_N\) and running it through a softmax \(\text{softmax}(\mathbf u)\).

    Problem: Based on the above discussion of the transformer architecture, explain how a large language model (LLM) like Gemini, ChatGPT, Claude, Grok, DeepSeek, etc. works (at a high level).

    Solution: Essentially, since an LLM is a neural network which takes as input some string of text and probabilistically predicts the next token, by seeding it with some corpus of text \(T\), the LLM can sample according to the probability distribution it generates for the next token, and append that to \(T\mapsto T+\tau\). Then, simply repeat this except pretend that \(T+\tau\) was the seed all along. In this way, generative AI models such as ChatGPT (where GPT stands for generative pre-trained transformer) work. In practice, it is helpful to also provide some system prompt like “What follows is a conversation between a user and a knowledgeable AI assistant:”.

    This entry was posted in Blog. Bookmark the permalink.

    Leave a Reply

    Your email address will not be published. Required fields are marked *