Explore Gradient-Checkpointing in PyTorch
This is a practical analysis of how Gradient-Checkpointing is implemented in Pytorch, and how to use it in Transformer models like BERT and GPT2.
Recently, OpenAI has published their work about Sparse Transformer. Despite the contribution of sparse attention, the paper mentions an practical way to reduce memory usage of deep transformer. This method is called Gradient Checkpointing, which is first introduced in the paper “Training Deep Nets with Sublinear Memory Cost”.
Gradient checkpointing claims to reduce the memory cost to \(O(\sqrt{n})\) when training a \(n\) layer network. This fits the setting of transformers. Because transformers often contain many layers and the most memory intensive part of transformers is storing the output \(N \times N\) attention matrices from every intermediate layers of multi-head attention. The below is the attentio memory usage of OpenAI’s 64 layers and 4 heads Transformer:
Data Type | Stored | Recomputed |
---|---|---|
1024 text tokens (several paragraphs) |
1.0 GB | 16 MB |
32x32x3 pixels (CIFAR-10 image) |
9.6 GB | 151 MB |
64x64x3 pixels (Imagenet 64 image) |
154 GB | 2.4 GB |
24,000 samples (~2 seconds of 12 kHz audio) |
590 GB | 9.2GB |
From the table, we can observe that gradient checkpointing has saved more than 10x memory usage in Transformer models. For common researchers who only have standard GPUs with memory of 12GB, this can be a good news. And potentially we do not need expensive computational devices like TPU or Tesla V100 to train big neural networks like BERT. (I wonder why the paper gradient checkpointing only has 80 citations compared to other papers like Dropout.)
Closer look in PyTorch
After introduced the background, we need to know whether it can be applied to our model. PyTorch has already provided us an official implementation of gradient checkpointing
But let us just focus on the official implementation, and the most important function should be this:
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrouding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None, None) + grads
The implementation is very straight forward. It passes a preserve_rng_state
argument to make sure the cuda random number generator(RNG) state is saved and can be used later during the backward process. This will ensure backprogation is consistent with the forward pass through Dropout layer. Also, unlike Batch Normalization, the implementation of Layer Normalization does not record running mean or variance, so we don’t need to worry about forward pass through layer normalization.
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
The remaining step is to find out where is a good point in the code to add checkpointing. By looking at the Sparse Transformer’s implementation, it seems that the best location to add the checkpoint is the Transformer block, in which multi-head attention and gelu activation are computed. We do the modification on the code in [pytorch-pretrained-bert] from huggingface. Thanks to PyTorch’s simplicity, it can be done with only three lines (much easier than the method in tensorflow!):
import torch.utils
import torch.utils.checkpoint
# change line around 410
hidden_states = layer_module(hidden_states, attention_mask)
# into
hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask)
Next, we will test the performance on swag test. I used the following setting on a 11GB GPU for benchmark.
export SWAG_DIR=/path/to/SWAG
python run_swag.py \
--bert_model bert-base-uncased \
--do_train \
--do_lower_case \
--do_eval \
--data_dir $SWAG_DIR/data \
--train_batch_size 24 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--max_seq_length 80 \
--output_dir /tmp/swag_output/ \
--gradient_accumulation_steps 1
We can get the results below by changing batch size and enabling gradient checkpointing.
Gradient Checkpointing | Batch size | GPU Memory | Time for one epoch | Validation Accuracy after one epoch |
---|---|---|---|---|
No | 24 | 10972MB | 27min05s | 0.7997 |
Yes | 24 | 3944MB | 36min50s | 0.7997 |
Yes | 132 | 10212MB | 31min20s | 0.7946 |
*Update validation accuracy to check correctness
Conclusion
By applying gradient checkpointing or so-called recompute technique, we can greatly reduce the memory required for training Transformer at the cost of slightly slower computation (10%~20%). It can let us design more complex models using Transformer and train on longer sequences of tokens. It may also inspire research for the next generation of Transformer architecture.
*However, it is reported to be extremely slow with multiple GPUs. I haven’t tested it myself, so I am not sure if it is fixed in Pytorch 1.1. But to whom are interested, here is a fix solution from github repo csrhddlam (https://github.com/csrhddlam/pytorch-checkpoint). I have attached a table below.
Method | # GPU | Batch | Memory | Time |
---|---|---|---|---|
Naive | 2 | 256 | 5.25G | 0.27s |
Official | 2 | 256 | 2.98G | 1.41s |
This repo | 2 | 256 | 2.97G | 0.31s |
References
[1] OpenAI Github. “Saving memory using gradient-checkpointing .” 2018.
[1] OpenAI Blog. “Generative Modeling with Sparse Transformers.” 2019.
[2] Chen, Tianqi, et al. “Training Deep Nets with Sublinear Memory Cost.” 2016.