Training MoEs - efficiency

May 2, 2026
$$ \newcommand{\on}[1]{\operatorname{#1}} \newcommand{\dmodel}{d_{\on{model}}} \newcommand{\dexpert}{d_{\on{expert}}} $$

MoEs are a new major paradigm for scaling models upwards without necessarily increasing compute costs, and they seem to work quite well. There's a lot of interesting work here, but this blog post focuses on recent work on increasing efficiency of MoE kernels. I'm mostly writng about my understanding as I read it.

I'm reading two papers about scaling MoEs:

They actually use different definitions of "granularity"! They're very related:

  • $G_{\on{Sonic}} = d_{\on{model}} / d_{\on{expert}}$ is the ratio of the embedding or residual size and the intermediate size of a single expert
  • $G_{\on{Krajewski}} = d_{\on{ff}} / d_{\on{expert}}$ is the ratio of the "original" dense feed-forward intermediate size and the size of a single expert
  • in practice, usually $d_{\on{ff}} = 4\cdot d_{\on{model}}$ , so they're related, but the former is internally defined (i.e. a function of the architecture), and the latter is kind of dependent on a reference dense transformer model

Let's define granularity to be $G_{\on{Sonic}}$, i.e. the ratio between the embedding size and expert, since it's calculable without a comparison. $K$ is the number of activated experts per token, and $E$ is the total number of experts. SonicMoE defines sparsity as $\rho = K/E$.

The SonicMoE blog post has a nice description of some recent models, which I've fleshed out a little here:

modelyear$d_\on{model}$$d_{\on{expert}}$$K$$E$$G$$\rho$
Mixtral 8x22B2024614416384280.3750.250
DeepSeek V220245120153661603.330.0375
DeepSeek V3.220257168204882563.50.031
Kimi K2.520257168204883843.50.021
Qwen3-Next-80B-A3B-Instruct20252048512105124.00.020
Qwen3.5-397B-A17B202640961024105124.00.020
Qwen3.6-35B-A3B2026204851282564.00.031
Arcee Trinity Large20263072307242561.00.016
z.AI GLM-5.120266144204882563.00.031
MiniMax M2.520263072153682562.00.031
Ant Ling 2.5-1T20268192204882564.00.031
DeepSeek V4 Flash20264096204862562.00.023
DeepSeek V4 Pro20267168307263842.330.016

Note that many of the above models have a "shared expert" (which is always activated), but that's not included in the sparsity calculation above. There's a lot of variability in $G$ and $\rho$ even in models from 2026!

Krajewski et al. showed that the compute optimal hyperparameters for MoE models are increasingly granular (see Table 2). For example, for a pretraining run like DeepSeek V4 Pro's, which was apparently approximately 1e25, we would want $G=64$ (and only $8T$ tokens), which is way bigger than even the most granular model above. So how do we get there?

There are two big issues1:

  1. As we increase the granularity (i.e. make the expert size $d_{\on{expert}}$ smaller relative to the embedding size $d_{\on{embed}}$), we reduce the number of flops used. In order to keep the flops constant (i.e. use our whole budget), we'd need to activate more experts; however in the forward and backward pass we have activations which are not dependent on $d_{\on{expert}}$ but on $d_{\on{embed}}$, which means that they just get bigger as we increase $K$).
  2. Increasing the granularity decreases the arithmetic intensity of the kernels, which means we quickly become memory bound.

The latter is really in the paper. Let's take a look at the forward pass and assume that when we concatenate the $M$ rows from the tokens routed to the expert we're looking at, it forms a matrix $X_e \in \mathbb{R}^{M \times d}$. We know that $M$ is on average $T\rho$, so let's consider that best case. Then for SwiGLU, we have to up-project, gate, and down-project, which are respectively multiplications by matrices of size $(d_{\on{model}}, d_{\on{expert}})$ and back. Each matmul is $2 d_{\on{model}}d_{\on{expert}}M$ FLOPs, for a total of $6 d_{\on{model}}d_{\on{expert}}T\rho$ FLOPs which is the numerator on the arithmetic intensity calculation.

For the denominator, assuming everything is bf16 (2 bytes), we have two operations that we think about in terms of HBM bandwidth: the up-projection + SwiGLU as a single operation, then the down projection as a separate operation (we don't send the intermediate pre-gating value back to HBM, presumably because we can easily fuse this:

  • The first operation in terms of bandwidth is reading the input ($2M\dmodel$ bytes), reading the two matrices ($4\dexpert\dmodel$ bytes), and writing out the result to GMEM ($2M\dexpert$ bytes).
  • The second operation is reading that result ($2M\dexpert$ bytes), reading the down-proj matrix ($2\dexpert\dmodel$) and writing the result ($2M\dmodel$)

So the arithmetic intensity is: $$ \begin{aligned} \on{ArithInt} &= \frac{6M\dmodel\dexpert}{4M\dexpert + 6\dexpert\dmodel +4M\dmodel} \ &= \frac{3}{\frac{2}{\dmodel} + \frac{3}{T\rho} + \frac{2}{\dexpert}} \end{aligned} $$ But since they've defined above $G = \dmodel / \dexpert$, the denominator is really 2 terms: $3/T\rho$, and $(2 + 2G)/\dmodel$. So as we increase the granularity $G$ OR make it more sparse (decrease $\rho$), we are decreasing the arithmetic intensity, pretty much proportionately.

The above specific bookkeeping isn't really that important to understanding the intuition, though. When you are increasing the granularity, you're really making the matrices we multiply by a little smaller ($\dexpert$ is getting smaller), which is making the matmuls less "square". If you squint at it, the above is really the same as the arithmetic intensity of a regular $(M, K, N)$ matmul, and if you make $K$ smaller, you get lower arithmetic intensity (this is a different $K$ from the number of experts activated per token).

Another way to look at it: to keeps FLOPs constant as you increase granularity, you need to decrease the size of each expert and increase the number of activated experts, but there's activations that are linear in the base model size $\dexpert$, which remains constant, and in $K$, the number of activated experts.

Naive MoE kernels

There are several MoE kernels available before this work, like MoMoE and ScatterMoE. Below I'll summarize how a very naive kernel works using SonicMoE's blog post (I think both of the above are more efficient than this naive version).

Naive implementation (using total tokens $T$):

  1. Gather the input $X$ of shape $(T, \dmodel)$ into an expanded form, repeating each input $K$ times so that it's $X_{\on{gathered}}:(TK, \dmodel)$
  2. Apply a grouped GEMM to $X_{\on{gathered}}$ along with the corresponding computed expert offsets (i.e. routing plan) and the up-project and gating weights (which are of course $(\dmodel, 2\dexpert)$ in size). The result is $H: (TK, 2\dexpert)$.
  3. Apply the SwiGLU operation to this to get the pre-down projection activations, $A: (TK, \dexpert)$.
  4. Apply the grouped GEMM with the down projection (and the expert offsets again) to get something of size $Y: (TK, \dmodel)$.
  5. Scatter and aggregate with the routing scores $S$ to get an output of size $(T, \dmodel)$

The real issue here is that as we get more granular, and more sparse, the activations of size $TK\dmodel$ get really big, as mentioned above. These have to fit in HBM and also transfer to compute and back. There's a similar issue with the backward pass.

SonicMoE

There's some smart fusing here, which is basically oriented around trying to avoid materializing anything that looks like the above shape. You can see the precise set of operations here.

So far so good - obviously a clever team that has found a smart way to avoid caching or materializing things that aren't necessary. The core idea is to avoid anything that is size $\mathcal{O}(TK\dmodel)$.

  • In the forward pass that's the $X$ after the gather operation (i.e. duplicating rows of $X$) but before the up-proj, and the $Y$ before the scatter-and-sum operation. SonicMoE fuses the up-proj and down-proj to avoid materializing the big $X$ and $Y$, and doesn't cache the results of the intermediate operations except for $H$, the output from the first matmul of the SwiGLU above.
  • It's more complicated for the backwards pass. We start with the grad of the output, and we need to compute $dX$, $dS$ (the routing scores), $dW_1$, and $dW_2$, the grads of the SwiGLU weights. If you store $S$ and $Y$ during the forward pass, computing $dS$ is just an inner product with $dO$. However, if we don't store it, the authors find a workaround: you can think of the application of $S$ as happening before the down-proj, and think of the result as $A'$, then you can avoid ever needing $Y$ at all.

Here's a more detailed explanation. Normally, to compute $dS$, we take the inner product of $dO$ and $Y$, which requires caching $Y$. However, let $A' = \on{diag}(S) \cdot A$, where it's after we route/gather $S$ to the right shape for $A$ ($TK \times \dexpert$). Then $dA = \on{diag}(S) \cdot dA'$ 2, and since $A'$ is multiplied directly by $W_2$ to get $O$, we can easily compute $dA'$ and $dA$ without storing anything besides the original inputs. The interesting part is this from the blog post and the paper:

$$\begin{align*} dS = \langle dO, Y\rangle = \langle dO, AW_2\rangle = \langle dOW_2^\intercal, A\rangle = \langle dA', A\rangle\end{align*}$$ Very cool. Then, since $A'$ is used to compute $O$ from $W_2$, you can compute $dW_2$ using that (once you've computed it in the backwards pass) and we can compute $dW_1$ since we store $H$. There's a detail here about exploiting L2 cache locality as well to make this faster.

In my head, the overall intuition for how SonicMoE avoids caching these big tensors is something like: this is kind of like avoiding materializing a super big matrix when you can instead do something like a torch.Tensor.scatter_. This isn't a particularly good intuition yet.

QuACK

This is where it gets really interesting. Generally speaking when we write optimized kernels,we keep a certain target hardware in mind (e.g. H100 or B200). In CUDA, this usually means a specific compute architecture. There were significant changes to the hardware in between those two generations, and it took a pretty long time for kernels to fully support Blackwell (Triton had to rethink a lot of details with Blackwell, for example).

SonicMoE supports a LOT of different architectures - H100, B200, and also surprisingly the SM120 architecture in consumer Blackwell (and RTX 6000, a chip I work with). How do they do it? One of the trends in kernel design lately is that we are moving towards increasing levels of warp specialization in kernels, with producers and consumers. SonicMoE is written by taking advantage of a library, QuACK, that allows you to split the computation into 3 stages: prologue, mainloop, and epilogue, and customize the stages. It turns out that all of the SonicMoE kernels can be written by customizing the prologue and the epilogue, and allowing the mainloop to be a generic MMA (WGMMA or UMMA depending on hardware).

Why? The grouped GEMM kernels that are necessary for MoE mostly differ in how they scatter or gather the producer or epilogue steps. For example (my understanding), when fusing the gathered X up-projection, this is really a matmul with a producer that loads from different tokens (or different expert weights) based on the expert-token choices. This means the compute expensive part (and usually hardware-specific area) is specialized to the middle, but the flexibility (scatter/gather/activations/etc.) is in the producer and epilogue.

Scheduler

This is something I don't have any experience with. I haven't really understood this part except to understand that Blackwell is different from Hopper since tcgen05.mma is asynchronous and skips the registers, so a more complicated scheduler (which you can apparently customize) is necessary to keep track of which TMEM buffers are locked or ready.

General notes

One thing I'm always wondering when reading this is that it often feels like NVIDIA is writing the hardware and software abstractions with this stuff in mind. Some of the advancements in SonicMoE can be described as "how can we make a grouped varlen GEMM really just like a regular GEMM?" which means that the hardware primitives they need definitely exist because of how central matmuls are, so it's not crazy that everything exists, but it's still definitely very impressive.

  1. I'm just rewriting the ideas from the blog post as I parse through them; this isn't intended to explain anything better than the blog post. Feel free to point out mistakes in understanding.

  2. Excuse all the abuse of notation; this is just my reading notes but it's easier if you have all the shapes.