Learning in Low Precision
$\newcommand{\on}[1]{\operatorname{#1}}$ I think I should be writing down some of the things I learn a little more frequently, so I'll try with a small note here.
I read this paper recently: Training LLMs with MXFP4. MXFP4 is a 4-bit OCP low precision format most well known for use in OpenAI's gpt-oss models. It uses an E2M1 layout (1 sign, 2 exponent, 1 mantissa bit) for the main data, with an extra shared scale value (INT8) for every 32 elements (blockwise scaling). Note that this discussion is mostly about the weight format; the activation format would probably be MXFP8 or BF16. In practice the main contribution of this paper is showing that you can do it all in MXFP4, I think, including the activations (see section 3).
There are two major advantages I see to using 4-bit weights at inference time:
- memory: transferring weights from GMEM to SMEM/RMEM/TMEM is much faster if there's ~1/4 of it
- compute: if the hardware supports compute kernels taking a 4-bit weight and multiplying by a MXFP8/BF16 activation, then you can get a speedup in raw FLOPS supported
The former is much easier, and the latter usually requires Blackwell (or similar).
This paper shows a way to do training with MXFP4 as well. Since training is usually so compute bound, this can be a huge benefit if you can deal with the instability. This paper introduces two methods to reduce the variance of the gradient estimates (sound familiar?!):
- Stochastic rounding, i.e. what my friends do to me when we go out for dinner
- Random hadamard transforms, something I've heard of many times in a compressed sensing context but didn't remember well
Let's dive in.
Stochastic Rounding
The idea here is that if your low precision data type can only represent a small set of values (e.g. just {1.0, 2.0}) and you need to round a higher precision data type down, instead of just rounding 1.6 to 2.0, you can round it to 1.0 sometimes (with a probability $p$). The goal is for the expectation to be the same, i.e. $\mathbb{E}[\on{round}(x)] = x$. In this specific case we'd want $p = 2 - x$, or basically an inverse distance.
When I first heard about this, I thought that this would be way too expensive to use, but it turns out there's a hardware thing called dithering which makes this almost no cost.
To make this work you also need to rescale the original values by 3/4 first (very interesting demonstration of how this works, but I think I understand: the original transform rescales to (-8, 8), but any values greater than 6 are clipped down, which introduces bias). We have to undo this later of course when applying the GEMM using this quantized weight.
Then we apply the stochastic rounding, which in expectation is the right number to multiply our activations by.
Random Hadamard Transform
This transform is $\on{RHT}: x \to HSx$, where $H$ is a Hadamard matrix, and $S$ is a random sign vector. Hadamard matrices are a specific series (not a class of matrices, but one literal set of matrices) where $H_n$ is recursively defined and orthogonal. From that, we have a nice result that $\on{RHT}(X)^\top \on{RHT}(Y) = X^\top Y$.
There's a very nice theorem (what I like to call a Mitzenmacher-style theorem) that says that the variance of the product of two quantized vectors grows linearly in the size of the vectors and the max-norm of each vector, but the variance of the product two quantized-after-RHT'd vectors grows with high probability log in the number of bits times the L2 norm of each vector.
At its face the result is nice, because $\log(b)$ is obviously better than $b$. However, we know the max-norm is less than the L2 norms, and in fact the L2 norm grows approximately as $\sqrt{b}$... which should cancel out the effect? But perhaps there's something weird where the max-norm is inherently very unstable, since maxes over many dimensions can lead to weird effects.
Figure 2 shows the relationship between these bounds (I didn't think about the experimental setup) and they appear to both be kind of linear after a while, though indeed the Hadamard-transformed version is lower variance.
But how do we compute it? It's not free to multiply your data by a matrix, even if it's random.
They have a solution here that seems to involve some kind of blockwise RHT (the transform is a kind of "mix", and you can say "I want to only mix across the first X bits, then the next X bits, etc.", I think). I didn't have time to understand this.
Evaluation
This is a pre-training paper, which means every eval is extremely expensive.
They didn't have access to FP4 hardware (Blackwell) when running these experiments, which is definitely a bummer (this is from late 2024, I think). Instead they used a neat Microsoft library that emulates MX-format data types in Pytorch. They attempt to measure the real slowdowns / speedups, but it's not clear to me that the comparisons are valid to e.g. specific hardware.