Transformer (Oversimplified)

It’s hard to identify the pre-requisites to understand this blog. A lot of these concepts are understood in some capacity at some high level by a lot of people. This is an attempt to unify abstract vague definitions into a more tangible, traceable illustration. This was written for the freshman version of myself trying to implement Andrej Karpathy’s micrograd course (free on YT).

3 components to a transformer:

  • Attention
  • MLP (dense feed-forward neural network)
  • Normalisation

But before that, let’s trace a simple neural network:

Sequence Length (L): 50 (50 tokens in each sentence)

Embedding Dimension (D): 512 (each token is a vector of 512 numbers)

tracing embedding:

Assume vocab size is 50000 tokens.

[50000, 512] is Embedding matrix dimensions

→ [50, 512] (pluck out the dimensions for the 50 tokens of the sentence)

tracing network:

[50, 512] Input dimensions into the neural net (each input token is embedded and stacked)

[512, 1024] let’s set 1024 neurons in the hidden layer. each has a weight vector of length 512.

[50, 512] x [512, 1024] → [50, 1024] is the hidden layer output.

i.e 1024 length vector represents activations of each of the 50 tokens

[50, 1024] x [1024, 512] → [50, 512] is the output layer.

i.e 512 length vector represents activations. each of the 50 tokens has one.

Point 1: To visualise it easier, just consider 1 token. Then it becomes 512 (input) → 1024 → 512 (output) which corresponds to a simple neural net visualisation. But in reality each of the just saying “1024” misses the point since it’s actually 1024 neurons with each neuron being a weight vector of length 512. And similarly “512” (output layer) is actually 512 neurons each containing a weight vector of 1024 dimensions from the previous layer.

Point 2: notice the 50 remains through the trace. Once we add in attention, this 50 is the context window of attention.

Point 3: the 50 words never actually talk to each other. The training objective is prediction of the next word (that’s what backpropogates). But all 50 words in the sequence are independently trying to predict the next word for itself with no interaction between them. So effectively this becomes a bigram model.

Point 4: there’s usually also a batches aspect (8 batches of this process) which is so GPU can optimise training. But I have excluded.

tracing generation:

[512, 50000] is transposed embedding matrix = “Language model head” or “LM head”

[50, 512] x [512, 50000] → [50, 50000] is the 50 tokens to be generated

i.e row 1 predicts word 2 in the sequence. row 2 predicts word 3 in the sequence… and thus row 50 predicts the 51st word (but no attention means it’s based only on the 50th word)

now add in attention:

[50, 1024] is the activations of the hidden layer.

instead of throwing it into a 512 neuron output layer, let’s throw it into 3 separate linear layers

[1024, 512], [1024, 512], [1024, 512]

→ [50, 512], [50, 512], [50, 512] output activations of the 3 linear layers

[50, 512] x [512, 50] transpose one of them = “Keys”; and multiply with another = “Queries”

[50, 50] attention matrix is found. Represents activation of each token in sequence with the other.

[50, 50] x [50, 512] multiply attention matrix with the third matrix = “Values”

→ [50, 512] output activations which have attention/ spoken to each other.

“Normalisation” = This layer standardizes the inputs across the features to a mean of 0 and a variance of 1, which prevents the internal values from exploding or vanishing and makes training significantly faster and more stable

“Attention Heads” = In above trace, I used a single 512-wide block. In reality, the model splits that 512 into "Heads."

  • If we have 8 heads, the 512 dimensions are split into 8 chunks of 64.
  • New Shape: [8,50,64] (8 heads, 50 tokens, 64 features per head).
  • They are Concatenated at the end to return to [50, 512].

“Residual Connections” = The model doesn't just take that "Attention Output" and move on. It uses a Residual Connection (Add).

  • The Math: Input[50,512]+Attention Output[50,512].
  • Intuition: This ensures that if the attention mechanism fails or learns nothing, the original identity of the token is preserved. It prevents the signal from getting "lost" in deep networks.

P.S. This is a extremely oversimplified explanation of attention. There’s all sorts of strategies for softmax, normalizing, batches etc. which has been deliberately excluded for simplicity.