Why Increasing Batch Size Doesn’t Always Speed Up Training
When training a language model, each training step feeds the model a batch of token sequences and uses the resulting gradients to update the parameters. Bigger batches sound appealing: they average gradients over more examples (less noise) and keep the GPU busier by processing multiple sequences in parallel.
The catch is memory. Large batches quickly exhaust GPU memory, so we often have to simulate them with gradient accumulation: run several small micro-batches (e.g., 4 examples at a time), sum their gradients, then apply a single update. Four micro-batches of 4 gives you an effective batch of 16, without ever fitting 16 examples on the GPU at once.
Intuitively, if you have enough memory, you might expect that just cranking the micro-batch up to 16 would be almost 4x faster than “4 x 4 with accumulation.” In reality, it often isn’t. Past a point, larger per-device batches can slow training down due to many reasons. The result is counterintuitive but common: higher “utilization” on paper, worse throughput on the clock. This is even more apparent when working with small devices, such as consumer GPUs.
This article explains why. We’ll walk through the five main reasons why big batches slow down training and show when it actually makes sense to increase micro-batch size versus relying on gradient accumulation.
Reason #1: Padding Blow-Up with Variable Lengths
When a micro-batch of size B has sequences with lengths L1, L2, ..., LB, let:
Lmax = max(L1, ..., LB)
Most training stacks pad every sequence in that micro-batch up to Lmax.
We have the following:
Useful tokens (what you care about): sum(Li)
Actual number of tokens (because of padding): B * Lmax
Padding waste (tokens): B * Lmax - sum(Li)
Waste fraction: 1 - sum(Li) / (B * Lmax)
With batch_size = 1, Lmax = L1, so waste is basically zero. With batch_size = 8, a single long sample forces seven shorter ones to run at its length in every layer.
Let’s see with two examples.
Example A: 4k outlier at batch size 8
8 sequences of 8 different lengths (number of tokens) in one micro-batch:
512, 520, 530, 560, 600, 640, 650, 4000
Useful tokens: 8012 (summing the length of all the sequences)
Computed
This excerpt is provided for preview purposes. Full article content is available on the original publication.
