L'équipe PyTorch de Meta a publié les détails d'architecture sur In-Kernel Broadcast Optimization (IKBO), une technique de fusion de kernel qui élimine un des patterns silencieusement chers dans l'inférence RecSys : matérialiser les tenseurs de broadcast avant les couches d'interaction. Dans une requête de recommandation typique, ~15 user embeddings sont répliqués 70x pour matcher un batch de 1024 candidats, puis droppés immédiatement après la matmul. IKBO encode la logique de broadcast dans le kernel GPU lui-même — accepte des batch sizes mismatched, fait des index lookups à l'intérieur du kernel, ne matérialise jamais le tenseur répliqué. Les chiffres clés sur H100 SXM5 : 4x speedup cumulatif sur le kernel de compression linéaire (1,944ms → 0,482ms), 6,4x throughput sur Flash Attention end-to-end incluant le coût de broadcasting (vs baseline CuTeDSL FA4-Hopper), et 621 BF16 TFLOPs soutenus sur un workload qui avant était IO-bound à 250 TFLOPs.
L'insight technique, c'est que le broadcast est un concern de data-layout, pas une nécessité computationnelle, et les savings cascadent à travers quatre étages de co-design progressif. Étage 1 — matmul decomposition — fait tourner le GEMM user-side à son batch naturel de 15 rows et le GEMM candidate-side à 1024, puis broadcaste seulement le petit résultat, coupant le compute user-side 70x. Étage 2 — memory alignment — pad K à des multiples de 8 pour des loads TMA 128-bit alignés sur Hopper, équilibrant la pipeline L1/TEX de 84% saturée à balanced et droppant la latence GEMM de 0,984ms à 0,400ms. Étage 3 — in-kernel broadcast fusion — plie le broadcast-add dans l'épilogue du GEMM candidate via index lookup, éliminant 0,87 GB de trafic DRAM intermédiaire. Étage 4 — warp-specialized multi-stage fusion via TLX — partitionne la CTA en producer + deux consumer warp groups qui ping-pong des tiles pour overlapper les stalls WGMMA, fusionne les GEMMs user et candidate en un seul kernel persistent, et lift le L2 throughput de 74% à 84% du peak. L'histoire Flash Attention est encore plus intéressante : la SDPA standard sit à ~60 FLOPs/Byte (IO-bound), tandis qu'IKBO FA pousse l'arithmetic intensity à ~833 FLOPs/Byte au ratio 70:1 — passé le balance point H100 de 495 FLOPs/Byte, le mettant fermement compute-bound où la warp specialization et le TMA async de Hopper paient effectivement.
Lecture ecosystem : c'est une classe d'optimisation que la plupart des ML engineers n'ont pas pensé à, mais elle généralise largement. Tout workload d'inférence avec dimensions de batch mismatched — user/item, vendor/product, ranking hiérarchique avec broadcast multi-niveau — a le même pattern. Le code vit dans `pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/ikbo` (pas encore mergé dans PyTorch core), et Meta l'a déployé à travers le RecSys de prod incluant MTIA. Deux chemins d'adoption : les auteurs de modèles intègrent les kernels IKBO directement, ou un pass de compilateur ML swap les ops standards pour des équivalents IKBO à l'inference. Pour les builders qui font tourner du ranking, retrieval ou recommendation à grande échelle, le match de shape de workload est ce qui détermine si tu obtiens 2x, 4x ou 6x ; le ratio candidate-to-user scale les savings linéairement. La couche TLX (Triton-based warp specialization) mérite aussi d'être trackée pour elle-même — c'est le genre de contrôle de kernel low-level qui a été dur à obtenir sans aller au CUDA brut, et l'investissement de Meta ici suggère que ça va se faire merger upstream.
Move pratique : si tu fais tourner du RecSys, ranking, ou n'importe quel pipeline d'inférence où une dimension de tenseur est beaucoup plus petite qu'une autre (pense personnalisation, vendor selection, retrieval reranking), check si ton kernel hot-path matérialise des tenseurs de broadcast. Si oui, le module experimental d'IKBO mérite un benchmark — Meta rapporte jusqu'à 2/3 de réduction de latence nette sur les modèles co-designés, robuste à travers les batch sizes 256-4096 et les ratios de 10:1 à 10 000:1. Le ratio 70:1 dans leur benchmark par défaut est réaliste pour le ranking d'ads et la personnalisation de feed. Si tu es sur AMD ou hardware non-Hopper, l'insight architectural (fold broadcast dans l'épilogue du kernel, élimine la matérialisation) porte — les chiffres spécifiques non, mais le pattern oui. Pour les ML compiler folks, le chemin d'inference-time transformation est celui à surveiller ; si le pass de compilateur Meta passe upstream, ça devient gratuit pour le reste de l'écosystème.
