跳过正文

CS336 Lecture 2 - 4. More Memory Optimizations

·451 字·3 分钟

Gradient accumulation
#

Large batch sizes: improve training stability However, activation memory scales with batch size, so might run out.

B = 64     # Batch size
D = 1024   # Dimensionality
L = 16     # Number of layers
activation_memory = 2 * B * D * L  # (2 bytes for bf16) 

Gradient accumulation:

  • Compute gradient on micro batches
  • Accumulate the gradients (don’t zero it out)
  • Every batch_size / micro_batch_size steps, update the parameters and zero out the gradients
micro_batch_size = 256
activation_memory = 2 * micro_batch_size * D * L  # (2 bytes for bf16)

Activation checkpointing
#

For training, we need to store the activations of all layers For inference, we don’t compute gradients, so we only need to store the current layer’s activations.

The memory usage is

B = 64     # Batch size
D = 1024   # Dimensionality
L = 16     # Number of layers

x = torch.randn(B, D, device=cuda_if_available(), requires_grad=True)
activation_memory = 2 * B * D * L

model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available())
memory = get_max_memory_usage(lambda: model(x).sum().backward())

Can we reduce this?

Activation checkpointing = gradient checkpointing = rematerialization Key idea:

  • Forward pass: keep only activations at subset of layers
  • Backward pass: recompute the missing activations from the last checkpoint
  • Philosophy: tradeoff memory for compute
class DeepNetworkCheckpointed(nn.Module):
    """Same as DeepNetwork, but with activation checkpointing."""
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([Block(dim) for i in range(num_layers)])
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply all the layers sequentially
        for layer in self.layers:
            # KEY: only store activations at checkpoints, recompute the rest
""" ==> """ x = torch.utils.checkpoint.checkpoint(layer, x)
        return x
# Store all activations:    x g1 h1 g2 h2 g3 h3 g4 h4
# Activation checkpointing: x    h1    h2    h3    h4
# Define the model with checkpointing
model = DeepNetworkCheckpointed(dim=D, num_layers=L).to(cuda_if_available())
checkpointed_memory = get_max_memory_usage(lambda: model(x).sum().backward())

Can we reduce this even more, especially for deep networks (large L)?

# Store all layers:   | h1 h2 h3 h4 h5 h6 h7 h8 h9 |
# Store no layers:    |                            |
# Store some layers:  |    h3       h6          h9 |

How frequently to checkpoint?

  • If store each layer’s activations, then activation memory is O(L) and no recomputation.
  • If store no activations, then activation memory is O(1) and compute is O(L^2) (recompute from the start for each layer).
  • If store every sqrt(L) layers, then activation memory is O(sqrt(L)) and O(L) recomputation.

Summary
#

  • Everything is operations on tensors (parameters, gradients, activations, optimizer states, data)
  • einops: better way to think about tensor operations
  • 6 (# data points) (# parameters) FLOPs per training step
  • Arithmetic intensity / roofline analysis: compute-bound or memory-bound?
  • Matrix multiplications are compute-bound, elementwise operations are memory-bound
  • Gradient accumulation, activation checkpointing: reduce memory to use bigger batch sizes