PyTorch और Meta की ads-ranking team ने इस हफ़्ते एक Blackwell-specific attention kernel drop किया — TLX Block Attention — और बड़ी story इसके नीचे की layer है। **TLX (Triton Language Extensions)** Triton की productivity और Blackwell पर raw CUTLASS-level control के बीच का DSL bridge है, जो नए tcgen05 async tensor cores, TMA descriptors और TMEM (256KB-per-SM Tensor Memory) को Triton primitives की तरह expose करता है — जैसे `tlx.async_dot`, `tlx.async_descriptor_load`, `tlx.local_trans`, plus producer-consumer warp pipelines के लिए mBarrier synchronization। Repo: github.com/triton-lang/triton-ext। यह वो layer है जिसमें 2026 में Blackwell kernels लिखने वाले ज़्यादातर builders रहेंगे एक बार stabilize होने पर।

Kernel ख़ुद fixed-block sparse self-attention को target करता है — 64-token blocks, block-diagonal pattern, compile-time-known। यह shape specifically Meta के ads-ranking और recommendation models के लिए है, LLM attention नहीं। Pattern compile time पर known होने की वजह से, kernel Flash Attention के multi-tile iteration loop, online-softmax correction factors, logsumexp HBM round-trip, और separate Di preprocessing को eliminate करता है — हर Q tile exactly एक K/V tile पर attend करता है, single GEMM, कोई correction नहीं चाहिए। Forward pass per CTA 15 warps एक specialized pipeline में use करता है (1 load / 1 QK-MMA / 4 softmax / 1 PV-MMA / 8 epilogue); backward 20 warps 7 stages में use करता है। Forward में TMEM triple-buffered (~169KB / 256KB), backward में double-buffered (~162KB / 256KB)। B200, BF16, sparsity=70% — forward 0.98ms vs Flash Attention v2 का 1.81ms (1.85×), backward 2.36ms vs 5.89ms (2.50×), total 2.31×। Numerical accuracy FA v2 को max dQ diff पर 53% से beat करती है।

Fused rotary backward दूसरा highlight है और generalizable pattern है। Standalone attention backward 1.56ms plus rotary backward 4.88ms = 6.44ms unfused; एक single kernel में fused जो dV को TMEM/registers में FP32 रखता है, rotary conjugate in-place apply करता है, फिर एक BF16 global store करता है = 1.82ms। **3.54× तेज़।** Sabak ads workloads के बाहर भी portable है: जब तुम्हारे पास registers/TMEM में FP32 intermediate values हैं, अपनी epilogue math FP32 में करना और एक बार BF16 store करना global memory के through round-trips eliminate करता है जो वरना dominate करते हैं। यह वो insight है जिसे builders TLX या Blackwell के बिना भी अन्य fused-op kernels पर apply कर सकते हैं।

Monday सुबह: यह kernel as-shipped तुम्हारे लिए useful है अगर तुम B200/B300 GPUs पर block-diagonal attention वाले ad-ranking, recsys या feature-interaction models ship करते हो — facebookresearch/ads_model_kernel_library clone करो और benchmark करो। अगर तुम LLM builder हो, kernel apply नहीं होता (causal, sliding-window, और arbitrary sparse patterns explicitly excluded हैं), पर TLX DSL ख़ुद वो part है जिसे watch करना है — यही है कि कैसे Blackwell-aware Triton kernels लिखे जाएँगे, और ज़्यादातर architectural primitives (warp specialization, TMA descriptors, TMEM accumulators) तुम्हारी stack को जिस भी attention shape की ज़रूरत हो उस पर generalize होते हैं। Honest limits: सिर्फ़ Blackwell (sm_100+), कोई Ampere/Hopper fallback नहीं, head_dim 64 या 128 पर hardcoded, block size 64 fixed, license blog में नहीं बताई (repo check करो)। इन techniques के साथ LLM-shaped attention के लिए, Flash Attention 3 का Blackwell-port और उसके successors अगले quarter के watch item होंगे।