01 Training a model is all about efficiency
Lecture 01 of CS336 Language Modeling from Scratch at Stanford.
Introduction
Before writing any training code, we need to realize that training a large language model is fundamentally an exercise in resource allocation. We have a compute budget, and we need to decide how to spend it. To make the decision, we need two numbers: how many parameters the model has, and how many FLOPs it takes to train. Specifically, we need to understand the basics:
- What does “large” mean?
- How much does it cost to train?
Everything else follows from those.
What makes a model “large”?
The powerful AI systems we use today, such as ChatGPT, Claude, and Gemini, are all built on large language models. Much of their capability comes from a phenomenon called emergent abilities: behaviors that appear suddenly as models scale up, and are absent in smaller models. Things like multi-step reasoning, in-context learning, and instruction following weren’t explicitly trained for — they emerged from scale.
So what is a “large” model exactly? The short answer is parameter count. Models today are in the hundreds of billions to trillions. For example, GPT-4 supposedly has 1.8T parameters. These numbers matter because parameters determine memory footprint, inference cost, and — up to a point — capability. So the natural first question is: for a Transformer of a given shape, how many parameters does it have?
Once we have that, we can ask the second question.
How much compute does it take to train and run a large model?
Before talking about training costs, we need a unit. The standard one is FLOPs — floating point operations. Specifically, in neural networks we count a fused multiply-add as two operations (one multiply, one add).
For a single linear layer mapping a vector $(1 \times d)$ through a weight matrix $(d \times o)$:
We do $O(d \cdot o)$ multiplications and $O(d \cdot o)$ additions, so $\text{FLOPs} = 2do$.
For a full matrix multiply $X_{[n \times d]} \times W_{[d \times o]}$, that’s $n$ such vectors, so:
\[\text{FLOPs} = 2ndo\]i.e., $\text{FLOPs} = 2 \times (\text{rows of left matrix}) \times (\text{inner dim}) \times (\text{cols of right matrix})$
Model Size and Memory Estimation
How many parameters does a Transformer have?
Almost every large language model today is built on the Transformer architecture.
Let $L$ = number of layers, $d$ = model dimension, $V$ = vocabulary size. Biases and layer norms are small enough to ignore.
Each transformer layer has four linear projections in attention and two in the MLP:
| Component | Parameters |
|---|---|
| Q, K, V projections | $3d^2$ |
| Output projection | $d^2$ |
| MLP up-projection ($d \to 4d$) | $4d^2$ |
| MLP down-projection ($4d \to d$) | $4d^2$ |
| Total per layer | $\mathbf{12d^2}$ |
Summing over all layers and adding embeddings:
\[\boxed{N \approx 12Ld^2 + 2Vd}\]The $2Vd$ term is the input embedding matrix plus the output (unembedding) matrix. These are often weight-tied in practice, bringing it down to $Vd$. For large $d$, the embedding term is usually negligible compared to $12Ld^2$.
Memory Estimation
Knowing the model size $N$ also tells us how much memory the model occupies. Each parameter is stored as a floating-point number, and the size depends on the precision:
| Precision | Bytes per parameter |
|---|---|
| float32 (fp32) | 4 bytes |
| float16 / bfloat16 (fp16/bf16) | 2 bytes |
| int8 quantization | 1 byte |
So for a model with $N$ parameters in fp16:
\[\text{Memory (bytes)} = 2N\]For example, a 70B parameter model in fp16 requires roughly $2 \times 70 \times 10^9 = 140\text{ GB}$ just to store the weights.
During training, memory pressure is much higher. With the Adam optimizer, we need to store:
- Weights: $2N$ bytes (fp16)
- Gradients: $2N$ bytes (fp16), same as weights because each parameter has a gradient.
-
Optimizer states (fp32 master weights + Adam’s first and second moments): $4N + 4N + 4N = 12N$ bytes
Adam (Adaptive Moment Estimation) is the standard optimizer for LLMs. It maintains two extra states per parameter:
- $m$: exponential moving average of gradients (first moment)
- $v$: exponential moving average of squared gradients (second moment)
The update rule is: \(\theta \leftarrow \theta - \frac{\eta}{\sqrt{v} + \epsilon} \cdot m\) where $\eta$ is the learning rate and $\epsilon$ is a small constant for numerical stability. Dividing by $\sqrt{v}$ shrinks the step size for parameters that have seen large gradients historically, giving each parameter an adaptive learning rate.
To preserve numerical precision during updates, Adam maintains an fp32 master copy of the weights alongside $m$ and $v$, all three in fp32 (4 bytes each) — hence $12N$ bytes total for optimizer states. The fp16 weights used in the forward pass are cast from this master copy.
This is the standard mixed precision training workflow: use fp16 for the compute-heavy parts (forward and backward passes) to save memory and leverage fast fp16 tensor cores, while keeping fp32 for the numerically sensitive parts (weight updates). The full cycle per training step is:
- Cast fp32 master weights → fp16 ($2N$)
- Forward pass in fp16
- Backward pass in fp16, accumulate gradients ($2N$)
- Update fp32 master weights using fp32 gradients and Adam states ($12N$)
- Repeat
Total training memory $\approx 16N$ bytes — roughly $8\times$ the inference cost.
Adam 8-bit is a variant that reduces memory further by storing $m$ and $v$ in int8 instead of fp32. Since int8 only has 256 discrete values, naive quantization would lose too much information. The fix is blockwise quantization: parameters are split into small blocks, each block finds its own scale factor and maps its values into $[-127, 127]$, keeping quantization error local.
At each update step:
- Dequantize: int8 → fp32 (approximate $m$ and $v$)
- Compute Adam update in fp32 using the approximated moments
- Update fp32 master weights as usual
- Requantize: fp32 → int8 (store updated $m$ and $v$)
The fp32 master weights are untouched — only the storage format of $m$ and $v$ changes. This works because $m$ and $v$ are slow-moving exponential averages, int8 is precise enough to capture their trend, and the fp32 master weights absorb any accumulated error over time.
Standard Adam Adam 8-bit fp16 weights $2N$ $2N$ fp16 gradients $2N$ $2N$ fp32 master weights $4N$ $4N$ $m$ (first moment) $4N$ (fp32) $N$ (int8) $v$ (second moment) $4N$ (fp32) $N$ (int8) Total $\mathbf{16N}$ $\mathbf{10N}$
Running Cost Estimation
How many FLOPs does a forward pass take?
Now that we know what the model looks like, we can apply the matmul rule to each component and count the actual compute.
For a sequence of $T$ tokens and hidden dimension $d$, we apply the matmul rule to each operation per layer.
In an attention module:
- Q, K, V projections: three linear layers, so $3 \times (2 \times T \times d \times d) = 6Td^2$
- $Q \times K^\top$ (attention scores): matrix shape $(T, d) \times (d, T)$, so $2T^2d$
- Score $\times V$ (weighted sum): matrix shape $(T, T) \times (T, d)$, so $2T^2d$
- Output projection: one linear layer, so $2 \times T \times d \times d = 2Td^2$
The feed-forward network (FFN) is actually the most compute-intensive part of a Transformer. A typical FFN has two linear layers: the first upprojects from $d$ to $4d$, and the second downprojects back to $d$:
The expansion to a higher dimension gives the model more capacity to learn complex, non-linear transformations: the wider intermediate space can enable the network to represent richer features before compressing back into the residual stream.
- MLP up-projection: $2 \times T \times d \times 4d = 8Td^2$
- MLP down-projection: $2 \times T \times 4d \times d = 8Td^2$
- Total for FFN: $16Td^2$
Putting it all together:
| Component | FLOPs |
|---|---|
| Q/K/V projections: $X_{[T\times d]} W_{[d\times d]}$ | $6Td^2$ |
| Attention scores $QK^\top$ (all heads) | $2T^2d$ |
| Weighted sum $AV$ (all heads) | $2T^2d$ |
| Output projection | $2Td^2$ |
| MLP up: $X_{[T\times d]} W_{[d\times 4d]}$ | $8Td^2$ |
| MLP down: $X_{[T\times 4d]} W_{[4d\times d]}$ | $8Td^2$ |
| Total per layer | $24Td^2 + 4T^2d$ |
Multiplying by $L$ layers:
\[\text{FLOPs}_{\text{forward}} = L(24Td^2 + 4T^2d)\]For most large models, $T \ll d$, so the attention term $4T^2d$ is negligible. Plugging in $N \approx 12Ld^2$:
\[\text{FLOPs}_{\text{forward}} \approx 24LTd^2 \approx 2NT\]Each token costs roughly $2N$ FLOPs in the forward pass. For example, a 70B parameter model costs ~140B FLOPs per token.
Inference is just a forward pass. Training is harder — we also need the backward pass.
Training vs Inference
The number of FLOPs differs significantly between inference and training:
- Inference (forward pass only): one pass through the network, costing $C \approx 2NT$ FLOPs.
- Training (forward + backward): the backward pass computes gradients with respect to both weights and inputs, which costs roughly $2\times$ the forward pass. So total training cost is approximately $3\times$ the forward pass — 1 forward + 2 backward:
The 6ND rule for training
For a full training run on $D$ tokens:
\[\boxed{\text{FLOPs}_{\text{train}} \approx 6ND}\]This is the 6ND rule, from Hoffmann et al. (Chinchilla, 2022). It’s the standard estimate used when planning training runs — if you know your model size $N$ and token budget $D$, you can estimate the compute cost before writing a single line of training code.
The Bitter Lesson
Now that we have the tools to reason about model size and compute cost, a natural question is: what actually drives progress in AI?
Rich Sutton’s The Bitter Lesson (2019) argues that the biggest lesson from 70 years of AI research is that general methods leveraging computation always win in the long run — human-designed domain knowledge helps in the short term but loses out as compute scales. A common misreading of this is that scale is all that matters — just throw more compute at the problem. The right interpretation is subtler: algorithms that scale is what matters. Then a good way to think about the goal of AI research is:
\[\text{accuracy} = \text{efficiency} \times \text{resources}\]Raw resources without efficient algorithms quickly reach a plateau, while efficient algorithms without sufficient resources cannot realize their full potential. The two are inseparable.
Given a fixed compute and data budget, the question becomes: what is the best model one can build? In other words, maximize efficiency.
Hernandez et al. (2020) showed that on ImageNet, algorithmic improvements alone delivered a 44× efficiency gain between 2012 and 2019, i.e., the same accuracy could be achieved with 44× less compute, purely from better algorithms. At larger scale, this matters even more. We can’t afford to be wasteful when a single training run costs millions of dollars.
Thus, efficiency should be the central consideration when making design decisions for training models.
Design Decisions
So what are those design decisions? Building a language model from scratch involves choices across five areas:
Architecture
Loss function
Optimizer
Learning rate
Parallelism
Quantization
Activation checkpointing
CPU offloading
Inference
Model complexity
Loss metric
Parametric form
Curation
Transformation
Filtering
Deduplication
Mixing
Reinforcement learning
Preference data
Synthetic data
Verifiers
Each of these is a lever for efficiency. The rest of the course goes through them one by one.