← Back to Library

Why Increasing Batch Size Doesn’t Always Speed Up Training

Image generated with ChatGPT

When training a language model, each training step feeds the model a batch of token sequences and uses the resulting gradients to update the parameters. Bigger batches sound appealing: they average gradients over more examples (less noise) and keep the GPU busier by processing multiple sequences in parallel.

The catch is memory. Large batches quickly exhaust GPU memory, so we often have to simulate them with gradient accumulation: run several small micro-batches (e.g., 4 examples at a time), sum their gradients, then apply a single update. Four micro-batches of 4 gives you an effective batch of 16, without ever fitting 16 examples on the GPU at once.

Intuitively, if you have enough memory, you might expect that just cranking the micro-batch up to 16 would be almost 4x faster than “4 x 4 with accumulation.” In reality, it often isn’t. Past a point, larger per-device batches can slow training down due to many reasons. The result is counterintuitive but common: higher “utilization” on paper, worse throughput on the clock. This is even more apparent when working with small devices, such as consumer GPUs.

This article explains why. We’ll walk through the five main reasons why big batches slow down training and show when it actually makes sense to increase micro-batch size versus relying on gradient accumulation.

Reason #1: Padding Blow-Up with Variable Lengths

When a micro-batch of size B has sequences with lengths L1, L2, ..., LB, let:

Lmax = max(L1, ..., LB)

Most training stacks pad every sequence in that micro-batch up to Lmax.

We have the following:

  • Useful tokens (what you care about): sum(Li)

  • Actual number of tokens (because of padding): B * Lmax

  • Padding waste (tokens): B * Lmax - sum(Li)

  • Waste fraction: 1 - sum(Li) / (B * Lmax)

With batch_size = 1, Lmax = L1, so waste is basically zero. With batch_size = 8, a single long sample forces seven shorter ones to run at its length in every layer.

Let’s see with two examples.

Example A: 4k outlier at batch size 8

8 sequences of 8 different lengths (number of tokens) in one micro-batch:
512, 520, 530, 560, 600, 640, 650, 4000

  • Useful tokens: 8012 (summing the length of all the sequences)

  • Computed

...
Read full article on The Kaitchup →