I wrote a fused MoE dispatch kernel in pure Triton that beats Megablocks on Mixtral and DeepSeek at inference batch sizes
Posted by bassrehab@reddit | LocalLLaMA | View on Reddit | 2 comments
Been working on custom Triton kernels for LLM inference for a while. My latest project: a fused MoE dispatch pipeline that handles the full forward pass in 5 kernel launches instead of 24+ in the naive approach.
Results on Mixtral-8x7B (A100):
| Tokens | vs PyTorch | vs Megablocks |
|---|---|---|
| 32 | 4.9x | 131% |
| 128 | 5.8x | 124% |
| 512 | 6.5x | 89% |
At 32 and 128 tokens (where most inference serving actually happens), it's faster than Stanford's CUDA-optimized Megablocks. At 512+ Megablocks pulls ahead with its hand-tuned block-sparse matmul.
The key trick is fusing the gate+up projection so both GEMMs share the same input tile from L2 cache, and the SiLU activation happens in registers without ever hitting global memory. Saves \~470MB of memory traffic per forward pass on Mixtral.
Also tested on DeepSeek-V3 (256 experts) and Qwen2-MoE. Ran the full suite on AMD MI300X with zero code changes, all 162 tests passing.
Code: https://github.com/bassrehab/triton-kernels
Full writeup with roofline analysis: https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/
mrtrly@reddit
The dispatch overhead is where everyone leaves performance on the table. Fusing it into the forward pass instead of treating it as a separate stage is the obvious move once you see it, but getting the memory layout right across those 5 launches is the hard part. Are you handling the all-to-all communication as a single kernel or splitting that piece out?
bassrehab@reddit (OP)
nah, single-GPU only. EP is on the list, haven't done it yet. single-GPU dispatch is the easy case honestly. you've got the global view of token-to-expert assignments, sort and permute in shared memory, done. multi-GPU is where it gets ugly: overlapping the all-to-all with expert compute, stragglers from imbalanced expert loads, the return-trip all-to-all for unpermute.
My guess is the all-to-all has to stay outside the GEMM kernel. Triton doesn't expose NCCL ops so you can't actually fuse comm into compute. But you could overlap: fire off the dispatch, and as tokens arrive on each rank start the gate+up on whatever's landed. the fused gate+up I built would help here since it shortens the critical path between "tokens arrived" and "ready for down projection".
anyone actually doing comm-compute overlap for MoE in Triton? only thing I know of is DeepEP and that's CUDA.