Optimizing GPU Memory for Deep Learning Workloads
Technical deep-dive into GPU memory optimization techniques, including gradient checkpointing, mixed precision, and memory mapping strategies.
GPU memory is often the limiting factor when training large deep learning models. Modern GPUs like the A100 have 40-80GB of memory, but large language models can require hundreds of gigabytes. Here’s how to make the most of available GPU memory.
Understanding GPU Memory Usage
Memory Breakdown
During training, GPU memory is used for:
- Model parameters (~4 bytes per parameter for FP32)
- Gradients (same size as parameters)
- Optimizer states (2x parameters for Adam)
- Activations (varies by batch size and sequence length)
- Temporary buffers (framework overhead)
# Calculate memory usage for a 7B parameter model
params = 7_000_000_000
model_memory = params * 4 # FP32 parameters
gradient_memory = params * 4 # FP32 gradients
optimizer_memory = params * 8 # Adam optimizer states
total_memory = model_memory + gradient_memory + optimizer_memory
print(f"Total memory needed: {total_memory / 1e9:.1f} GB")
# Output: Total memory needed: 112.0 GB
Memory Optimization Techniques
1. Gradient Checkpointing
Trade computation for memory by recomputing activations during backward pass:
import torch.utils.checkpoint as checkpoint
class CheckpointedTransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.mlp = MLP(config)
def forward(self, x):
# Use gradient checkpointing for this block
return checkpoint.checkpoint(self._forward, x)
def _forward(self, x):
x = x + self.attention(x)
x = x + self.mlp(x)
return x
Memory savings: 50-80% reduction in activation memory
Cost: 33% increase in computation time
2. Mixed Precision Training
Use FP16 for most operations, FP32 for stability-critical parts:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
outputs = model(batch)
loss = criterion(outputs, targets)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Memory savings: ~50% reduction in model and activation memory
Cost: Minimal performance impact on modern GPUs
3. Activation Offloading
Move activations to CPU memory when not needed:
class OffloadingModule(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
# Move to GPU for computation
x = x.cuda()
output = self.module(x)
# Offload result to CPU if not immediately needed
if not output.requires_grad:
output = output.cpu()
return output
4. Gradient Accumulation
Simulate larger batch sizes without using more memory:
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps
for i, batch in enumerate(dataloader):
with autocast():
outputs = model(batch)
loss = criterion(outputs, targets) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Advanced Techniques
ZeRO Optimizer
Partition optimizer states across multiple GPUs:
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
# ZeRO Stage 1: Partition optimizer states
# ZeRO Stage 2: Partition gradients
# ZeRO Stage 3: Partition parameters
config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
}
}
}
Memory-Efficient Attention
Reduce attention memory from O(n²) to O(n):
def flash_attention(q, k, v, block_size=1024):
"""Memory-efficient attention implementation"""
seq_len = q.size(1)
output = torch.zeros_like(q)
for i in range(0, seq_len, block_size):
end_i = min(i + block_size, seq_len)
q_block = q[:, i:end_i]
# Compute attention for this block
scores = torch.matmul(q_block, k.transpose(-2, -1))
attn_weights = torch.softmax(scores, dim=-1)
output[:, i:end_i] = torch.matmul(attn_weights, v)
return output
Monitoring and Debugging
GPU Memory Profiling
import torch
def print_gpu_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
# Profile memory usage
print_gpu_memory()
model = create_model()
print_gpu_memory()
Finding Memory Leaks
import gc
import torch
def find_tensors():
"""Find all tensors in memory"""
tensors = []
for obj in gc.get_objects():
if torch.is_tensor(obj):
tensors.append((obj.shape, obj.dtype, obj.device))
return tensors
# Check for tensor leaks between training steps
before = find_tensors()
train_step()
after = find_tensors()
# Compare to find unexpected tensor growth
Practical Tips
- Start with the biggest wins: Mixed precision and gradient checkpointing
- Profile early: Use tools like
torch.profilerto understand bottlenecks - Batch size tuning: Find the largest batch size that fits in memory
- Model sharding: Split large models across multiple GPUs
- CPU offloading: Use CPU RAM as extended GPU memory
Conclusion
GPU memory optimization is crucial for training large models efficiently. The techniques above can often reduce memory usage by 4-8x, enabling training of much larger models on the same hardware. The key is to understand the memory/compute tradeoffs and apply the right combination of techniques for your specific use case.