(Not) fast dot products via SIMD

August 13, 2023

Lately I've been tinkering with optimizing Sasha's llama2.rs, a fast Rust port of Karpathy's llama2.c. It takes advantage of Rust nightly's portable_simd, which allows you to emit AVX2 or AVX512 instructions using a relatively clean set of abstractions, and also run inference on a quantized llama2, so it's pretty fast.

One of the core loops looks like this:

let mask = (1 << BITS) - 1; // BITS is 4
let elems_per_i32 = 32 / BITS;
let ipg: usize = GROUPSIZE / 32 * BITS; // GROUPSIZE is 128, so this is 16 (left-associative)
let mask_4bits = i32x8::splat(mask);
let shift_right = i32x8::from_array([0, 4, 8, 12, 16, 20, 24, 28]);
...
// Do K at a time
let zero = f32x8::splat(0.0);
let qzeros = &self.qzeros[oi / elems_per_i32]; // self is the neural net block this is part of
let out_elem = oi % elems_per_i32;
let qweight = self.qweight[oi].chunks_exact(ipg);

let collect = self.scales[oi]
    .into_iter()
    .zip(qweight)
    .enumerate()
    .map(|(group, (scale, qweight))| {
        let qz = ((qzeros[group] >> (BITS * out_elem)) & mask) + 1;
        let scale_simd = f32x8::splat(scale);
        let zero_simd = i32x8::splat(qz);
        let in_pos = group * GROUPSIZE;
        let xs = x[in_pos..in_pos + GROUPSIZE].chunks_exact(8);
        qweight
            .iter()
            .zip(xs)
            .map(|(v, x)| {
                //Extract v into 8 chunks
                let x = f32x8::from_slice(x);
                let num_simd = i32x8::splat(*v);
                let qw: i32x8 = (num_simd >> shift_right) & mask_4bits;
                let combine: f32x8 = (qw - zero_simd).cast::<f32>();
                let weight: f32x8 = scale_simd * combine;
                weight * x
            })
            .fold(zero, |x, y| x + y)
    })
    .fold(zero, |x, y| x + y);
*o = collect.reduce_sum(); // output reference to write to

Here, the f32x8 and i32x8 primitives are aliases for the std::simd types that denote a SIMD vector with 8 lanes of f32 and i32 respectively (256 bits each). I have a 5700X, a Zen 3 CPU with AVX2 support but not AVX512.

When I profiled this, I saw that a lot of the time was being spent in the fold on the inside (i.e. the sum over results from mapping over qweight).

Ok, so this is basically a SIMD dot-product between x[in_pos..in_pos + GROUPSIZE] and whatever the weight becomes. My limited SIMD experience is based on writing CUDA kernels, so naturally I thought that this might be better if I rewrote this to use the fma instructions that have been around for the past few years.

I rewrote the interior slightly:

// Do K at a time
let zero = f32x8::splat(0.0);
let qzeros = &self.qzeros[oi / elems_per_i32];
let out_elem = oi % elems_per_i32;
let qweight = self.qweight[oi].chunks_exact(ipg);

let collect = self.scales[oi]
    .into_iter()
    .zip(qweight)
    .enumerate()
    .map(|(group, (scale, qweight))| {
        let qz = ((qzeros[group] >> (BITS * out_elem)) & mask) + 1;
        let scale_simd = f32x8::splat(scale);
        let qzero_simd = i32x8::splat(qz);
        let in_pos = group * GROUPSIZE;
        let xs = x[in_pos..in_pos + GROUPSIZE]
            .chunks_exact(8)
            .map(f32x8::from_slice);
        let q_op = qweight
            .iter()
            .map(|v| {
                //Extract v into 8 chunks
                let num_simd = i32x8::splat(*v);
                let qw: i32x8 = (num_simd >> shift_right) & mask_4bits;
                (qw - qzero_simd).cast::<f32>()
            });
        xs.zip(q_op).fold(zero, |acc, (x, y)| {
            x.mul_add(y, acc)
        }) * scale_simd
    })
    .fold(zero, |x, y| x + y);
*o = collect.reduce_sum();

However, this tanked performance. I'm working on getting a proper hyperfine benchmark up and running1, but it's safe to say that the performance is at least 50% worse. I thought issuing 1 instructon (fma) per element of the vector (well, N/8 since I have 8 lanes) would be faster than issuing 2 instructions (mul and add).

The reason, as best as I can understand it (not SIMD expert, yet!) is that in the above code, we're bottlenecked on instruction latency, not throughput. Apparently the computation

result = weight[0] * x[0] + weight[1] * x[1] + ... + x[n] * weight[n]

can be run highly out-of-order because it doesn't matter which multiplication happens first, or technically in which order the results are added2. That's why I see the add in the profiling output, because that's the instruction that ends up being where the CPU is waiting for the result of all the individual multiplications to finish.

In contrast:

result = fma(fma(x[2], weight[2], fma(x[1], weight[1], fma(x[0], weight[0]), 0)))

is highly order dependent, so the execution is much slower (also, I think in this case fma latency is 4 cycles vs 3 cycles for add and mul, which makes things much slower). Compilers are pretty smart! And modern out-of-order execution is pretty impressive.

Here's the people who helped me figure this out:

Is vfmadd132pd slow on AMD Zen 3 architecture?

Latency and throughput

1

The full program is hard to benchmark because a ton of the initial time is spent mmaping the huge weights into memory, so I think it depends a lot on what the rest of my machine is doing (I think I have a fairly slow SSD at the moment).

2

I'm not entirely sure about this, actually, since I think Rust doesn't enable -ffast-math in release builds (I'm not sure how to, at the moment), so I think this operation still ends up being non-associative.