Number of multiplications per token in a GPT model (source code)

= Number of multiplications per token in a <GPT> model

The following is for a "classic" <GPT-2>-style model, the following estimates the number attention multiplications.

For each layer (L):
* for each attention head (h):
  * K = d_model * d_head (takes embedding of one token and converts to vector of length d_head)
  * Q = d_model * d_head (same)
  * K Q dot product for attention pattern: n_ctx * d_head (n_ctx times dot products of vectors of size d_head, once new K vs every Q. Q vs every K zeroed out by causality.)
  * new value vector for new token: d_model * d_model
  * new updates: n_ctx * d_model (multiply each value vector by the new attention column scalar)
* fully connected: d_model * d_ff + d_ff * d_model (converts the embedding to the hidden layer size and then back)
So the total sum is:
``
L * (
  h * (
    2 * d_model * d_head +
    n_ctx * d_head +
    d_model * d_model +
    n_ctx * d_model
  ) +
  2 * d_model * d_ff
)
``

This is coded at: \a[llm_count_mults.py].

Bibliography:
* https://www.reddit.com/r/theydidthemath/comments/1fzrs1k/request_how_many_individual/
* https://www.gaohongnan.com/playbook/training/how_to_calculate_flops_in_transformer_based_models.html#sanity-check-with-palm-paper-s-flops-calculation