= Number of multiplications per token in a <GPT> model
The following is for a "classic" <GPT-2>-style model, the following estimates the number attention multiplications.
For each layer (L):
* for each attention head (h):
* K = d_model * d_head (takes embedding of one token and converts to vector of length d_head)
* Q = d_model * d_head (same)
* K Q dot product for attention pattern: n_ctx * d_head (n_ctx times dot products of vectors of size d_head, once new K vs every Q. Q vs every K zeroed out by causality.)
* new value vector for new token: d_model * d_model
* new updates: n_ctx * d_model (multiply each value vector by the new attention column scalar)
* fully connected: d_model * d_ff + d_ff * d_model (converts the embedding to the hidden layer size and then back)
So the total sum is:
``
L * (
h * (
2 * d_model * d_head +
n_ctx * d_head +
d_model * d_model +
n_ctx * d_model
) +
2 * d_model * d_ff
)
``
This is coded at: \a[llm_count_mults.py].
Bibliography:
* https://www.reddit.com/r/theydidthemath/comments/1fzrs1k/request_how_many_individual/
* https://www.gaohongnan.com/playbook/training/how_to_calculate_flops_in_transformer_based_models.html#sanity-check-with-palm-paper-s-flops-calculation
Back to article page