Dot products in Rust
A few months ago I was helping with a Rust-based Llama2 inference project and learned a few things about optimizing CPU SIMD code. One thing I couldn't shake is that codegen in Rust is still pretty bad at the moment, at least for neural network inference.
Here's a comparison of two different dot products, which are written by the portable_simd group and by StackOverflow user Soonts, in Rust and C++ respectively.
#![feature(array_chunks)]
#![feature(slice_as_chunks)]
#![feature(portable_simd)]
use std::simd::num::*;
use std::simd::*;
pub fn dot_prod_simd_5(a: &[f32], b: &[f32]) -> f32 {
a.array_chunks::<4>()
.map(|&a| f32x4::from_array(a))
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
.fold(f32x4::splat(0.), |acc, (a, b)| a.mul_add(b, acc))
.reduce_sum()
}
fn main() {
// initialize two large arrays with sin(x/100) and cos(x/100), each with N elements
// multiply K times
let N = 1000000;
let K = 10000;
let a: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).sin()).collect();
let b: Vec<f32> = (0..N).map(|i| ((i as f32) / 100.0).cos()).collect();
// compute the dot product of the two arrays
// do this dot product enough times to reduce fixed costs
let mut result = 0.0;
for _ in 0..K {
result = dot_prod_simd_5(&a, &b);
}
println!("result: {}", result);
}
To properly compile this, you have to use a nightly Rust compiler, and set rustflags = ["-Ctarget-cpu=native"]
so that codegen properly uses your AVX2-capable machine.
And in C++ (NOTE! not my code, this is from StackOverflow, credit Soonts):
#include <immintrin.h>
#include <vector>
#include <algorithm>
#include <assert.h>
#include <stdint.h>
#include <cmath>
#include <cstdio>
using std::ptrdiff_t;
// CPUs support RAM access like this: "ymmword ptr [rax+64]"
// Using templates with offset int argument to make easier for compiler to emit good code.
// Returns acc + ( p1 * p2 ), for 8 float lanes
template<int offsetRegs>
inline __m256 fma8( __m256 acc, const float* p1, const float* p2 )
{
constexpr ptrdiff_t lanes = offsetRegs * 8;
const __m256 a = _mm256_loadu_ps( p1 + lanes );
const __m256 b = _mm256_loadu_ps( p2 + lanes );
return _mm256_fmadd_ps( a, b, acc );
}
#ifdef __AVX2__
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
// Make a mask of 8 bytes
// These aren't branches, they should compile to conditional moves
missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
uint64_t mask = -( missingLanes < 8 );
mask >>= missingLanes * 8;
// Sign extend the bytes into int32 lanes in AVX vector
__m128i tmp = _mm_cvtsi64_si128( (int64_t)mask );
return _mm256_cvtepi8_epi32( tmp );
}
#else
// Aligned by 64 bytes
// The load will only touch a single cache line, no penalty for unaligned load
static const int alignas( 64 ) s_remainderLoadMask[ 16 ] = {
-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0 };
inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
{
// These aren't branches, they compile to conditional moves
missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
missingLanes = std::min( missingLanes, (ptrdiff_t)8 );
// Unaligned load from a constant array
const int* rsi = &s_remainderLoadMask[ missingLanes ];
return _mm256_loadu_si256( ( const __m256i* )rsi );
}
#endif
// Same as fma8(), load conditionally using the mask
// When the mask has all bits set, an equivalent of fma8(), but 1 instruction longer
// When the mask is a zero vector, the function won't load anything, will return `acc`
template<int offsetRegs>
inline __m256 fma8rem( __m256 acc, const float* p1, const float* p2, ptrdiff_t rem )
{
constexpr ptrdiff_t lanes = offsetRegs * 8;
// Generate the mask for conditional loads
// The implementation depends on whether AVX2 is enabled with compiler switches
const __m256i mask = makeRemainderMask( ( 8 + lanes ) - rem );
// These conditional load instructions produce zeros for the masked out lanes
const __m256 a = _mm256_maskload_ps( p1 + lanes, mask );
const __m256 b = _mm256_maskload_ps( p2 + lanes, mask );
return _mm256_fmadd_ps( a, b, acc );
}
// Compute dot product of float vectors, using 8-wide FMA instructions
float dotProductFma( const std::vector<float>& a, const std::vector<float>& b )
{
assert( a.size() == b.size() );
const size_t length = a.size();
if( length == 0 )
return 0.0f;
const float* p1 = a.data();
const float* p2 = b.data();
// Compute length of the remainder;
// We want a remainder of length [ 1 .. 32 ] instead of [ 0 .. 31 ]
const ptrdiff_t rem = ( ( length - 1 ) % 32 ) + 1;
const float* const p1End = p1 + length - rem;
// Initialize accumulators with zeros
__m256 dot0 = _mm256_setzero_ps();
__m256 dot1 = _mm256_setzero_ps();
__m256 dot2 = _mm256_setzero_ps();
__m256 dot3 = _mm256_setzero_ps();
// Process the majority of the data.
// The code uses FMA instructions to multiply + accumulate, consuming 32 values per loop iteration.
// Unrolling manually for 2 reasons:
// 1. To reduce data dependencies. With a single register, every loop iteration would depend on the previous result.
// 2. Unrolled code checks for exit condition 4x less often, therefore more CPU cycles spent computing useful stuff.
while( p1 < p1End )
{
dot0 = fma8<0>( dot0, p1, p2 );
dot1 = fma8<1>( dot1, p1, p2 );
dot2 = fma8<2>( dot2, p1, p2 );
dot3 = fma8<3>( dot3, p1, p2 );
p1 += 32;
p2 += 32;
}
// Handle the last, possibly incomplete batch of length [ 1 .. 32 ]
// To save multiple branches, we load that entire batch with `vmaskmovps` conditional loads
// On modern CPUs, the performance of such loads is pretty close to normal full vector loads
dot0 = fma8rem<0>( dot0, p1, p2, rem );
dot1 = fma8rem<1>( dot1, p1, p2, rem );
dot2 = fma8rem<2>( dot2, p1, p2, rem );
dot3 = fma8rem<3>( dot3, p1, p2, rem );
// Add 32 values into 8
dot0 = _mm256_add_ps( dot0, dot2 );
dot1 = _mm256_add_ps( dot1, dot3 );
dot0 = _mm256_add_ps( dot0, dot1 );
// Add 8 values into 4
__m128 r4 = _mm_add_ps( _mm256_castps256_ps128( dot0 ),
_mm256_extractf128_ps( dot0, 1 ) );
// Add 4 values into 2
r4 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
// Add 2 lower values into the scalar result
r4 = _mm_add_ss( r4, _mm_movehdup_ps( r4 ) );
// Return the lowest lane of the result vector.
// The intrinsic below compiles into noop, modern compilers return floats in the lowest lane of xmm0 register.
return _mm_cvtss_f32( r4 );
}
int main(int argc, char** argv) {
int N = 1000000;
int K = 10000;
std::vector<float> a(N);
std::vector<float> b(N);
for (int i = 0; i < N; i++) {
a[i] = sin(i/100.0);
b[i] = cos(i/100.0);
}
float result = 0.0;
for (int i = 0; i < K; i++) {
result = dotProductFma(a, b);
}
std::printf("result: %.7f\n", result);
return 0;
}
I compiled this with g++ -O3 -march=native -o main
.
As far as I can tell, these pieces of code are functionally identical. There are enough loops that creating the arrays for the first time shouldn't be a problem.
They also produce slightly different results, which might be from different float accumulation patterns.
~/proj/profile/rust_dp main ?33 ❯ ./target/release/rust_dp
result: 4.5250244
~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621
and for good measure using clang++
:
~/proj/profile/cpp_dp main ?33 ❯ ./main
result: 4.5243621
Benchmarks
~/proj/profile/rust_dp main ?33 ❯ hyperfine --warmup 3 "target/release/rust_dp"
Benchmark 1: target/release/rust_dp
Time (mean ± σ): 2.261 s ± 0.006 s [User: 2.234 s, System: 0.006 s]
Range (min … max): 2.255 s … 2.273 s 10 runs
~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
Time (mean ± σ): 752.5 ms ± 19.3 ms [User: 726.2 ms, System: 3.1 ms]
Range (min … max): 724.9 ms … 785.9 ms 10 runs
and wow here's clang++
:
~/proj/profile/cpp_dp main ?33 ❯ hyperfine --warmup 3 "./main"
Benchmark 1: ./main
Time (mean ± σ): 233.4 ms ± 3.2 ms [User: 207.6 ms, System: 3.3 ms]
Range (min … max): 228.7 ms … 238.4 ms 11 runs
Ideas
- Rust code could still be generating bounds checks.
- We aren't unrolling optimally above and using all the registers.
- We need
-ffast-math
to get some small microoptimizations.