Skip to main content

CS336 Lecture 2 - 3. Full Example

·1144 words·6 mins

Deep linear network
#

Consider a deep network with L layers and D-dimensional inputs, activations, and outputs.

class Block(nn.Module):
    """Simple block that applies a linear transformation followed by a ReLU nonlinearity."""
    def __init__(self, dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim, dim) / math.sqrt(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x @ self.weight  # Linear
        x = F.relu(x)        # Activation
        return x

class DeepNetwork(nn.Module):
    """Map `dim`-vector to a `dim`-vector."""
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([Block(dim) for i in range(num_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply all the layers sequentially
        for layer in self.layers:
            x = layer(x)  # @stepover
        return x

def get_num_parameters(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters())
D = 8
L = 3
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available())

num_parameters = get_num_parameters(model)  
assert num_parameters == (D * D) * L

# Run the model on a batch of data
B = 4  # Batch size
x = torch.randn(B, D, device=cuda_if_available())  
y = model(x)

Gradients basics
#

So far, we’ve constructed tensors and passed them through operations (forward). Now, we’re going to compute the gradient (backward).

As a simple example, let’s consider the simple linear model: y = x * w, loss = 0.5(y - 5)^2

Forward pass: compute loss

x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)  # Want gradient
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)

Backward pass: compute gradients

loss.backward()
assert torch.equal(w.grad, torch.tensor([1, 2, 3])

What happens under the hood (PyTorch autograd)
#

In this example, PyTorch performs two phases: forward and backward.

1. Forward pass (build computation graph)
#

During the forward pass, PyTorch does two things at the same time:

  • Compute actual values
  • Record how those values were computed

Concretely:

x, w → matmul → pred_y → subtract(-5) → pow(2) → multiply(0.> 5) → loss

Each operation creates a new tensor and attaches a grad_fn (gradient function), which represents:

  • how this tensor was computed
  • how to compute its gradient during backward

This forms a computation graph:

  • Nodes = tensors / operations
  • Edges = data dependencies
  • It is a directed acyclic graph (DAG) constructed dynamically

Only tensors with requires_grad=True (like w) are tracked as endpoints for gradients.

2. Backward pass (compute gradients)
#

For a more clear view on backpropagation, please check Backpropagation calculus | Deep Learning Chapter 4 by 3Blue1Brown [哔哩哔哩]

When we call:

loss.backward()

PyTorch:

  1. Starts from the loss with:

    d(loss)/d(loss) = 1
  2. Traverses the computation graph in reverse order

  3. At each node, applies the chain rule: $$ \frac{\mathrm{d}L}{\mathrm{d}w} = \frac{\mathrm{d}L}{\mathrm{d}y} \cdot \frac{\mathrm{d}y}{\mathrm{d}w} $$

    Each operation knows its local derivative, so it multiplies:

    incoming gradient × local derivative
  4. Gradients are propagated backward step by step:

    loss → pred_y → w
  5. Final gradients are accumulated into leaf tensors:

    w.grad = [1, 2, 3]

Gradient FLOPs
#

Let’s count the FLOPs for computing gradients.

B = 1024  # Number of points
D = 256   # Dimension

Define a simplified model (2-layer linear network):

x = torch.ones(B, D, device=cuda_if_available())
w1 = torch.randn(D, D, device=cuda_if_available(), requires_grad=True)
w2 = torch.randn(D, D, device=cuda_if_available(), requires_grad=True)

# Forward pass
h1 = einsum(x, w1, "batch in, in out -> batch out")  # x 
h2 = einsum(h1, w2, "batch in, in out -> batch out")  # h1 
loss = (h2.mean() - 0)**2  # Regress everything to 0 (arbitrary)

# Backward pass
h1.retain_grad()  # For debugging
h2.retain_grad()  # For debugging
loss.backward()

Zoom in on one layer
#

Let’s focus on the second layer (h2 = h1 @ w2)

Forward pass: Recall the number of forward FLOPs:

num_forward_flops = 2 * B * D * D

Backward pass: How many FLOPs is running the backward pass?

We need to compute:

  • h1.grad = d loss / d h1
  • w2.grad = d loss / d w2
h1_grad = einsum(h2.grad, w2, "batch out, in out -> batch in")
assert torch.allclose(h1.grad, h1_grad)

w2_grad = einsum(h2.grad, h1, "batch out, batch in -> in out")
assert torch.allclose(w2.grad, w2_grad)

num_backward_flops = (2 * B * D * D) + (2 * B * D * D)  

Note that the backward pass is 2x more expensive than the forward pass.

Consider all layers
#

This was just for w2, need to apply it to all parameters in the network.

Putting it together:

  • Forward pass: 2 (# data points) (# parameters) FLOPs
  • Backward pass: 4 (# data points) (# parameters) FLOPs
  • Total: 6 (# data points) (# parameters) FLOPs

This is for multilayer perceptrons (MLPs) …but it turns out to be a good approximation for Transformers for short context lengths as well.

Optimizer
#

Recall our deep network.

B = 2  # Batch size
D = 4  # Dimensionality of input, activations, and output
L = 3  # Number of layers
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available())  

Let’s define the AdaGrad optimizer

  • momentum = SGD + exponential averaging of grad
  • AdaGrad = SGD + averaging by grad^2
  • RMSProp = AdaGrad but with exponential averaging of grad^2
  • Adam = RMSProp + momentum

AdaGrad [Duchi+ 2011]

optimizer = AdaGrad(model.parameters(), lr=0.01)
state = model.state_dict()

# Compute gradients
x = torch.randn(B, D, device=cuda_if_available())
y = torch.tensor([4., 5.], device=cuda_if_available())
pred_y = model(x).mean()
loss = F.mse_loss(input=pred_y, target=y)
loss.backward()

# Take a step
optimizer.step()
optimizer_state = {i: dict(p_state) for i, (p, p_state) in enumerate(optimizer.state.items())}

# Free up the memory
optimizer.zero_grad(set_to_none=True)

Memory
#

num_parameters = D * D * L
parameter_memory = 2 * num_parameters  # (2 bytes for bf16)
gradient_memory = 2 * num_parameters  # (2 bytes for bf16)
optimizer_state_memory = 4 * num_parameters  # (4 bytes for fp32)
activation_memory = 2 * (B * D * L)  # (2 bytes for bf16)

It is customary to use fp32 for stability (accumulating averages over powers over many steps). Optimizer state memory:

  • AdaGrad: 4 bytes/parameter for storing second moments
  • Adam: 8 bytes/parameter for storing first and second moments
# Putting it all together
total_memory = parameter_memory + activation_memory + gradient_memory + optimizer_state_memory  

Compute (for one training step)
#

num_parameters = D * D * L
flops = 6 * B * num_parameters  

Transformers
#

The accounting for a Transformer is more complicated, but the same idea. Assignment 1 will ask you to do that. Blog post describing memory usage for Transformer training [article] Blog post describing FLOPs for a Transformer: [article]

Train loop
#

# True linear function with weights (0, 1, 2, ..., D-1)
D = 16  # Dimensionality
true_w = torch.arange(D, dtype=torch.float32, device=cuda_if_available())

# Data loader that generates (x, y) pairs
B = 4  # Batch size
def get_batch() -> tuple[torch.Tensor, torch.Tensor]:
    x = torch.randn(B, D).to(cuda_if_available())
    true_y = x @ true_w
    return (x, true_y)

# Define the model and optimizer
L = 2  # Number of layers
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available()) 
optimizer = AdaGrad(model.parameters(), lr=0.01) 

# Train!
num_train_steps = 10
for t in range(num_train_steps):
    # Get data
    x, y = get_batch()

    # Forward (compute loss)
    pred_y = model(x).mean()  
    loss = F.mse_loss(pred_y, y)

    # Backward (compute gradients)
    loss.backward()

    # Update parameters
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)