Attention in transformers, step-by-step by 3Blue1Brown
. Source. 2024. Uses on GPT-3 as basis.For inferencing just a single prompt, things appear to be very obviously memory bound, i.e. bound by the transfer speeds of VRAM to GPU cache for loading model parameters into GPU so they can be used, supposing that the model fits in VRAM, which is the case for many popular models.
It is however possible to make fuller utilization of the GPU's compute power by running multiple independent queries in parallel, this way you load the subset of model weights that you need, and then use those to do part of the inference for multiple input prompts. With this it should be possible to reach full utilization.
The following is for a "classic" GPT-2-style model, the following estimates the number attention multiplications.
For each layer (L):So the total sum is:
- for each attention head (h):
- K = d_model * d_head (takes embedding of one token and converts to vector of length d_head)
- Q = d_model * d_head (same)
- K Q dot product for attention pattern: n_ctx * d_head (n_ctx times dot products of vectors of size d_head, once new K vs every Q. Q vs every K zeroed out by causality.)
- new value vector for new token: d_model * d_model
- new updates: n_ctx * d_model (multiply each value vector by the new attention column scalar)
- fully connected: d_model * d_ff + d_ff * d_model (converts the embedding to the hidden layer size and then back)
L * (
h * (
2 * d_model * d_head +
n_ctx * d_head +
d_model * d_model +
n_ctx * d_model
) +
2 * d_model * d_ff
)
This is coded at: llm_count_mults.py.
Bibliography:
Homepage: www.llama.com/
Page: www.llama.com/llama2/
Articles by others on the same topic
There are currently no matching articles.