Optimizing GPU Memory for Deep Learning Workloads

September 2024
GPU Programming Deep Learning Performance

Technical deep-dive into GPU memory optimization techniques, including gradient checkpointing, mixed precision, and memory mapping strategies.


GPU memory is often the limiting factor when training large deep learning models. Modern GPUs like the A100 have 40-80GB of memory, but large language models can require hundreds of gigabytes. Here’s how to make the most of available GPU memory.

Understanding GPU Memory Usage

Memory Breakdown

During training, GPU memory is used for:

  1. Model parameters (~4 bytes per parameter for FP32)
  2. Gradients (same size as parameters)
  3. Optimizer states (2x parameters for Adam)
  4. Activations (varies by batch size and sequence length)
  5. Temporary buffers (framework overhead)
# Calculate memory usage for a 7B parameter model
params = 7_000_000_000
model_memory = params * 4  # FP32 parameters
gradient_memory = params * 4  # FP32 gradients  
optimizer_memory = params * 8  # Adam optimizer states

total_memory = model_memory + gradient_memory + optimizer_memory
print(f"Total memory needed: {total_memory / 1e9:.1f} GB")
# Output: Total memory needed: 112.0 GB

Memory Optimization Techniques

1. Gradient Checkpointing

Trade computation for memory by recomputing activations during backward pass:

import torch.utils.checkpoint as checkpoint

class CheckpointedTransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.mlp = MLP(config)
        
    def forward(self, x):
        # Use gradient checkpointing for this block
        return checkpoint.checkpoint(self._forward, x)
    
    def _forward(self, x):
        x = x + self.attention(x)
        x = x + self.mlp(x)
        return x

Memory savings: 50-80% reduction in activation memory
Cost: 33% increase in computation time

2. Mixed Precision Training

Use FP16 for most operations, FP32 for stability-critical parts:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    # Forward pass in mixed precision
    with autocast():
        outputs = model(batch)
        loss = criterion(outputs, targets)
    
    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Memory savings: ~50% reduction in model and activation memory
Cost: Minimal performance impact on modern GPUs

3. Activation Offloading

Move activations to CPU memory when not needed:

class OffloadingModule(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        
    def forward(self, x):
        # Move to GPU for computation
        x = x.cuda()
        output = self.module(x)
        
        # Offload result to CPU if not immediately needed
        if not output.requires_grad:
            output = output.cpu()
            
        return output

4. Gradient Accumulation

Simulate larger batch sizes without using more memory:

accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps

for i, batch in enumerate(dataloader):
    with autocast():
        outputs = model(batch)
        loss = criterion(outputs, targets) / accumulation_steps
    
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Advanced Techniques

ZeRO Optimizer

Partition optimizer states across multiple GPUs:

from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

# ZeRO Stage 1: Partition optimizer states
# ZeRO Stage 2: Partition gradients  
# ZeRO Stage 3: Partition parameters

config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu"
        },
        "offload_param": {
            "device": "cpu"
        }
    }
}

Memory-Efficient Attention

Reduce attention memory from O(n²) to O(n):

def flash_attention(q, k, v, block_size=1024):
    """Memory-efficient attention implementation"""
    seq_len = q.size(1)
    output = torch.zeros_like(q)
    
    for i in range(0, seq_len, block_size):
        end_i = min(i + block_size, seq_len)
        q_block = q[:, i:end_i]
        
        # Compute attention for this block
        scores = torch.matmul(q_block, k.transpose(-2, -1))
        attn_weights = torch.softmax(scores, dim=-1)
        output[:, i:end_i] = torch.matmul(attn_weights, v)
    
    return output

Monitoring and Debugging

GPU Memory Profiling

import torch

def print_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")

# Profile memory usage
print_gpu_memory()
model = create_model()
print_gpu_memory()

Finding Memory Leaks

import gc
import torch

def find_tensors():
    """Find all tensors in memory"""
    tensors = []
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            tensors.append((obj.shape, obj.dtype, obj.device))
    return tensors

# Check for tensor leaks between training steps
before = find_tensors()
train_step()
after = find_tensors()

# Compare to find unexpected tensor growth

Practical Tips

  1. Start with the biggest wins: Mixed precision and gradient checkpointing
  2. Profile early: Use tools like torch.profiler to understand bottlenecks
  3. Batch size tuning: Find the largest batch size that fits in memory
  4. Model sharding: Split large models across multiple GPUs
  5. CPU offloading: Use CPU RAM as extended GPU memory

Conclusion

GPU memory optimization is crucial for training large models efficiently. The techniques above can often reduce memory usage by 4-8x, enabling training of much larger models on the same hardware. The key is to understand the memory/compute tradeoffs and apply the right combination of techniques for your specific use case.