This is a practical analysis of how Gradient-Checkpointing is implemented in Pytorch, and how to use it in Transformer models like BERT and GPT2.

Recently, OpenAI has published their work about Sparse Transformer. Despite the contribution of sparse attention, the paper mentions an practical way to reduce memory usage of deep transformer. This method is called Gradient Checkpointing, which is first introduced in the paper “Training Deep Nets with Sublinear Memory Cost”.

Gradient checkpointing claims to reduce the memory cost to when training a layer network. This fits the setting of transformers. Because transformers often contain many layers and the most memory intensive part of transformers is storing the output attention matrices from every intermediate layers of multi-head attention. The below is the attentio memory usage of OpenAI’s 64 layers and 4 heads Transformer:

Data Type Stored Recomputed
1024 text tokens
(several paragraphs)
1.0 GB 16 MB
32x32x3 pixels
(CIFAR-10 image)
9.6 GB 151 MB
64x64x3 pixels
(Imagenet 64 image)
154 GB 2.4 GB
24,000 samples
(~2 seconds of 12 kHz audio)
590 GB 9.2GB

From the table, we can observe that gradient checkpointing has saved more than 10x memory usage in Transformer models. For common researchers who only have standard GPUs with memory of 12GB, this can be a good news. And potentially we do not need expensive computational devices like TPU or Tesla V100 to train big neural networks like BERT. (I wonder why the paper gradient checkpointing only has 80 citations compared to other papers like Dropout.)

Closer look in PyTorch

After introduced the background, we need to know whether it can be applied to our model. PyTorch has already provided us an official implementation of gradient checkpointing

But let us just focus on the official implementation, and the most important function should be this:

class CheckpointFunction(torch.autograd.Function):

    def forward(ctx, run_function, preserve_rng_state, *args):
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
        with torch.no_grad():
            outputs = run_function(*args)
        return outputs

    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None, None) + grads

The implementation is very straight forward. It passes a preserve_rng_state argument to make sure the cuda random number generator(RNG) state is saved and can be used later during the backward process. This will ensure backprogation is consistent with the forward pass through Dropout layer. Also, unlike Batch Normalization, the implementation of Layer Normalization does not record running mean or variance, so we don’t need to worry about forward pass through layer normalization.

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

The remaining step is to find out where is a good point in the code to add checkpointing. By looking at the Sparse Transformer’s implementation, it seems that the best location to add the checkpoint is the Transformer block, in which multi-head attention and gelu activation are computed. We do the modification on the code in [pytorch-pretrained-bert] from huggingface. Thanks to PyTorch’s simplicity, it can be done with only three lines (much easier than the method in tensorflow!):

import torch.utils
import torch.utils.checkpoint

# change line around 410
hidden_states = layer_module(hidden_states, attention_mask)
# into
hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask)

Next, we will test the performance on swag test. I used the following setting on a 11GB GPU for benchmark.

export SWAG_DIR=/path/to/SWAG

python \
  --bert_model bert-base-uncased \
  --do_train \
  --do_lower_case \
  --do_eval \
  --data_dir $SWAG_DIR/data \
  --train_batch_size 24 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
  --max_seq_length 80 \
  --output_dir /tmp/swag_output/ \
  --gradient_accumulation_steps 1

We can get the results below by changing batch size and enabling gradient checkpointing.

Gradient Checkpointing Batch size GPU Memory Time for one epoch Validation Accuracy after one epoch
No 24 10972MB 27min05s 0.7997
Yes 24 3944MB 36min50s 0.7997
Yes 132 10212MB 31min20s 0.7946

*Update validation accuracy to check correctness


By applying gradient checkpointing or so-called recompute technique, we can greatly reduce the memory required for training Transformer at the cost of slightly slower computation (10%~20%). It can let us design more complex models using Transformer and train on longer sequences of tokens. It may also inspire research for the next generation of Transformer architecture.

*However, it is reported to be extremely slow with multiple GPUs. I haven’t tested it myself, so I am not sure if it is fixed in Pytorch 1.1. But to whom are interested, here is a fix solution from github repo csrhddlam ( I have attached a table below.

Method # GPU Batch Memory Time
Naive 2 256 5.25G 0.27s
Official 2 256 2.98G 1.41s
This repo 2 256 2.97G 0.31s


[1] OpenAI Github. “Saving memory using gradient-checkpointing .” 2018.

[1] OpenAI Blog. “Generative Modeling with Sparse Transformers.” 2019.

[2] Chen, Tianqi, et al. “Training Deep Nets with Sublinear Memory Cost.” 2016.