Operator Fusion & FlashAttention Explained
The Kernel Launch Problem
What is Operator Fusion?
Operator fusion combines multiple GPU kernels into a single kernel so intermediate results stay in fast on-chip memory (registers, SRAM) instead of making round-trips through slow HBM. This is the most broadly applicable GPU optimization — eliminating unnecessary memory traffic between operations.
The Factory Conveyor Problem
Imagine a factory where each machine writes its output to a warehouse shelf, and the next machine fetches it from the shelf to continue. Three machines = three warehouse round-trips. Fusion puts the machines side by side — output passes directly from one to the next without touching the warehouse.
Unfused (3 kernels)
6 HBM accesses
Fused (1 kernel)
+ bias
+ ReLU
2 HBM accesses
3× fewer HBM accesses — same computation
In GPU terms: each CUDA kernel reads inputs from HBM, computes, and writes results back to HBM. When you chain three kernels (matmul → bias → ReLU), each intermediate result makes a round-trip through HBM:
- matmul reads inputs from HBM, writes output to HBM
- bias reads matmul output from HBM, writes result to HBM
- ReLU reads bias output from HBM, writes result to HBM
That's 6 HBM accesses just for intermediates. Fused into one kernel: 2 HBM accesses (read input once, write final output once). Same math, 3× less memory traffic.
On the right panel: Compare unfused (left, red arrows) vs fused (right, green). Count the HBM arrows — that's the bandwidth you're paying for. The fused version finishes while the unfused is still shuttling data.
Operator Fusion
Which Operations Can Fuse?
Not all operations fuse equally. The general rule: elementwise operations fuse freely; reductions create barriers.
Two categories of operations:
-
Elementwise means
output[i] = f(input[i])— each element is processed independently, without looking at any other element. Adding a number, multiplying, applying ReLU (if negative → 0) — the GPU can process all elements in parallel with no coordination. -
Reduction means
output = f(input[0], input[1], ..., input[N])— the result depends on the entire row or column. Sum, mean, max, softmax all need to read every element before producing one output.
This distinction matters for fusion: elementwise ops can be tacked onto any kernel (process each element as it flows through), while reductions force the kernel to wait for all data before continuing.
matmul + bias + activation
Every linear layer in a neural network computes y = activation(W × x + b) — a matmul, then a bias add, then an activation function. Without fusion, that's three separate kernels with two HBM round-trips between them.
- bias add ✅ elementwise — adds a constant to each element:
output[i] = matmul_result[i] + b. No element depends on any other. - activation (ReLU/GELU) ✅ elementwise — applies a function per element:
output[i] = relu(input[i])(if negative → 0, else keep). Each element is independent.
Both are cheap follow-ups to the matmul. Because they're elementwise, they can be computed on each output element immediately as the matmul produces it — no need to write intermediates to HBM. This is the most common fusion in deep learning.
layernorm (mean + variance + normalize)
Layer normalization ensures each layer's outputs have consistent scale — preventing values from exploding or vanishing as they flow through the network. It computes three things in sequence:
- mean — average all values in the row:
μ = (x₁ + x₂ + ... + xₙ) / n(reduction — needs all elements) - variance — measure how spread out they are:
σ² = Σ(xᵢ - μ)² / n(reduction — needs all elements + the mean) - normalize — scale each element:
output[i] = (xᵢ - μ) / √(σ² + ε)(elementwise — once you have μ and σ², each element is independent)
Without fusion: three separate kernels, each reading the entire row from HBM — three round-trips. Fused: one kernel reads each element once, computes mean and variance on the fly, and normalizes before writing back. 3× less HBM traffic.
attention score pipeline (Q·K^T + scale + mask + softmax)
Recall from the LLM Internals Attention module: attention computes how much each token should "pay attention" to every other token. The pipeline has four steps:
- Q·K^T — matrix multiply between queries and keys, producing raw attention scores (one score per token pair)
- scale — divide each score by √d to prevent large values:
score[i] /= √128(elementwise ✅ — each score is independent) - mask — set certain scores to -∞ so future tokens can't attend to past tokens:
if (j > i) score[i][j] = -∞(elementwise ✅ — each position is independent) - softmax — convert scores to probabilities that sum to 1 across each row:
prob[i] = e^(score[i]) / Σe^(score[j])(reduction ⚠️ — needs the entire row to compute the sum)
Steps 1-3 can fuse freely — scale and mask are cheap elementwise follow-ups to the matmul. But softmax is the fusion barrier: it must read every score in the row before producing any output. Everything after softmax (multiplying by V) needs a separate kernel.
This is exactly the wall that FlashAttention breaks — by computing softmax incrementally across tiles using the online softmax trick (Step 4).
Fusion Rules
- Elementwise (add, multiply, ReLU, GELU, dropout, mask): ✅ fuse freely — each element is independent
- Reductions (softmax, layernorm, sum, mean): ⚠️ fusion barriers — they need data from the entire row/column before producing output
On the right panel: Select a fusion example to see before/after HBM traffic. Notice the ✅ and ⚠️ badges — elementwise ops fuse freely, reductions create barriers. The roofline dot shifts right as fusion increases arithmetic intensity.
FlashAttention: The Problem
Quick Recap: What Attention Does
Recall from the LLM Internals Attention module: attention lets each token decide how much to "look at" every other token. For a sequence of N tokens, this means computing an N×N grid of scores — every token scored against every other token.
The three matrices involved:
- Q (queries) — "what am I looking for?" (N × d, where d is the head dimension, typically 128)
- K (keys) — "what do I contain?" (N × d)
- V (values) — "what information do I offer?" (N × d)
The attention matrix S = Q · K^T is the N×N grid: S[i][j] = how much token i should attend to token j. After softmax (converting scores to probabilities that sum to 1 per row), we multiply by V to get the final output — a weighted blend of values based on attention scores.
The Problem: Three Passes Over N×N
In standard attention, this N×N matrix is computed and stored in HBM (GPU main memory). The pipeline makes three passes over it:
- Compute scores: S = Q · K^T → write the N×N matrix to HBM
- Softmax: read S from HBM, apply softmax → write P back to HBM
- Weight values: read P from HBM, multiply by V → write output to HBM
Each pass reads or writes the entire N×N matrix. That's a lot of memory traffic — and it grows quadratically with sequence length:
| Sequence Length (N) | N×N Matrix (FP16) | 3-Pass HBM Traffic |
|---|---|---|
| 512 | 0.5 MB | 1.5 MB |
| 1,024 | 2 MB | 6 MB |
| 2,048 | 8 MB | 24 MB |
| 4,096 | 32 MB | 96 MB |
| 8,192 | 128 MB | 384 MB |
For N=4096 with 32 attention heads: ~3 GB of HBM traffic just for one attention layer.
Same FLOPs, Different IO
The computation is O(N²d) FLOPs regardless — the matrix multiply and softmax do the same work no matter where the data lives. The question is: how many bytes move through HBM? Standard attention: O(N²) bytes. Can we do better?
On the right panel: Watch the N×N attention matrix (the large pink square) materialize in HBM, get read for softmax, written back, read again for V multiply. Three round-trips for one matrix — that's the bottleneck. Try different sequence lengths to see the quadratic scaling.
FlashAttention: The Solution
What is FlashAttention?
FlashAttention (Tri Dao, 2022) is a way to compute attention without ever building the full N×N score matrix in HBM. The name comes from the key idea: attention computed from fast on-chip SRAM instead of slow HBM — "flash" as in fast, because the data stays in the fastest memory the GPU has.
It combines two ideas you already know:
- Tiling (from Module 6) — process data in small chunks that fit in fast SRAM instead of loading everything into slow HBM
- Online softmax (explained below) — a trick to compute softmax incrementally, tile by tile, without needing the entire row first
The result is mathematically identical to standard attention — same scores, same probabilities, same output. The speedup comes entirely from avoiding HBM traffic.
Instead of the standard approach (compute ALL scores → store N×N matrix in HBM → read it back for softmax → store again → read for V multiply), FlashAttention processes small tiles:
- Load a small chunk of Q and K into SRAM
- Compute a small piece of scores (stays in SRAM — never touches HBM)
- Update running softmax numbers
- Multiply by a chunk of V, accumulate partial output
- Move to next chunk → repeat
- When done: write only the final output to HBM
The N×N matrix never gets built. Only small tile-sized pieces exist at any moment, and only in fast SRAM. But there's one challenge to solve first...
The Problem: Softmax Needs the Whole Row
There's a catch. In Step 2, we learned that softmax is a reduction barrier — it needs the entire row before it can produce any output. If we're processing tiles, we only see part of the row at a time. How can we compute softmax without the full picture?
To understand the solution, let's first understand what softmax actually does.
What Softmax Does (Quick Refresher)
Softmax turns a row of raw scores into percentages that sum to 1. For example, attention scores [2, 4, 1, 3] become probabilities like [5%, 73%, 2%, 20%] — the highest score (4) gets the most weight.
The formula has two steps:
- Exponentiate each score: compute e^(score) for each value. This makes all values positive and amplifies differences. (e ≈ 2.718, so e² ≈ 7.4, e⁴ ≈ 54.6 — bigger inputs grow much faster.)
- Divide by the total: each e^(score) ÷ sum of all e^(scores). This makes them sum to 1.
The overflow problem: If a score is large (like 1000), e^1000 is astronomically huge — it overflows. The standard trick: subtract the maximum from all scores first. For scores [2, 4, 1, 3], max=4, so compute e^(2-4), e^(4-4), e^(1-4), e^(3-4) = e^(-2), e^0, e^(-3), e^(-1). Now all exponents are ≤ 0, so all values are between 0 and 1. The ratios stay the same — the final percentages are identical.
Why this needs the whole row: You need to know the maximum before you start, and you need the sum of all exponentials for the denominator. Both require seeing every score in the row.
The Solution: Online Softmax
When processing in tiles, we don't have the whole row. Online softmax handles this by keeping two running numbers as we go through tiles:
- m — the largest score seen so far (so we can subtract it)
- l — the running sum of exponentials (the denominator, built up tile by tile)
Think of it like grading exams without knowing the curve in advance. You grade each batch, and when a later batch has a higher top score, you go back and adjust your previous grades.
Tile 0: scores [2, 4, 1, 3]
This is straightforward — just normal softmax on 4 values:
- Max so far: m = 4
- Subtract max and exponentiate: e^(2-4) + e^(4-4) + e^(1-4) + e^(3-4) = 0.135 + 1.0 + 0.050 + 0.368
- Running sum: l = 1.553
Tile 1: scores [5, 1, 2, 3]
The new tile has a 5 — that's bigger than our old max of 4. This means all our previous exponentials were computed relative to the wrong reference point (we subtracted 4, but should have subtracted 5). We need to correct:
- New max: m = 5 (updated from 4)
- Correction: multiply the old sum by e^(4-5) = e^(-1) ≈ 0.368. This adjusts our previous work to the new reference point. Corrected old sum: 1.553 × 0.368 = 0.571
- Add new tile's exponentials (relative to max=5): e^(5-5) + e^(1-5) + e^(2-5) + e^(3-5) = 1.0 + 0.018 + 0.050 + 0.135 = 1.203
- Updated running sum: l = 0.571 + 1.203 = 1.774
Done! After all tiles: softmax[i] = e^(score[i] - 5) / 1.774. This gives the exact same result as computing softmax on all 8 values at once.
Input tiles
Tile 0
[2, 4, 1, 3]
Tile 1
[5, 1, 2, 3]
After Tile 0
l₀ = e^(2−4) + e^(4−4) + e^(1−4) + e^(3−4)
After Tile 1
Final softmax
softmax[i] = e^(x[i] − 5) / 1.774
No second pass needed — max was corrected online
The key insight: the correction factor e^(old_max - new_max) is always ≤ 1 (because the new max is bigger), so it just scales down previous work. If the max doesn't change in a new tile, the correction factor is e^0 = 1 — no change needed.
Same FLOPs, Different IO
FlashAttention computes the exact same result as standard attention — same O(N²d) FLOPs. The only difference: the N×N matrix never materializes in HBM. Intermediates live in SRAM. Same math, dramatically less memory traffic.
On the right panel: Step through the tiling: watch Q and K tiles load into SRAM, partial scores computed without touching HBM, online softmax updating m and l. Notice the traffic counter barely moves during compute phases — only ticking during load/write. That's the whole point.
IO Complexity: Why It's Faster
Why FlashAttention Is Faster
Standard attention HBM accesses: O(N²) — the N×N matrix is read and written multiple times.
FlashAttention HBM accesses: O(N²d²/M) where M = SRAM size. For d=128, M=100KB on A100:
Reduction factor ≈ M / d² ≈ 100,000 / 16,384 ≈ 6×
At longer sequences, both are quadratic in N — but FlashAttention's constant factor is M/d² smaller. That constant factor is the entire speedup.
Same FLOPs, Different IO
This is the lesson of the entire module: FlashAttention doesn't reduce computation. It reduces memory traffic.
The attention mechanism does the same O(N²d) multiply-adds. By tiling into SRAM and never materializing the N×N matrix in HBM, it transforms attention from memory-bound to compute-bound.
On the roofline: same Y-axis (FLOPs), higher X-axis (arithmetic intensity). The attention dot moves from the memory-bound region to near the compute ceiling.
Connecting the Tracks
The Attention module in the LLM Internals track shows what attention computes — the Q·K^T scores, softmax, and V multiplication. This module shows why the naive version is slow and how FlashAttention fixes it — same result, dramatically less memory traffic.
What's Next
Module 9 introduces Triton & torch.compile — FlashAttention was hand-written in CUDA. Triton lets you write similar fused kernels in Python, and torch.compile applies fusion automatically. The final module connects all nine modules into one mental model.