Learning in low precision
$\newcommand{\on}[1]{\operatorname{#1}}$ I think I should be writing down some of the things I learn a little more frequently, so I'll try with a small note here.
I read this paper recently: Training LLMs with MXFP4. MXFP4 is a 4-bit OCP low precision format most well known for use in OpenAI's gpt-oss models. It uses an E2M1 layout (1 sign, 2 exponent, 1 mantissa bit) for the main data, with an extra shared scale value (E8M0) for every 32 elements (blockwise scaling). Note that this discussion is mostly about the weight format; the activation format would probably be MXFP8 or BF16. In practice the main contribution of this paper is attempting to show that you can do it all in MXFP4, I think, including the activations (see section 3). In practice they show it is possible to use MXFP4 in linear backward passes, while the forward pass is in FP8 or BF16.
There are two major advantages I see to using 4-bit weights at inference time:
- memory: transferring weights from GMEM to SMEM/RMEM/TMEM is much faster if there's ~1/4 of it
- compute: if the hardware supports compute kernels taking a 4-bit weight and multiplying by a MXFP8/BF16 activation, then you can get a speedup in raw FLOPS supported
The former is much easier, and the latter usually requires Blackwell (or similar).
This paper shows a way to do training with MXFP4 as well. Since training is usually so compute bound, this can be a huge benefit if you can deal with the instability. This paper introduces two methods to reduce the variance of the gradient estimates (sound familiar?!):
- Stochastic rounding, i.e. what my friends do to me when we go out for dinner
- Random hadamard transforms, something I've heard of many times in a compressed sensing context but didn't remember well
Let's dive in.
Stochastic Rounding
The idea here is that if your low precision data type can only represent a small set of values (e.g. just {1.0, 2.0}) and you need to round a higher precision data type down, instead of just rounding 1.6 to 2.0, you can round it to 1.0 sometimes (with a probability $p$). The goal is for the expectation to be the same, i.e. $\mathbb{E}[\on{round}(x)] = x$. In this specific case we'd want $p = 2 - x$, or basically an inverse distance.
When I first heard about this, I thought that this would be way too expensive to use, but it turns out there's a hardware thing called dithering which makes this almost no cost.
To make this work you also need to rescale the original values by 3/4 first (very interesting demonstration of how this works, but I think I understand: the original transform rescales to (-8, 8), but any values greater than 6 are clipped down, which introduces bias). We have to undo this later of course when applying the GEMM using this quantized weight.
Then we apply the stochastic rounding, which in expectation is the right number to multiply our activations by.
Random Hadamard Transform
This transform is $\on{RHT}: x \to HSx$, where $H$ is a Hadamard matrix, and $S$ is a random sign vector. Hadamard matrices are a specific series (not a class of matrices like upper triangular, but a specific set parametrized by a size $n$) where $H_n$ is recursively defined and orthogonal. From that, we conveniently have that $\on{RHT}(X)^\top \on{RHT}(Y) = X^\top Y$, which makes it easy to use in linear layers.
There's a very nice theorem (what I like to call a Mitzenmacher-style theorem) that says that applying the RHT before quantization changes the variance dependence from $b$ to $\log b$:
with probability parametrized by $\epsilon$. The precise form is in the paper. At face value the result is nice, because $\log(b)$ is obviously better than $b$. However, we know the max-norm is less than the L2 norms, and in fact the L2 norm grows approximately as $\sqrt{b}$... which should cancel out the effect?
I dug slightly into this, and it seems like a desirable property of LLM weights is that they're incoherent, i.e. they are entries without a few large outliers, because they can be quantized without significant accuracy loss. Approximately this definition is "no individual entry is too big compared to the Frobenius norm". My intuition at this point is that the Hadamard transform preserves the energy (it's an orthogonal transform), but it spreads out a very big outlier in one coordinate into all the coordinates, via this random sign. Then, since we share a scale amongst many values when quantizing, the quantization error goes way down. The specific bound used here is a Hoeffding-style concentration bound for signed sums, which gives the "sub-Gaussian" shape.
Figure 2 shows the relationship between these bounds (I didn't think about the experimental setup) and they appear to both be kind of linear after a while, though indeed the Hadamard-transformed version is lower variance.
But how do we compute it? It's not free to multiply your data by a matrix, even if it's random. There's a fast algorithm to apply the full RHT in $n \log n$ time, but I think the issue here is that we'd be mixing across the batch dimension, which is very expensive if you're doing any kind of data parallel work (different rows of the data live on different GPUs). They have a solution here that seems to involve some kind of blockwise RHT (the transform is a kind of "mix", and you can say "I want to only mix across the first X bits, then the next X bits, etc.", I think). I didn't have time to understand this.
Evaluation
This is a pre-training paper, which means every eval is extremely expensive.
They didn't have access to FP4 hardware (Blackwell) when running these experiments, which is definitely a bummer (this is from late 2024, I think). Instead they used a neat Microsoft library that emulates MX-format data types in Pytorch. They attempt to measure the real slowdowns / speedups, but it's not clear to me that the comparisons are valid to e.g. specific hardware.
Related results - more Hadamard rotations
I just saw this new paper by the Together AI lab (aka Tri Dao & co), focused on inference, which also uses a block diagonal Hadamard rotation:
Our central finding is that a simple design—token-wise INT4 quantization with block-diagonal Hadamard rotation—consistently achieves the best accuracy–efficiency trade-off.
This seems to be quite common in FP4 and INT4 training and inference, e.g. see this paper.