Deep linear network #

Consider a deep network with L layers and D-dimensional inputs, activations, and outputs.
class Block(nn.Module):
"""Simple block that applies a linear transformation followed by a ReLU nonlinearity."""
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim, dim) / math.sqrt(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x @ self.weight # Linear
x = F.relu(x) # Activation
return x
class DeepNetwork(nn.Module):
"""Map `dim`-vector to a `dim`-vector."""
def __init__(self, dim: int, num_layers: int):
super().__init__()
self.layers = nn.ModuleList([Block(dim) for i in range(num_layers)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply all the layers sequentially
for layer in self.layers:
x = layer(x) # @stepover
return x
def get_num_parameters(model: nn.Module) -> int:
return sum(param.numel() for param in model.parameters())D = 8
L = 3
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available())
num_parameters = get_num_parameters(model)
assert num_parameters == (D * D) * L
# Run the model on a batch of data
B = 4 # Batch size
x = torch.randn(B, D, device=cuda_if_available())
y = model(x)Gradients basics #
So far, we’ve constructed tensors and passed them through operations (forward). Now, we’re going to compute the gradient (backward).
As a simple example, let’s consider the simple linear model: y = x * w, loss = 0.5(y - 5)^2
Forward pass: compute loss
x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True) # Want gradient
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)Backward pass: compute gradients
loss.backward()
assert torch.equal(w.grad, torch.tensor([1, 2, 3])What happens under the hood (PyTorch autograd) #
In this example, PyTorch performs two phases: forward and backward.
1. Forward pass (build computation graph) #
During the forward pass, PyTorch does two things at the same time:
- Compute actual values
- Record how those values were computed
Concretely:
x, w → matmul → pred_y → subtract(-5) → pow(2) → multiply(0.> 5) → lossEach operation creates a new tensor and attaches a grad_fn (gradient function), which represents:
- how this tensor was computed
- how to compute its gradient during backward
This forms a computation graph:
- Nodes = tensors / operations
- Edges = data dependencies
- It is a directed acyclic graph (DAG) constructed dynamically
Only tensors with
requires_grad=True(like w) are tracked as endpoints for gradients.2. Backward pass (compute gradients) #
For a more clear view on backpropagation, please check Backpropagation calculus | Deep Learning Chapter 4 by 3Blue1Brown [哔哩哔哩]
When we call:
loss.backward()PyTorch:
Starts from the loss with:
d(loss)/d(loss) = 1Traverses the computation graph in reverse order
At each node, applies the chain rule: $$ \frac{\mathrm{d}L}{\mathrm{d}w} = \frac{\mathrm{d}L}{\mathrm{d}y} \cdot \frac{\mathrm{d}y}{\mathrm{d}w} $$
Each operation knows its local derivative, so it multiplies:
incoming gradient × local derivativeGradients are propagated backward step by step:
loss → pred_y → wFinal gradients are accumulated into leaf tensors:
w.grad = [1, 2, 3]
Gradient FLOPs #
Let’s count the FLOPs for computing gradients.

B = 1024 # Number of points
D = 256 # DimensionDefine a simplified model (2-layer linear network):
x = torch.ones(B, D, device=cuda_if_available())
w1 = torch.randn(D, D, device=cuda_if_available(), requires_grad=True)
w2 = torch.randn(D, D, device=cuda_if_available(), requires_grad=True)
# Forward pass
h1 = einsum(x, w1, "batch in, in out -> batch out") # x
h2 = einsum(h1, w2, "batch in, in out -> batch out") # h1
loss = (h2.mean() - 0)**2 # Regress everything to 0 (arbitrary)
# Backward pass
h1.retain_grad() # For debugging
h2.retain_grad() # For debugging
loss.backward()Zoom in on one layer #
Let’s focus on the second layer (h2 = h1 @ w2)
Forward pass: Recall the number of forward FLOPs:
num_forward_flops = 2 * B * D * DBackward pass: How many FLOPs is running the backward pass?
We need to compute:
- h1.grad = d loss / d h1
- w2.grad = d loss / d w2
h1_grad = einsum(h2.grad, w2, "batch out, in out -> batch in")
assert torch.allclose(h1.grad, h1_grad)
w2_grad = einsum(h2.grad, h1, "batch out, batch in -> in out")
assert torch.allclose(w2.grad, w2_grad)
num_backward_flops = (2 * B * D * D) + (2 * B * D * D) Note that the backward pass is 2x more expensive than the forward pass.
Consider all layers #
This was just for w2, need to apply it to all parameters in the network.
Putting it together:
- Forward pass: 2 (# data points) (# parameters) FLOPs
- Backward pass: 4 (# data points) (# parameters) FLOPs
- Total: 6 (# data points) (# parameters) FLOPs
This is for multilayer perceptrons (MLPs) …but it turns out to be a good approximation for Transformers for short context lengths as well.
Optimizer #
Recall our deep network.
B = 2 # Batch size
D = 4 # Dimensionality of input, activations, and output
L = 3 # Number of layers
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available()) Let’s define the AdaGrad optimizer
- momentum = SGD + exponential averaging of grad
- AdaGrad = SGD + averaging by grad^2
- RMSProp = AdaGrad but with exponential averaging of grad^2
- Adam = RMSProp + momentum
AdaGrad [Duchi+ 2011]
optimizer = AdaGrad(model.parameters(), lr=0.01)
state = model.state_dict()
# Compute gradients
x = torch.randn(B, D, device=cuda_if_available())
y = torch.tensor([4., 5.], device=cuda_if_available())
pred_y = model(x).mean()
loss = F.mse_loss(input=pred_y, target=y)
loss.backward()
# Take a step
optimizer.step()
optimizer_state = {i: dict(p_state) for i, (p, p_state) in enumerate(optimizer.state.items())}
# Free up the memory
optimizer.zero_grad(set_to_none=True)Memory #
num_parameters = D * D * L
parameter_memory = 2 * num_parameters # (2 bytes for bf16)
gradient_memory = 2 * num_parameters # (2 bytes for bf16)
optimizer_state_memory = 4 * num_parameters # (4 bytes for fp32)
activation_memory = 2 * (B * D * L) # (2 bytes for bf16)It is customary to use fp32 for stability (accumulating averages over powers over many steps). Optimizer state memory:
- AdaGrad: 4 bytes/parameter for storing second moments
- Adam: 8 bytes/parameter for storing first and second moments
# Putting it all together
total_memory = parameter_memory + activation_memory + gradient_memory + optimizer_state_memory Compute (for one training step) #
num_parameters = D * D * L
flops = 6 * B * num_parameters Transformers #
The accounting for a Transformer is more complicated, but the same idea. Assignment 1 will ask you to do that. Blog post describing memory usage for Transformer training [article] Blog post describing FLOPs for a Transformer: [article]
Train loop #
# True linear function with weights (0, 1, 2, ..., D-1)
D = 16 # Dimensionality
true_w = torch.arange(D, dtype=torch.float32, device=cuda_if_available())
# Data loader that generates (x, y) pairs
B = 4 # Batch size
def get_batch() -> tuple[torch.Tensor, torch.Tensor]:
x = torch.randn(B, D).to(cuda_if_available())
true_y = x @ true_w
return (x, true_y)
# Define the model and optimizer
L = 2 # Number of layers
model = DeepNetwork(dim=D, num_layers=L).to(cuda_if_available())
optimizer = AdaGrad(model.parameters(), lr=0.01)
# Train!
num_train_steps = 10
for t in range(num_train_steps):
# Get data
x, y = get_batch()
# Forward (compute loss)
pred_y = model(x).mean()
loss = F.mse_loss(pred_y, y)
# Backward (compute gradients)
loss.backward()
# Update parameters
optimizer.step()
optimizer.zero_grad(set_to_none=True)