Optimize fma with GF2 with NEON

Improves performance by ~2x on very large symbol counts
This commit is contained in:
Christopher Berner 2021-02-14 13:49:16 -08:00
parent 8bbc99c5cd
commit 893e1c7c79
2 changed files with 154 additions and 41 deletions

@ -74,51 +74,51 @@ The following were run on a Raspberry Pi 3 B+ (Cortex-A53 @ 1.4GHz)
```
Symbol size: 1280 bytes (without pre-built plan)
symbol count = 10, encoded 127 MB in 11.898secs, throughput: 86.1Mbit/s
symbol count = 100, encoded 127 MB in 8.862secs, throughput: 115.5Mbit/s
symbol count = 250, encoded 127 MB in 9.103secs, throughput: 112.4Mbit/s
symbol count = 500, encoded 127 MB in 8.806secs, throughput: 115.9Mbit/s
symbol count = 1000, encoded 126 MB in 9.412secs, throughput: 107.9Mbit/s
symbol count = 2000, encoded 126 MB in 7.041secs, throughput: 144.2Mbit/s
symbol count = 5000, encoded 122 MB in 12.119secs, throughput: 80.6Mbit/s
symbol count = 10000, encoded 122 MB in 9.694secs, throughput: 100.7Mbit/s
symbol count = 20000, encoded 122 MB in 12.087secs, throughput: 80.8Mbit/s
symbol count = 50000, encoded 122 MB in 23.912secs, throughput: 40.8Mbit/s
symbol count = 10, encoded 127 MB in 9.478secs, throughput: 108.0Mbit/s
symbol count = 100, encoded 127 MB in 6.281secs, throughput: 162.9Mbit/s
symbol count = 250, encoded 127 MB in 7.216secs, throughput: 141.8Mbit/s
symbol count = 500, encoded 127 MB in 7.623secs, throughput: 133.9Mbit/s
symbol count = 1000, encoded 126 MB in 8.424secs, throughput: 120.6Mbit/s
symbol count = 2000, encoded 126 MB in 8.775secs, throughput: 115.7Mbit/s
symbol count = 5000, encoded 122 MB in 8.439secs, throughput: 115.7Mbit/s
symbol count = 10000, encoded 122 MB in 8.297secs, throughput: 117.7Mbit/s
symbol count = 20000, encoded 122 MB in 9.329secs, throughput: 104.7Mbit/s
symbol count = 50000, encoded 122 MB in 11.724secs, throughput: 83.3Mbit/s
Symbol size: 1280 bytes (with pre-built plan)
symbol count = 10, encoded 127 MB in 8.399secs, throughput: 121.9Mbit/s
symbol count = 100, encoded 127 MB in 4.660secs, throughput: 219.6Mbit/s
symbol count = 250, encoded 127 MB in 6.373secs, throughput: 160.5Mbit/s
symbol count = 500, encoded 127 MB in 4.699secs, throughput: 217.2Mbit/s
symbol count = 1000, encoded 126 MB in 5.978secs, throughput: 169.9Mbit/s
symbol count = 2000, encoded 126 MB in 6.182secs, throughput: 164.3Mbit/s
symbol count = 5000, encoded 122 MB in 5.958secs, throughput: 163.9Mbit/s
symbol count = 10000, encoded 122 MB in 7.228secs, throughput: 135.1Mbit/s
symbol count = 20000, encoded 122 MB in 6.764secs, throughput: 144.4Mbit/s
symbol count = 50000, encoded 122 MB in 6.649secs, throughput: 146.9Mbit/s
symbol count = 10, encoded 127 MB in 6.298secs, throughput: 162.6Mbit/s
symbol count = 100, encoded 127 MB in 5.402secs, throughput: 189.5Mbit/s
symbol count = 250, encoded 127 MB in 5.312secs, throughput: 192.6Mbit/s
symbol count = 500, encoded 127 MB in 5.296secs, throughput: 192.7Mbit/s
symbol count = 1000, encoded 126 MB in 4.081secs, throughput: 248.9Mbit/s
symbol count = 2000, encoded 126 MB in 4.110secs, throughput: 247.1Mbit/s
symbol count = 5000, encoded 122 MB in 5.947secs, throughput: 164.2Mbit/s
symbol count = 10000, encoded 122 MB in 6.271secs, throughput: 155.7Mbit/s
symbol count = 20000, encoded 122 MB in 6.745secs, throughput: 144.8Mbit/s
symbol count = 50000, encoded 122 MB in 6.646secs, throughput: 146.9Mbit/s
Symbol size: 1280 bytes
symbol count = 10, decoded 127 MB in 13.727secs using 0.0% overhead, throughput: 74.6Mbit/s
symbol count = 100, decoded 127 MB in 9.727secs using 0.0% overhead, throughput: 105.2Mbit/s
symbol count = 250, decoded 127 MB in 12.135secs using 0.0% overhead, throughput: 84.3Mbit/s
symbol count = 500, decoded 127 MB in 10.658secs using 0.0% overhead, throughput: 95.8Mbit/s
symbol count = 1000, decoded 126 MB in 10.414secs using 0.0% overhead, throughput: 97.5Mbit/s
symbol count = 2000, decoded 126 MB in 10.828secs using 0.0% overhead, throughput: 93.8Mbit/s
symbol count = 5000, decoded 122 MB in 12.545secs using 0.0% overhead, throughput: 77.8Mbit/s
symbol count = 10000, decoded 122 MB in 10.667secs using 0.0% overhead, throughput: 91.5Mbit/s
symbol count = 20000, decoded 122 MB in 19.769secs using 0.0% overhead, throughput: 49.4Mbit/s
symbol count = 50000, decoded 122 MB in 25.817secs using 0.0% overhead, throughput: 37.8Mbit/s
symbol count = 10, decoded 127 MB in 11.529secs using 0.0% overhead, throughput: 88.8Mbit/s
symbol count = 100, decoded 127 MB in 8.011secs using 0.0% overhead, throughput: 127.8Mbit/s
symbol count = 250, decoded 127 MB in 9.322secs using 0.0% overhead, throughput: 109.7Mbit/s
symbol count = 500, decoded 127 MB in 9.388secs using 0.0% overhead, throughput: 108.7Mbit/s
symbol count = 1000, decoded 126 MB in 7.614secs using 0.0% overhead, throughput: 133.4Mbit/s
symbol count = 2000, decoded 126 MB in 6.706secs using 0.0% overhead, throughput: 151.5Mbit/s
symbol count = 5000, decoded 122 MB in 8.677secs using 0.0% overhead, throughput: 112.5Mbit/s
symbol count = 10000, decoded 122 MB in 9.529secs using 0.0% overhead, throughput: 102.5Mbit/s
symbol count = 20000, decoded 122 MB in 10.766secs using 0.0% overhead, throughput: 90.7Mbit/s
symbol count = 50000, decoded 122 MB in 13.497secs using 0.0% overhead, throughput: 72.4Mbit/s
symbol count = 10, decoded 127 MB in 11.557secs using 5.0% overhead, throughput: 88.6Mbit/s
symbol count = 100, decoded 127 MB in 9.586secs using 5.0% overhead, throughput: 106.8Mbit/s
symbol count = 250, decoded 127 MB in 11.725secs using 5.0% overhead, throughput: 87.2Mbit/s
symbol count = 500, decoded 127 MB in 10.859secs using 5.0% overhead, throughput: 94.0Mbit/s
symbol count = 1000, decoded 126 MB in 7.036secs using 5.0% overhead, throughput: 144.3Mbit/s
symbol count = 2000, decoded 126 MB in 11.247secs using 5.0% overhead, throughput: 90.3Mbit/s
symbol count = 5000, decoded 122 MB in 12.590secs using 5.0% overhead, throughput: 77.6Mbit/s
symbol count = 10000, decoded 122 MB in 15.379secs using 5.0% overhead, throughput: 63.5Mbit/s
symbol count = 20000, decoded 122 MB in 18.543secs using 5.0% overhead, throughput: 52.7Mbit/s
symbol count = 50000, decoded 122 MB in 32.090secs using 5.0% overhead, throughput: 30.4Mbit/s
symbol count = 10, decoded 127 MB in 14.057secs using 5.0% overhead, throughput: 72.8Mbit/s
symbol count = 100, decoded 127 MB in 10.187secs using 5.0% overhead, throughput: 100.5Mbit/s
symbol count = 250, decoded 127 MB in 9.220secs using 5.0% overhead, throughput: 110.9Mbit/s
symbol count = 500, decoded 127 MB in 9.276secs using 5.0% overhead, throughput: 110.0Mbit/s
symbol count = 1000, decoded 126 MB in 8.117secs using 5.0% overhead, throughput: 125.1Mbit/s
symbol count = 2000, decoded 126 MB in 8.459secs using 5.0% overhead, throughput: 120.1Mbit/s
symbol count = 5000, decoded 122 MB in 8.410secs using 5.0% overhead, throughput: 116.1Mbit/s
symbol count = 10000, decoded 122 MB in 11.370secs using 5.0% overhead, throughput: 85.9Mbit/s
symbol count = 20000, decoded 122 MB in 11.923secs using 5.0% overhead, throughput: 81.9Mbit/s
symbol count = 50000, decoded 122 MB in 17.768secs using 5.0% overhead, throughput: 55.0Mbit/s
```
### Public API

@ -84,6 +84,22 @@ pub fn fused_addassign_mul_scalar_binary(
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "use_neon"))]
{
if is_aarch64_feature_detected!("neon") {
unsafe {
return fused_addassign_mul_scalar_binary_neon(octets, other, scalar);
}
}
}
#[cfg(all(target_arch = "arm", feature = "use_neon"))]
{
if is_arm_feature_detected!("neon") {
unsafe {
return fused_addassign_mul_scalar_binary_neon(octets, other, scalar);
}
}
}
// TODO: write an optimized fallback that does call .to_octet_vec()
if *scalar == Octet::one() {
@ -93,6 +109,79 @@ pub fn fused_addassign_mul_scalar_binary(
}
}
#[cfg(all(
any(target_arch = "arm", target_arch = "aarch64"),
feature = "use_neon"
))]
#[target_feature(enable = "neon")]
unsafe fn fused_addassign_mul_scalar_binary_neon(
octets: &mut [u8],
other: &BinaryOctetVec,
scalar: &Octet,
) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
use std::mem;
let first_bit = other.padding_bits();
let other_u16 = std::slice::from_raw_parts(
other.elements.as_ptr() as *const u16,
other.elements.len() * 4,
);
let mut other_batch_start_index = first_bit / 16;
let first_bits = other_u16[other_batch_start_index];
let bit_in_first_bits = first_bit % 16;
let mut remaining = octets.len();
let mut self_neon_ptr = octets.as_mut_ptr();
// Handle first bits to make remainder 16bit aligned
if bit_in_first_bits > 0 {
for (i, val) in octets.iter_mut().enumerate().take(16 - bit_in_first_bits) {
// TODO: replace with UBFX instruction, once it's support in arm intrinsics
let selected_bit = first_bits & (0x1 << (bit_in_first_bits + i));
let other_byte = if selected_bit == 0 { 0 } else { 1 };
// other_byte is binary, so u8 multiplication is the same as GF256 multiplication
*val ^= scalar.byte() * other_byte;
}
remaining -= 16 - bit_in_first_bits;
other_batch_start_index += 1;
self_neon_ptr = self_neon_ptr.add(16 - bit_in_first_bits);
}
assert_eq!(remaining % 16, 0);
let shuffle_mask = vld1q_u8([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1].as_ptr());
let bit_select_mask = vld1q_u8(
[
1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80,
]
.as_ptr(),
);
let scalar_neon = vdupq_n_u8(scalar.byte());
let other_neon = other_u16.as_ptr();
// Process the rest in 128bit chunks
for i in 0..(remaining / mem::size_of::<uint8x16_t>()) {
// Convert from bit packed u16 to 16xu8
let other_vec = vld1q_dup_u16(other_neon.add(other_batch_start_index + i));
let other_vec: uint8x16_t = mem::transmute(other_vec);
let other_vec = vqtbl1q_u8(other_vec, shuffle_mask);
let other_vec = vandq_u8(other_vec, bit_select_mask);
// The bits are now unpacked, but aren't in a defined position (may be in 0-7 bit of each byte)
// Test non-zero to get one or zero in the correct byte position
let other_vec = vcgeq_u8(other_vec, vdupq_n_u8(1));
// Multiply by scalar. other_vec is binary (0xFF or 0x00), so just mask with the scalar
let product = vandq_u8(other_vec, scalar_neon);
// Add to self
let self_vec = vld1q_u8(self_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let result = veorq_u8(self_vec, product);
store_neon((self_neon_ptr as *mut uint8x16_t).add(i), result);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "bmi1")]
@ -749,8 +838,10 @@ mod tests {
use rand::Rng;
use crate::octet::Octet;
use crate::octets::fused_addassign_mul_scalar;
use crate::octets::mulassign_scalar;
use crate::octets::{
fused_addassign_mul_scalar, fused_addassign_mul_scalar_binary, BinaryOctetVec,
};
#[test]
fn mul_assign() {
@ -785,4 +876,26 @@ mod tests {
assert_eq!(expected, data1);
}
#[test]
fn fma_binary() {
let size = 41;
let scalar = Octet::new(rand::thread_rng().gen_range(2..255));
let mut binary_vec: Vec<u64> = vec![0; (size + 63) / 64];
for i in 0..binary_vec.len() {
binary_vec[i] = rand::thread_rng().gen();
}
let binary_octet_vec = BinaryOctetVec::new(binary_vec, size);
let mut data1: Vec<u8> = vec![0; size];
let data2: Vec<u8> = binary_octet_vec.to_octet_vec();
let mut expected: Vec<u8> = vec![0; size];
for i in 0..size {
data1[i] = rand::thread_rng().gen();
expected[i] = (Octet::new(data1[i]) + &Octet::new(data2[i]) * &scalar).byte();
}
fused_addassign_mul_scalar_binary(&mut data1, &binary_octet_vec, &scalar);
assert_eq!(expected, data1);
}
}