For years, efficient-attention models traded speed for smarts. Jet-Nemotron, from NVIDIA researchers, tries to end that bargain with a pragmatic recipe: don’t pretrain a new architecture—start from a strong full-attention model, keep its MLPs, and search only the attention stack. They call it Post Neural Architecture Search (PostNAS), and the result is a 2–4B-parameter family that rivals or beats same-size full-attention baselines while massively upping tokens-per-second.
What PostNAS actually does
PostNAS is a four-step, hardware-aware exploration loop layered on a pre-trained LLM: (1) learn where to keep or drop full-attention layers; (2) select the best linear-attention block; (3) optionally design a new block (“JetBlock”); and (4) tune hyperparameters for real GPUs. Freezing MLP weights keeps search cheap while letting attention do the heavy lifting.
JetBlock in a sentence
JetBlock mixes linear attention with dynamic, input-conditioned causal convolutions on values (and trims redundant static convs on Q/K), yielding accuracy gains with little runtime overhead.
The headline numbers
-
Throughput: On H100s, Jet-Nemotron-2B logs up to 53.6× decoding and 6.14× prefilling speedups at 256K context vs Qwen3-1.7B-Base—and still shows gains at shorter contexts.
-
Accuracy: Despite being hybrid (mostly linear attention), Jet-Nemotron-2B/4B match or beat leading full-attention peers (Qwen2.5/3, Gemma3, Llama3.2) across MMLU/Pro, math, retrieval, coding, and long-context suites at similar scales.
-
Coding & long-context: In the paper’s tables, Jet-Nemotron-4B leads average coding accuracy and outpaces Qwen3-1.7B-Base on long-context tasks while running ~21× faster.
Why it’s fast (and why that matters)
A core finding is blunt but useful: KV-cache size, not parameter count, is the dominant limiter of long-context throughput. Keep KV small and you can batch more sequences; decoding is typically memory-bandwidth-bound. PostNAS bakes that into a hardware-aware search that tweaks heads/keys/values to hold speed while buying back accuracy.
Why it’s interesting for builders
-
Upgrade path, not a moonshot. You can retrofit an existing model: freeze MLPs, swap/search attention, and ship meaningful speedups without full pretraining.
-
Hybrid done right. Strategically retain a few full-attention layers (learned placement beats uniform) to keep retrieval and tricky benchmarks strong.
-
Long-context economics. If you serve 128K–256K prompts, the 53.6× decoding and 6.14× prefilling gains translate directly into lower latency or higher concurrency.
Bottom line
Jet-Nemotron reframes efficient LMs as an architecture-search problem on top of pre-trained backbones. With JetBlock and a KV-aware, GPU-realistic search, it shows you don’t have to choose between accuracy and speed—especially at long context lengths that crush classic Transformers.
Paper link: arXiv 2508.15884 (PDF)
No comments:
Post a Comment