Skip to main content

CS336 Lecture 5 - GPUs

·2692 words·13 mins

Please read along with the original slides

Stanford CS336 shifts pretty hard here from “model architecture” into systems territory.
This lecture is basically the point where transformers stop feeling abstract and start colliding with hardware reality.

A lot of modern ML progress is not just “better models”. It’s also:

  • faster matrix multiplication
  • better memory layouts
  • lower precision arithmetic
  • compiler tricks
  • smarter scheduling
  • fewer memory accesses

Once you start looking at GPUs this way, FlashAttention stops looking like black magic and starts looking like a very aggressive memory optimization.


Why GPUs matter so much for LLMs
#

The lecture starts with a pretty blunt observation, that LLM scaling depends on compute scaling.

For a while, classical CPU scaling mostly came from Dennard scaling and frequency improvements. Clock speeds went up, transistors got smaller, everything got faster automatically. That trend basically died in the 2000s.

Modern scaling instead comes from parallelism. GPUs scaled extremely aggressively over the past decade, especially for matrix operations.

There is also an important asymmetry:

  • compute throughput keeps growing very quickly
  • memory bandwidth grows much more slowly

This imbalance ends up dominating almost every optimization later in the lecture.


CPUs vs GPUs
#

The lecture frames CPUs and GPUs as fundamentally different optimization targets.

CPUs care about latency. They want a small number of threads to finish quickly. So CPUs invest heavily in:

  • branch prediction
  • speculative execution
  • large caches
  • complicated control logic

GPUs care about throughput. They want enormous numbers of operations happening simultaneously, even if individual threads are relatively “dumb”. So GPUs instead spend silicon budget on:

  • lots of ALUs
  • many lightweight execution units
  • huge parallel execution capability

This tradeoff matters because ML workloads are unusually regular. Matrix multiplications do not require complicated branching logic. They mostly need raw arithmetic throughput. That makes GPUs a very good fit.

The execution structure of a GPU
#

Streaming Multiprocessors (SMs)
#

Source: https://jonathan-hui.medium.com/ai-chips-a100-gpu-with-nvidia-ampere-architecture-3034ed685e6e and https://www.nvidia.com/content/PDF/fermi_white_papers/NVIDIA_Fermi_Compute_Architecture_Whitepaper.pdf

An SM is basically a large execution cluster inside the GPU. A GPU contains many SMs operating independently.

Each SM has:

  • compute units (ALU)
  • schedulers
  • registers
  • shared memory

You can think of SMs as tiny massively parallel processors living inside the GPU.

Threads
#

Source: https://jonathan-hui.medium.com/ai-chips-a100-gpu-with-nvidia-ampere-architecture-3034ed685e6e

Threads are the smallest execution units. Each thread runs the same instruction sequence on different data.

This is the SIMT model: Single Instruction, Multiple Threads.

Blocks
#

Threads are grouped into blocks. A block executes on one SM and shares access to that SM’s shared memory.

Warps
#

Threads execute in groups called warps. On NVIDIA GPUs, a warp contains 32 threads. The GPU scheduler issues instructions warp-by-warp, not thread-by-thread. This detail explains a surprising amount of GPU behavior later.

GPU memory hierarchy
#

The hierarchy roughly looks like this:

  1. Registers
  2. Shared memory / L1 cache
  3. L2 cache
  4. Global memory (HBM / DRAM)

Shared memory is SRAM-based and extremely fast, but expensive and limited. Global memory is large but comparatively slow.

This creates the central problem of GPU programming:How to keep compute units busy without constantly waiting for memory?

Tensor cores changed everything
#

Modern GPUs contain dedicated matrix multiplication hardware called tensor cores.

This matters because matrix multiplication throughput exploded compared to ordinary floating point operations, which is why transformers are so aligned with GPU hardware.


The memory wall
#

Compute scaling outpaced memory scaling. This means modern GPUs can theoretically perform absurd amounts of arithmetic, but often cannot fetch data fast enough to stay fully utilized. A lot of ML systems engineering is really about avoiding memory bottlenecks rather than increasing raw compute.

The roofline model
#

Performance is limited by either:

  • compute throughput
  • memory bandwidth

Arithmetic intensity roughly means: operations performed per byte moved

  • If arithmetic intensity is low, the workload becomes memory-bound.
  • If arithmetic intensity is high enough, the workload becomes compute-bound.

Control divergence
#

GPUs execute threads in warps, and all threads in the same warp share one instruction stream. A conditional branch is efficient when every thread in the warp takes the same path, because the warp can continue executing normally.

The problem appears when different threads in the same warp take different branches. The GPU cannot execute both paths independently at the same time. Instead, it runs one branch while masking out the inactive threads, then runs the other branch while masking out the remaining threads. The final result is correct, but part of the warp is idle during each branch.

This is called control divergence. It is not mainly a memory bandwidth problem; it comes from the SIMT execution model itself. Divergence reduces effective parallelism, so GPU kernels usually try to keep neighboring threads following the same control flow whenever possible.


Low precision computation
#

Lower precision means:

  • less memory traffic
  • smaller tensors
  • higher arithmetic intensity
  • better tensor core throughput

For FP32:

  • 4-byte reads
  • 4-byte writes

For FP16:

  • 2-byte reads
  • 2-byte writes

Same operation count, less data movement.

Mixed precision and tensor cores
#

Source: https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf

Modern tensor cores heavily optimize lower precision operations like:

  • FP16
  • BF16
  • FP8

Usually accumulation still happens in FP32 for numerical stability.

This is one reason mixed precision training became standard.

FP8 and MXFP8
#

Source: https://developer.nvidia.com/blog/floating-point-8-an-introduction-to-efficient-lower-precision-ai-training/

MXFP8 is a block-scaled FP8 format. Instead of using one scale for an entire tensor, it assigns separate scaling factors to small groups of values. This gives the format more local dynamic range, which helps preserve accuracy while still keeping the storage and memory bandwidth benefits of FP8.

The extra scaling metadata also makes layout transformations more complicated. A transpose is no longer just a different view of the same underlying values, because the scale factors are tied to the original block layout. After transposition, the grouping of values changes, so the tensor often needs to be quantized again with a different set of scales.

In practice, MXFP8 training systems may keep separate quantized layouts for the same tensor: one for the original orientation and one for the transposed orientation. This reduces transpose overhead during computation, but it also increases memory usage and implementation complexity.

Source: https://arxiv.org/html/2506.08027v2

Operator fusion
#

Source: https://medium.com/data-science/how-pytorch-2-0-accelerates-deep-learning-with-operator-fusion-and-cpu-gpu-code-generation-35132a85bd26

Many elementwise operations are limited more by memory traffic than by arithmetic cost. If each operation is executed as a separate CUDA kernel, every intermediate result has to be written to global memory and then read back by the next kernel.

For example:

y = sin(x)**2 + cos(x)**2

A naive implementation may launch separate kernels for sin, cos, squaring, addition, and other intermediate steps. Although each operation is simple, the repeated global memory reads and writes dominate the runtime.

Operator fusion reduces this overhead by combining multiple operations into a single kernel. Intermediate values can stay in registers or local on-chip storage instead of being materialized in global memory. This reduces kernel launch overhead and, more importantly, cuts down unnecessary memory traffic.

This kind of fusion is especially effective for pointwise operations, where the computation per element is small and the main bottleneck is moving data to and from HBM.

Recomputation
#

Source: https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467
Source: https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467

During training, the backward pass typically requires intermediate activations produced during the forward pass.

A straightforward implementation stores all activations in memory so they can later be reused during gradient computation.

However, activation storage creates substantial memory traffic:

  • activations must be written to global memory during the forward pass
  • later retrieved again during the backward pass

For deep networks, especially transformers, these memory accesses can become more expensive than the extra arithmetic needed to recompute the activations.

Recomputation (or activation checkpointing) trades additional compute for reduced memory usage and lower memory bandwidth pressure.

Instead of storing every intermediate activation, the system stores only selected checkpoints and recomputes missing activations when needed during backpropagation.

On modern GPUs, this tradeoff is often favorable because compute throughput scales faster than memory bandwidth.

Memory coalescing
#

Global memory is served through aligned memory transactions rather than isolated scalar reads. When a warp issues a load instruction, the GPU checks the addresses requested by its 32 threads. If these addresses are contiguous or fall within a small number of aligned memory segments, the load can be served efficiently. This access pattern is called memory coalescing.

For row-major matrices, elements in the same row are contiguous, while elements in the same column are separated by the row stride. Therefore, a kernel where neighbouring threads read along a row is usually more memory efficient than one where neighbouring threads read down a column, even if both perform the same arithmetic.

Coalescing explains why memory layout and thread mapping matter in CUDA. Tiling builds on the same idea: a kernel first loads matrix tiles from global memory into shared memory using coalesced accesses, then reuses those values locally instead of repeatedly reading from global memory.

Tiling
#

Tiling reorganizes matrix multiplication around blocks of data rather than individual output elements. Instead of having each thread repeatedly fetch values from global memory, the kernel loads a tile from matrix A and a tile from matrix B into shared memory, then uses those tiles to compute part of a tile in matrix C.

The outer loop moves across tiles along the reduction dimension. For each step, the current A and B tiles are loaded once from global memory and reused by many threads inside the block. The inner loop then multiplies elements from those shared-memory tiles and accumulates partial sums for the output tile.

This changes the memory behavior of matrix multiplication. In a naive kernel, the same values from A and B may be loaded from global memory many times by different threads. With tiling, those values are fetched from global memory much less often and reused from faster on-chip memory.

The main benefit is not fewer FLOPs. The arithmetic is the same. The benefit is higher arithmetic intensity: more multiply-add operations are performed for each byte loaded from global memory.

Why tiling improves arithmetic intensity
#

Tiling does not change the number of multiply-add operations in matrix multiplication. It changes where the input values are read from.

In a non-tiled matrix multiplication, each input value may be read from global memory N times, because different output elements repeatedly need the same values. With tile size T, each input value is read from global memory only N/T times. After a tile is loaded into shared memory, the values inside that tile can be reused T times before the kernel moves to the next tile.

This gives a factor of T reduction in global memory access under the simplified square-matrix model shown above.

The result is higher arithmetic intensity. The kernel performs the same arithmetic, but it performs more multiply-add operations for each byte fetched from global memory. This makes it easier for the GPU to use its compute units instead of waiting on HBM.

Practical limits of tiling
#

Source: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html

The ideal tile shape is constrained by the hardware. A tile has to fit into shared memory, map well to warps, and produce coalesced global memory loads. Matrix dimensions also matter. If the dimensions are not divisible by the tile size, boundary tiles may contain inactive threads or unused elements.

Source: https://www.thonking.ai/p/what-shapes-do-matrix-multiplications

Alignment adds another constraint. Global memory is served through aligned memory transactions, so a tile is efficient when its rows line up with these transaction boundaries. In the aligned case, the data loaded from global memory mostly belongs to the tile being computed.

In the unaligned case, the same logical tile may cross transaction boundaries. The GPU then has to fetch extra memory segments, and some of the fetched values are not useful for the current tile. The arithmetic work is almost unchanged, but the number of memory transactions increases.

This is one reason GPU performance curves are often jagged. A small change in matrix shape can change tile utilization, memory alignment, or scheduling behavior, even when the FLOP count changes only slightly.

Wave quantization
#

Wave quantization is another source of uneven performance. The GPU schedules thread blocks across a fixed number of SMs. If the number of tiles maps cleanly onto the available SMs, most SMs stay busy. If the tile count slightly exceeds a multiple of the SM count, the GPU may need an extra scheduling wave with only a small amount of remaining work.

For example, increasing a matrix dimension from 1792 to 1793 can increase the number of tiles from 98 to 120 for a particular tile shape. On an A100 with 108 SMs, 98 tiles fit into one wave, while 120 tiles require a second wave. The second wave is under-filled, so utilization drops.

This explains why a larger matrix can sometimes run slower than a slightly smaller one. The arithmetic changed only a little, but the mapping onto hardware changed a lot.


FlashAttention
#

FlashAttention uses the same hardware-aware ideas from tiled matrix multiplication, but applies them to attention. The point is not to approximate attention or reduce the number of attention scores. It still computes exact attention. The improvement comes from reducing memory traffic.

Why standard attention is memory hungry
#

Scaled dot-product attention can be written as:

$$ S = QK^T $$

$$ P = \operatorname{softmax}(S) $$

$$ O = PV $$

The expensive part is not only the matrix multiplication. A naive implementation materializes the attention score matrix S and often the probability matrix P in global memory. For a sequence length of N, these matrices have size N × N.

This creates a large amount of HBM traffic. The GPU computes QK^T, writes the scores to global memory, reads them back for softmax, writes the softmax result, then reads it again for the multiplication with V.

The FLOP count is high, but the memory traffic is the larger problem.

Tiling attention
#

FlashAttention avoids materializing the full attention matrix in HBM. It splits Q, K, and V into blocks and computes attention block by block.

For each query block, the kernel loads a block of K and V, computes a partial attention result, and updates the output. The intermediate attention scores are kept in on-chip memory as much as possible. They do not need to be written out as a full N × N matrix.

This is the same basic idea as tiled matrix multiplication: move a small block of data into fast memory, reuse it, and avoid repeated global memory access.

The softmax problem
#

The difficult part is softmax. Matrix multiplication can be tiled directly, but softmax normally needs information from the whole row because each output depends on the row maximum and the normalization sum.

FlashAttention handles this with online softmax. Instead of computing softmax after the full attention row has been materialized, the kernel updates the row maximum and normalization term incrementally as it processes each tile.

For each tile, the algorithm keeps track of:

  • the running maximum for numerical stability
  • the running normalization denominator
  • the partial weighted sum with V

This makes it possible to compute the same softmax result tile by tile without storing the full attention matrix.

Forward pass intuition
#

In the forward pass, FlashAttention combines several ideas:

  • tile-wise computation of QK^T
  • fusion of scaling, masking, exponentiation, and normalization
  • online softmax across tiles
  • immediate multiplication with the corresponding V tile

The intermediate attention matrix exists conceptually, but not as a large tensor stored in HBM. This is the main reason FlashAttention is faster and more memory efficient than a naive implementation.

Backward pass and recomputation
#

The backward pass uses the same memory-saving philosophy. Instead of storing every intermediate attention value from the forward pass, FlashAttention recomputes some of them tile by tile during backward.

This trades extra computation for much lower memory usage. On modern GPUs, this tradeoff often makes sense because compute throughput is abundant while HBM bandwidth and capacity are more limited.


Takeaway
#

The lecture’s main message is that GPU performance is often limited by data movement rather than raw arithmetic.

Several techniques follow from that idea:

  • coalescing makes global memory transactions more efficient
  • fusion avoids unnecessary intermediate writes
  • recomputation trades extra FLOPs for lower memory pressure
  • tiling moves reusable data into shared memory
  • FlashAttention applies these ideas to attention

This is a useful way to read many ML systems papers. The important question is often not just “how many FLOPs does this use?”, but “where does the data live, and how many times does it move?”