When Triton Stops Being the Right Tool

A field report from optimizing 3-bit quantized inference in vLLM.

The promise

Triton is supposed to be the compromise that ends the CUDA-vs-productivity debate. Write Python-flavoured kernels. Get 90% of hand-CUDA performance. Keep the other 10% as your margin for moving faster and iterating more.

For the common case of dense GEMM, attention, and anything that maps cleanly to tile-of-tensor-cores abstractions, this is true. Triton's kernels routinely match or beat hand-written CUDA on modern NVIDIA hardware. Key algorithmic ideas in FlashAttention-2 (looping over K/V, parallelising over sequence length) first landed in Phil Tillet's Triton FlashAttention implementation before the CUDA rewrite. It's not a toy.

The catch: Triton's abstractions were designed for "programs that look like dense linear algebra." When your program does not look like dense linear algebra, the same abstractions that paid off elsewhere become taxes you can't remove. Think bit-level manipulation, reading sub-byte-packed data, or running at batch sizes where tensor cores can't engage.

We hit exactly this wall this week, optimizing a 3-bit weight-only quantized Linear layer in vLLM. It's a short story about a long ceiling.

Three interventions, three lessons

Lesson 1: When your memory access pattern doesn't match a power-of-two stride, Triton's vectorizer just doesn't fire

The starting kernel loaded packed 3-bit weights via two 2D scatter loads, one per cross-boundary byte. The stride pattern isn't arbitrary. It falls out of the packing format: every 8 three-bit values share three consecutive bytes, with values 2 and 5 straddling byte boundaries. Calculating the byte offset for each thread's value therefore needs modulo arithmetic, producing a non-linear stride:

bi0 = g8 * 3 + fb                       # [0,0,0,1,1,1,2,2,3,3,3,...]
p0 = packed_row[:, None] * stride_cn + bi0[None, :] * stride_ck
b0 = tl.load(packed_ptr + p0, ...)

The index pattern is deterministic but not stride-1. Triton's vectorizer sees "indices not a clean tl.arange" and falls back to per-byte transactions. The A100's memory controllers saturate not from bandwidth but from request count. At bs=1 on Qwen3-8B this cost us a 50× slowdown vs BF16 at batch size 1.

Fix: one coalesced tl.arange(0, 64) bulk load of all 48 packed bytes per row (padded to the next power of two), then two tl.gathers against the in-register buffer. Same bit math, different load pattern.

Result: 5× speedup, verified end-to-end.

At this level, Triton's abstractions work exactly as advertised. The fix was a few lines, expressed natively in the language, and the compiler translated it to the right PTX. This is Triton at its best.

Lesson 2: Every high-level tensor op in Triton has a silent PTX price

Flush with success, we pushed further. tl.gather was added in Triton 3.6 with "naive codegen" flagged in the commit message itself. On sm_80 at bs=1, tl.gather with dynamic indices appears to compile to either a conditional switch cascade or a register spill. We decided to eliminate the gather entirely by using a static decode: three unit-stride loads for the three bytes of each 3-bit triplet, then eight compile-time bit-shifts per triplet, then a tree of tl.join + tl.trans + tl.reshape to assemble the 128-element output tensor.

CPU bit-equivalence: verified on 100 random + edge cases. The math is identical.

GPU measurement: no relative improvement over the tl.gather version. Per-call latency on H100 scaled almost exactly by hardware-improvement ratio from the A100 tl.gather measurement. Triton-the-compiler treated the tl.join + tl.trans + tl.reshape chain the same way it treated gather: register shuffles all the way down.

What this showed: you can't escape the ALU-pressure problem by dressing it up as a different operation. Every high-level Triton primitive has a PTX price. They're usually invisible for tile-of-tensor-cores code. They're never invisible when you're moving 3-bit values around a register file.

Lesson 3: Specializing for M=1 doesn't rescue a 2D-grid kernel

The kernel's output accumulation was acc += tl.dot(x_tile, tl.trans(values)). Tensor cores are hardware-optimized for 16×16×16 matrix blocks. At batch size 1 the input is a 1×K row vector, so we're asking a GEMM pipeline to do a GEMV. Triton pads the M dimension to 16 internally; 15/16ths of each tensor-core op is wasted, and we still pay the shared-memory staging cost the tensor-core path requires. The real bottleneck at bs=1 is memory bandwidth, not tensor-core throughput. Fixing tensor-core utilization wouldn't help; you want to skip them entirely. Obvious fix: specialize for M=1 to skip tl.dot and use elementwise multiply + tl.sum:

if BLOCK_M == 1:
    acc += tl.sum(values * x_tile.reshape(BLOCK_K)[None, :], axis=1)[None, :]
else:
    acc += tl.dot(x_tile, tl.trans(values))

Zero measurable improvement (8.35 → 8.16 tok/s, within noise).

Why: the kernel's thread block layout is fundamentally 2D over (pid_m, pid_n). The K-dimension reduction is a parallel reduction across that layout. Triton can be coaxed into a 1-D grid, and tl.sum does lower to warp-shuffle reductions under the hood. But it abstracts away the warp-level primitives you'd need to hand-orchestrate peak small-M reductions. The abstraction is tile-oriented. At M=1, the tile is the problem, and no amount of inner-loop surgery moves the needle.

The pattern: abstraction taxes compound at edge cases

All three symptoms have the same root cause. Triton's abstractions are designed for the statistical center of GPU workloads: 2D GEMMs with power-of-two tiles, dense tensor cores, regular memory access.

At edges like sub-byte decode, non-aligned stride patterns, and batch sizes below tensor-core minimums, each abstraction costs something small that you'd never notice in the common case. The compound effect at the edges is what looks like a ceiling.

Measured, our kernel spends:

We're 40× over the theoretical floor, 10× over BF16, and 17× over the hand-written-CUDA ceiling that other quantization libraries ship with. Each gap tells a different story. The Triton-vs-CUDA gap specifically is the one we can't close from inside the language.

Signs you've outgrown Triton (in my current reading)

These aren't dogma. They're heuristics. Any one of them is a yellow flag. Two or more is probably a sign to drop to CUDA:

  1. You need a 1D grid, not 2D. Triton wants pid_m, pid_n. If your workload is fundamentally "one output slice per CTA, loop through K", you're fighting the grid model. Classic bs=1 GEMV.
  2. You need to bypass tensor cores deliberately. Tensor cores are a Triton optimization target. Reducing without them means expressing everything through tl.sum + broadcasts, which compile suboptimally compared to register-level __shfl_xor_sync patterns.
  3. Your data is packed with non-power-of-two strides. 48 bytes per group of 128 3-bit values. 52 bytes for 10 per int32. Triton's tile sizes want power-of-two; everything around that gets awkward.
  4. You want to emit specific PTX instructions like lop3.b32 (3-input bitwise logic in one cycle) for your hot path. Triton 3.3+ has tl.inline_asm_elementwise as an escape hatch, genuinely useful when you know exactly what PTX you want. But once you're composing multiple inline-asm ops, you've left the Triton programming model behind.
  5. Your best reference implementation is 100 lines of CUDA. Marlin, AWQ, and FLUTE each fit their hot kernel in ~100-200 lines of raw CUDA. At that scale the "productivity vs performance" math flips: writing the raw-CUDA version to hit the hardware floor is often faster than spending three days trying to trick Triton's compiler into emitting the right PTX. You don't save meaningful wall-clock time using Triton here; you just ship slower code.

Caveats

A few things I'm deliberately not claiming:

What Triton is for, then

Everything else. Which is most kernels. Which is why Triton is a good default. The point isn't that Triton is bad. Its productivity gains are real for exactly the workloads its abstractions were designed for, and they shrink or reverse when you step outside that design envelope. Knowing where the envelope is saves time. Attempting to drag Triton into a 1D GEMV at bs=1 with sub-byte decode is exactly the kind of project where you can spend days and end up where you started, because the tool isn't cooperating with the problem.

The final-boss kernel for our PR is a 1D grid bs=1 GEMV with a 10-values-per-int32 pack format and PTX lop3.b32 decodes. It will be written in raw CUDA, because the project outgrew what Triton was built for.

Coda: three rules from this week

  1. Measure before optimizing. We had four plausible hypotheses for the bs=1 slowdown. Three were wrong. One torch.profiler capture and three ablation runs (total cost: ~€0.60 of GPU time) ruled them out cheaply before we invested in any restructuring.
  2. Release notes are honest when they warn you. "Naive codegen" on tl.gather in Triton 3.6's release notes was not exaggeration. If the tool's authors mark something as provisional, believe them.
  3. A good PR is not a complete PR. We could hold this work until we'd closed the full latency gap. We probably shouldn't. The memory-win PR is mergeable today with a clean 46-line diff and an honest "latency follow-up in a dedicated kernel" note. The follow-up kernel is a different project, a different review surface, and a different time budget. Splitting them is the only way either lands.
A 70 GB model on a 48 GB MacBook Weight compression turboquant-vllm vLLM PR #39970
More writing →