Parallax — Local-linear attention vs FlashAttention 2/3
LLMThe news. On May 29, 2026, the Parallax paper (arXiv:2605.29157) reframed softmax attention as a local-constant (Nadaraya–Watson) estimator and upgraded it to local-linear estimation. Tested at 0.6B and 1.7B pretraining scales, it reports consistent perplexity gains under both parameter-matched and compute-matched controls, and a hardware-aware decode kernel that matches or outperforms FlashAttention 2/3 across batch sizes and context lengths — "the first empirical demonstration of strong architecture–optimizer codesign for attention mechanisms." Read the paper →
Picture the hillside. You're standing at a spot — the query — and you want to guess your altitude from a handful of nearby trail markers whose heights you know. The lazy method is to average the markers: add up their altitudes, divide, done. But if the markers sit on a slope, their average sits downhill of where you actually stand — the flat guess is biased exactly because it threw away the slope. That average-the-neighbors move is what softmax attention does: it turns scores into weights and reports a weighted average of the value vectors. Statisticians have a name for it — the Nadaraya–Watson, or local-constant, estimator.
Parallax keeps the same markers but fits their slope. Instead of one flat number, it draws a short line through the weighted neighborhood and reads off the value at your exact spot — a local-linear fit. On the hill that's the difference between "about 600 m, give or take the slope" and "612 m, right here." The extra ingredient is a small query-like probe that measures the KV covariance — how value trends with key across the neighborhood — so the model knows which way, and how steeply, to tilt the line. Earlier Local Linear Attention needed a numerical solver to do this; Parallax derives the estimator in closed form and drops the solver entirely.
Here's the part that makes it more than a statistics upgrade. Fitting a slope is more math per neighbor than averaging — and on a GPU, more math per byte fetched is exactly what you want. Single-token decode is memory-bound: the hardware sits idle waiting on the KV cache to stream in from HBM, so FlashAttention — already IO-optimal — has little speed headroom left on that path, because it's bandwidth-limited, not compute-limited. Parallax's heavier slope kernel raises arithmetic intensity, pushing the operating point right along the roofline until it crosses the ridge into the compute-bound regime. The slope work is largely free because those FLOPs land on units that were otherwise stalled.
Where the extra FLOPs are free
Hold one decode step fixed and walk the roofline (numbers illustrative). Say attention reads a KV block of 4 MB from HBM and FlashAttention does ~8 MFLOPs on it — an arithmetic intensity of 2 FLOP/byte. If the GPU's ridge point is at ~10 FLOP/byte, that operating point sits 5× to the left of the ridge: deep in memory-bound territory, where the compute units idle ~80% of the time. Now Parallax's slope fit does ~5× the FLOPs — ~40 MFLOPs on the same 4 MB — for an intensity of 10 FLOP/byte, landing right at the ridge. In this illustrative setup the byte traffic is unchanged, so the memory cost stays the same; the new FLOPs run on units that were stalled. That's why a heavier kernel can match or beat a leaner one on wall-clock: the win isn't fewer operations, it's trading idle cycles for a sharper estimate.
Three ways to estimate a query's output
| Method | Estimator | Extra cost | Decode kernel |
|---|---|---|---|
| Softmax attention | local-constant (Nadaraya–Watson) — weighted average | none | memory-bound (FlashAttention) |
| Local Linear Attention (prior) | local-linear via a numerical solver | solver per step | solver overhead, hard to fuse |
| Parallax | local-linear, closed-form + KV-covariance probe | ~slope math, no solver (setup-dependent, illustrative) | compute-bound, matches/beats FA 2/3 (paper) |
The table isn't a claim that averaging is wrong — for many neighborhoods the slope is near zero and the flat guess is fine, which is why softmax attention works at all. It's that whenever the neighbors do sit on a slope, a constant fit leaves a bias on the table that a line removes — and on decode hardware, paying for that line is nearly free. Parallax's contribution is making the local-linear estimator cheap enough to fuse and codesigning it with the Muon optimizer so the gains actually train.
Goes deeper in: GPU & CUDA → Roofline Model → Reading the roofline
Related explainers
- I/O-optimal approximate attention — Near-linear I/O vs FlashAttention — another attention kernel measured against FlashAttention, but attacking IO complexity instead of the estimator
- Gated DeltaNet-2 — Decoupled erase/write gates — a different rethink of what attention computes, from the linear-recurrence side