Dot products in Rust

March 1, 2024
This post is a work in progress. I'll update it as I go, and I might be missing very obvious things (or haven't gotten around to it yet). Feel free to shoot me an email if you want to make a comment.

A few months ago I was helping with a Rust-based Llama2 inference project and learned a few things about optimizing CPU SIMD code. One thing I couldn't shake is that codegen in Rust is still pretty bad at the moment, at least for neural network inference.

Here's a comparison of two different dot products, which are written by the portable_simd group and by StackOverflow user Soonts, in Rust and C++ respectively.

#![feature(array_chunks)]
#![feature(slice_as_chunks)]
#![feature(portable_simd)]

use std::simd::num::*;
use std::simd::*;

pub fn dot_prod_simd_5(a: &[f32], b: &[f32]) -> f32 {
    a.array_chunks::<4>()
        .map(|&a| f32x4::from_array(a))
        .zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
        .fold(f32x4::splat(0.), |acc, (a, b)| a.mul_add(b, acc))
        .reduce_sum()
}

fn main() {
    // initialize two large arrays with sin(x/100) and cos(x/100), each with N elements
    // multiply K times
    let N = 1000000;
    let K = 10000;
    let a: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).sin()).collect();
    let b: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).cos()).collect();

    // compute the dot product of the two arrays
    // do this dot product enough times to reduce fixed costs
    let mut result = 0.0;
    for _ in 0..K {
        result = dot_prod_simd_5(&a, &b);
    }
    println!("result: {}", result);
}

To properly compile this, you have to use a nightly Rust compiler, and set rustflags = ["-Ctarget-cpu=native"] so that codegen properly uses your AVX2-capable machine.

And in C++ (NOTE! not my code, this is from StackOverflow, credit Soonts):

#include <immintrin.h>
#include <vector>
#include <algorithm>
#include <assert.h>
#include <stdint.h>
#include <cmath>
#include <cstdio>

using std::ptrdiff_t;

// CPUs support RAM access like this: "ymmword ptr [rax+64]"
// Using templates with offset int argument to make easier for compiler to emit good code.

// Returns acc + ( p1 * p2 ), for 8 float lanes
template<int offsetRegs>
inline __m256 fma8( __m256 acc, const float* p1, const float* p2 )
{
    constexpr ptrdiff_t lanes = offsetRegs * 8;
    const __m256 a = _mm256_loadu_ps( p1 + lanes );
    const __m256 b = _mm256_loadu_ps( p2 + lanes );
    return _mm256_fmadd_ps( a, b, acc );
}

#ifdef __AVX2__
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
    // Make a mask of 8 bytes
    // These aren't branches, they should compile to conditional moves
    missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
    uint64_t mask = -( missingLanes < 8 );
    mask >>= missingLanes * 8;
    // Sign extend the bytes into int32 lanes in AVX vector
    __m128i tmp = _mm_cvtsi64_si128( (int64_t)mask );
    return _mm256_cvtepi8_epi32( tmp );
}
#else
// Aligned by 64 bytes
// The load will only touch a single cache line, no penalty for unaligned load
static const int alignas( 64 ) s_remainderLoadMask[ 16 ] = {
    -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0 };
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
    // These aren't branches, they compile to conditional moves
    missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
    missingLanes = std::min( missingLanes, (ptrdiff_t)8 );
    // Unaligned load from a constant array
    const int* rsi = &s_remainderLoadMask[ missingLanes ];
    return _mm256_loadu_si256( ( const __m256i* )rsi );
}
#endif

// Same as fma8(), load conditionally using the mask
// When the mask has all bits set, an equivalent of fma8(), but 1 instruction longer
// When the mask is a zero vector, the function won't load anything, will return `acc`
template<int offsetRegs>
inline __m256 fma8rem( __m256 acc, const float* p1, const float* p2, ptrdiff_t rem )
{
    constexpr ptrdiff_t lanes = offsetRegs * 8;
    // Generate the mask for conditional loads
    // The implementation depends on whether AVX2 is enabled with compiler switches
    const __m256i mask = makeRemainderMask( ( 8 + lanes ) - rem );
    // These conditional load instructions produce zeros for the masked out lanes
    const __m256 a = _mm256_maskload_ps( p1 + lanes, mask );
    const __m256 b = _mm256_maskload_ps( p2 + lanes, mask );
    return _mm256_fmadd_ps( a, b, acc );
}

// Compute dot product of float vectors, using 8-wide FMA instructions
float dotProductFma( const std::vector<float>& a, const std::vector<float>& b )
{
    assert( a.size() == b.size() );
    const size_t length = a.size();
    if( length == 0 )
        return 0.0f;

    const float* p1 = a.data();
    const float* p2 = b.data();
    // Compute length of the remainder; 
    // We want a remainder of length [ 1 .. 32 ] instead of [ 0 .. 31 ]
    const ptrdiff_t rem = ( ( length - 1 ) % 32 ) + 1;
    const float* const p1End = p1 + length - rem;

    // Initialize accumulators with zeros
    __m256 dot0 = _mm256_setzero_ps();
    __m256 dot1 = _mm256_setzero_ps();
    __m256 dot2 = _mm256_setzero_ps();
    __m256 dot3 = _mm256_setzero_ps();

    // Process the majority of the data.
    // The code uses FMA instructions to multiply + accumulate, consuming 32 values per loop iteration.
    // Unrolling manually for 2 reasons:
    // 1. To reduce data dependencies. With a single register, every loop iteration would depend on the previous result.
    // 2. Unrolled code checks for exit condition 4x less often, therefore more CPU cycles spent computing useful stuff.
    while( p1 < p1End )
    {
        dot0 = fma8<0>( dot0, p1, p2 );
        dot1 = fma8<1>( dot1, p1, p2 );
        dot2 = fma8<2>( dot2, p1, p2 );
        dot3 = fma8<3>( dot3, p1, p2 );
        p1 += 32;
        p2 += 32;
    }

    // Handle the last, possibly incomplete batch of length [ 1 .. 32 ]
    // To save multiple branches, we load that entire batch with `vmaskmovps` conditional loads
    // On modern CPUs, the performance of such loads is pretty close to normal full vector loads
    dot0 = fma8rem<0>( dot0, p1, p2, rem );
    dot1 = fma8rem<1>( dot1, p1, p2, rem );
    dot2 = fma8rem<2>( dot2, p1, p2, rem );
    dot3 = fma8rem<3>( dot3, p1, p2, rem );

    // Add 32 values into 8
    dot0 = _mm256_add_ps( dot0, dot2 );
    dot1 = _mm256_add_ps( dot1, dot3 );
    dot0 = _mm256_add_ps( dot0, dot1 );
    // Add 8 values into 4
    __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( dot0 ),
        _mm256_extractf128_ps( dot0, 1 ) );
    // Add 4 values into 2
    r4 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
    // Add 2 lower values into the scalar result
    r4 = _mm_add_ss( r4, _mm_movehdup_ps( r4 ) );

    // Return the lowest lane of the result vector.
    // The intrinsic below compiles into noop, modern compilers return floats in the lowest lane of xmm0 register.
    return _mm_cvtss_f32( r4 );
}

int main(int argc, char** argv) {
    int N = 1000000;
    int K = 10000;
    std::vector<float> a(N);
    std::vector<float> b(N);
    for (int i = 0; i < N; i++) {
        a[i] = sin(i/100.0);
        b[i] = cos(i/100.0);
    }
    float result = 0.0;
    for (int i = 0; i < K; i++) {
        result = dotProductFma(a, b);
    }
    std::printf("result: %.7f\n", result);
    return 0;
}

I compiled this with g++ -O3 -march=native -o main.

As far as I can tell, these pieces of code are functionally identical. There are enough loops that creating the arrays for the first time shouldn't be a problem.

They also produce slightly different results, which might be from different float accumulation patterns.

~/proj/profile/rust_dp main ?33 ❯ ./target/release/rust_dp
result: 4.5250244
~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621

and for good measure using clang++:

~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621

Benchmarks

~/proj/profile/rust_dp main ?33 ❯ hyperfine --warmup 3 "target/release/rust_dp"
Benchmark 1: target/release/rust_dp
  Time (mean ± σ):      2.261 s ±  0.006 s    [User: 2.234 s, System: 0.006 s]
  Range (min … max):    2.255 s …  2.273 s    10 runs
~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
  Time (mean ± σ):     752.5 ms ±  19.3 ms    [User: 726.2 ms, System: 3.1 ms]
  Range (min … max):   724.9 ms … 785.9 ms    10 runs

and wow here's clang++:

~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
  Time (mean ± σ):     233.4 ms ±   3.2 ms    [User: 207.6 ms, System: 3.3 ms]
  Range (min … max):   228.7 ms … 238.4 ms    11 runs

Ideas

  1. Rust code could still be generating bounds checks.
  2. We aren't unrolling optimally above and using all the registers.
  3. We need -ffast-math to get some small microoptimizations.