raptorq/src/octets.rs
2024-03-04 19:15:53 -08:00

923 lines
33 KiB
Rust

#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::octet::Octet;
use crate::octet::OCTET_MUL;
#[cfg(all(
any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64",
),
feature = "std"
))]
use crate::octet::OCTET_MUL_HI_BITS;
#[cfg(all(
any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64",
),
feature = "std"
))]
use crate::octet::OCTET_MUL_LOW_BITS;
#[cfg(all(target_arch = "aarch64", feature = "std"))]
use std::arch::is_aarch64_feature_detected;
// An octet vec containing only binary values, which are bit-packed for efficiency
pub struct BinaryOctetVec {
// Values are stored packed into the highest bits, with the last value at the highest bit of the
// last byte. Therefore, there may be trailing bits (least significant) which are unused
elements: Vec<u64>,
length: usize,
}
impl BinaryOctetVec {
pub(crate) const WORD_WIDTH: usize = 64;
pub fn new(elements: Vec<u64>, length: usize) -> Self {
assert_eq!(
elements.len(),
(length + Self::WORD_WIDTH - 1) / Self::WORD_WIDTH
);
BinaryOctetVec { elements, length }
}
pub fn len(&self) -> usize {
self.length
}
fn to_octet_vec(&self) -> Vec<u8> {
let mut word = 0;
let mut bit = self.padding_bits();
let result = (0..self.length)
.map(|_| {
let value = if self.elements[word] & BinaryOctetVec::select_mask(bit) == 0 {
0
} else {
1
};
bit += 1;
if bit == 64 {
word += 1;
bit = 0;
}
value
})
.collect();
assert_eq!(word, self.elements.len());
assert_eq!(bit, 0);
result
}
pub fn padding_bits(&self) -> usize {
(BinaryOctetVec::WORD_WIDTH - (self.length % BinaryOctetVec::WORD_WIDTH))
% BinaryOctetVec::WORD_WIDTH
}
pub fn select_mask(bit: usize) -> u64 {
1u64 << (bit as u64)
}
}
pub fn fused_addassign_mul_scalar_binary(
octets: &mut [u8],
other: &BinaryOctetVec,
scalar: &Octet,
) {
debug_assert_ne!(
*scalar,
Octet::zero(),
"Don't call with zero. It's very inefficient"
);
assert_eq!(octets.len(), other.len());
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("bmi1") {
unsafe {
return fused_addassign_mul_scalar_binary_avx2(octets, other, scalar);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "std"))]
{
if is_aarch64_feature_detected!("neon") {
unsafe {
return fused_addassign_mul_scalar_binary_neon(octets, other, scalar);
}
}
}
#[cfg(all(target_arch = "arm", feature = "std"))]
{
// TODO: enable when stable
// if is_arm_feature_detected!("neon") {
// unsafe {
// return fused_addassign_mul_scalar_binary_neon(octets, other, scalar);
// }
// }
}
// TODO: write an optimized fallback that does call .to_octet_vec()
if *scalar == Octet::one() {
return add_assign(octets, &other.to_octet_vec());
} else {
return fused_addassign_mul_scalar(octets, &other.to_octet_vec(), scalar);
}
}
#[cfg(target_arch = "aarch64")]
// TODO: enable when stable
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// #[target_feature(enable = "neon")]
unsafe fn fused_addassign_mul_scalar_binary_neon(
octets: &mut [u8],
other: &BinaryOctetVec,
scalar: &Octet,
) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
use std::mem;
let first_bit = other.padding_bits();
let other_u16 = std::slice::from_raw_parts(
other.elements.as_ptr() as *const u16,
other.elements.len() * 4,
);
let mut other_batch_start_index = first_bit / 16;
let first_bits = other_u16[other_batch_start_index];
let bit_in_first_bits = first_bit % 16;
let mut remaining = octets.len();
let mut self_neon_ptr = octets.as_mut_ptr();
// Handle first bits to make remainder 16bit aligned
if bit_in_first_bits > 0 {
for (i, val) in octets.iter_mut().enumerate().take(16 - bit_in_first_bits) {
// TODO: replace with UBFX instruction, once it's support in arm intrinsics
let selected_bit = first_bits & (0x1 << (bit_in_first_bits + i));
let other_byte = if selected_bit == 0 { 0 } else { 1 };
// other_byte is binary, so u8 multiplication is the same as GF256 multiplication
*val ^= scalar.byte() * other_byte;
}
remaining -= 16 - bit_in_first_bits;
other_batch_start_index += 1;
self_neon_ptr = self_neon_ptr.add(16 - bit_in_first_bits);
}
assert_eq!(remaining % 16, 0);
let shuffle_mask = vld1q_u8([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1].as_ptr());
let bit_select_mask = vld1q_u8(
[
1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80,
]
.as_ptr(),
);
let scalar_neon = vdupq_n_u8(scalar.byte());
let other_neon = other_u16.as_ptr();
// Process the rest in 128bit chunks
for i in 0..(remaining / mem::size_of::<uint8x16_t>()) {
// Convert from bit packed u16 to 16xu8
let other_vec = vld1q_dup_u16(other_neon.add(other_batch_start_index + i));
let other_vec: uint8x16_t = mem::transmute(other_vec);
let other_vec = vqtbl1q_u8(other_vec, shuffle_mask);
let other_vec = vandq_u8(other_vec, bit_select_mask);
// The bits are now unpacked, but aren't in a defined position (may be in 0-7 bit of each byte)
// Test non-zero to get one or zero in the correct byte position
let other_vec = vcgeq_u8(other_vec, vdupq_n_u8(1));
// Multiply by scalar. other_vec is binary (0xFF or 0x00), so just mask with the scalar
let product = vandq_u8(other_vec, scalar_neon);
// Add to self
let self_vec = vld1q_u8(self_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let result = veorq_u8(self_vec, product);
store_neon((self_neon_ptr as *mut uint8x16_t).add(i), result);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "bmi1")]
unsafe fn fused_addassign_mul_scalar_binary_avx2(
octets: &mut [u8],
other: &BinaryOctetVec,
scalar: &Octet,
) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let first_bit = other.padding_bits();
let other_u32 = std::slice::from_raw_parts(
other.elements.as_ptr() as *const u32,
other.elements.len() * 2,
);
let mut other_batch_start_index = first_bit / 32;
let first_bits = other_u32[other_batch_start_index];
let bit_in_first_bits = first_bit % 32;
let mut remaining = octets.len();
let mut self_avx_ptr = octets.as_mut_ptr();
// Handle first bits to make remainder 32bit aligned
if bit_in_first_bits > 0 {
let control = bit_in_first_bits as u32 | 0x100;
for (i, val) in octets.iter_mut().enumerate().take(32 - bit_in_first_bits) {
let other_byte = _bextr2_u32(first_bits, control + i as u32) as u8;
// other_byte is binary, so u8 multiplication is the same as GF256 multiplication
*val ^= scalar.byte() * other_byte;
}
remaining -= 32 - bit_in_first_bits;
other_batch_start_index += 1;
self_avx_ptr = self_avx_ptr.add(32 - bit_in_first_bits);
}
assert_eq!(remaining % 32, 0);
// See: https://stackoverflow.com/questions/24225786/fastest-way-to-unpack-32-bits-to-a-32-byte-simd-vector
let shuffle_mask = _mm256_set_epi64x(
0x0303_0303_0303_0303,
0x0202_0202_0202_0202,
0x0101_0101_0101_0101,
0,
);
let bit_select_mask = _mm256_set1_epi64x(0x8040_2010_0804_0201u64 as i64);
let scalar_avx = _mm256_set1_epi8(scalar.byte() as i8);
// Process the rest in 256bit chunks
for i in 0..(remaining / 32) {
// Convert from bit packed u32 to 32xu8
let other_vec = _mm256_set1_epi32(other_u32[other_batch_start_index + i] as i32);
let other_vec = _mm256_shuffle_epi8(other_vec, shuffle_mask);
let other_vec = _mm256_andnot_si256(other_vec, bit_select_mask);
// The bits are now unpacked, but aren't in a defined position (may be in 0-7 bit of each byte),
// and are inverted (due to AND NOT)
// Test against zero to get non-inverted selection
let other_vec = _mm256_cmpeq_epi8(other_vec, _mm256_setzero_si256());
// Multiply by scalar. other_vec is binary (0xFF or 0x00), so just mask with the scalar
let product = _mm256_and_si256(other_vec, scalar_avx);
// Add to self
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm256_loadu_si256((self_avx_ptr as *const __m256i).add(i));
let result = _mm256_xor_si256(self_vec, product);
#[allow(clippy::cast_ptr_alignment)]
_mm256_storeu_si256((self_avx_ptr as *mut __m256i).add(i), result);
}
}
fn mulassign_scalar_fallback(octets: &mut [u8], scalar: &Octet) {
let scalar_index = usize::from(scalar.byte());
for item in octets {
let octet_index = usize::from(*item);
// SAFETY: `OCTET_MUL` is a 256x256 matrix, both indexes are `u8` inputs.
*item = unsafe {
*OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(octet_index)
};
}
}
#[cfg(target_arch = "aarch64")]
// TODO: enable when stable
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// #[target_feature(enable = "neon")]
unsafe fn mulassign_scalar_neon(octets: &mut [u8], scalar: &Octet) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
use std::mem;
let mask = vdupq_n_u8(0x0F);
let self_neon_ptr = octets.as_mut_ptr();
#[allow(clippy::cast_ptr_alignment)]
let low_table = vld1q_u8(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr());
#[allow(clippy::cast_ptr_alignment)]
let hi_table = vld1q_u8(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr());
for i in 0..(octets.len() / mem::size_of::<uint8x16_t>()) {
// Multiply by scalar
#[allow(clippy::cast_ptr_alignment)]
let self_vec = vld1q_u8(self_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let low = vandq_u8(self_vec, mask);
let low_result = vqtbl1q_u8(low_table, low);
let hi = vshrq_n_u8(self_vec, 4);
let hi = vandq_u8(hi, mask);
let hi_result = vqtbl1q_u8(hi_table, hi);
let result = veorq_u8(hi_result, low_result);
store_neon((self_neon_ptr as *mut uint8x16_t).add(i), result);
}
let remainder = octets.len() % mem::size_of::<uint8x16_t>();
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) = *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*octets.get_unchecked(i) as usize);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "avx2")]
unsafe fn mulassign_scalar_avx2(octets: &mut [u8], scalar: &Octet) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let low_mask = _mm256_set1_epi8(0x0F);
let hi_mask = _mm256_set1_epi8(0xF0u8 as i8);
let self_avx_ptr = octets.as_mut_ptr();
// Safe because _mm256_loadu_si256 loads from unaligned memory
#[allow(clippy::cast_ptr_alignment)]
let low_table =
_mm256_loadu_si256(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr() as *const __m256i);
// Safe because _mm256_loadu_si256 loads from unaligned memory
#[allow(clippy::cast_ptr_alignment)]
let hi_table =
_mm256_loadu_si256(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr() as *const __m256i);
for i in 0..(octets.len() / 32) {
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm256_loadu_si256((self_avx_ptr as *const __m256i).add(i));
let low = _mm256_and_si256(self_vec, low_mask);
let low_result = _mm256_shuffle_epi8(low_table, low);
let hi = _mm256_and_si256(self_vec, hi_mask);
let hi = _mm256_srli_epi64(hi, 4);
let hi_result = _mm256_shuffle_epi8(hi_table, hi);
let result = _mm256_xor_si256(hi_result, low_result);
#[allow(clippy::cast_ptr_alignment)]
_mm256_storeu_si256((self_avx_ptr as *mut __m256i).add(i), result);
}
let remainder = octets.len() % 32;
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) = *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*octets.get_unchecked(i) as usize);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "ssse3")]
unsafe fn mulassign_scalar_ssse3(octets: &mut [u8], scalar: &Octet) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let low_mask = _mm_set1_epi8(0x0F);
let hi_mask = _mm_set1_epi8(0xF0u8 as i8);
let self_ssse_ptr = octets.as_mut_ptr();
#[allow(clippy::cast_ptr_alignment)]
let low_table =
_mm_loadu_si128(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr() as *const __m128i);
#[allow(clippy::cast_ptr_alignment)]
let hi_table =
_mm_loadu_si128(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr() as *const __m128i);
for i in 0..(octets.len() / 16) {
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm_loadu_si128((self_ssse_ptr as *const __m128i).add(i));
let low = _mm_and_si128(self_vec, low_mask);
let low_result = _mm_shuffle_epi8(low_table, low);
let hi = _mm_and_si128(self_vec, hi_mask);
let hi = _mm_srli_epi64(hi, 4);
let hi_result = _mm_shuffle_epi8(hi_table, hi);
let result = _mm_xor_si128(hi_result, low_result);
#[allow(clippy::cast_ptr_alignment)]
_mm_storeu_si128((self_ssse_ptr as *mut __m128i).add(i), result);
}
let remainder = octets.len() % 16;
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) = *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*octets.get_unchecked(i) as usize);
}
}
pub fn mulassign_scalar(octets: &mut [u8], scalar: &Octet) {
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return mulassign_scalar_avx2(octets, scalar);
}
}
if is_x86_feature_detected!("ssse3") {
unsafe {
return mulassign_scalar_ssse3(octets, scalar);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "std"))]
{
if is_aarch64_feature_detected!("neon") {
unsafe {
return mulassign_scalar_neon(octets, scalar);
}
}
}
#[cfg(all(target_arch = "arm", feature = "std"))]
{
// TODO: enable when stable
// if is_arm_feature_detected!("neon") {
// unsafe {
// return mulassign_scalar_neon(octets, scalar);
// }
// }
}
return mulassign_scalar_fallback(octets, scalar);
}
fn fused_addassign_mul_scalar_fallback(octets: &mut [u8], other: &[u8], scalar: &Octet) {
let scalar_index = scalar.byte() as usize;
for i in 0..octets.len() {
unsafe {
*octets.get_unchecked_mut(i) ^= *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*other.get_unchecked(i) as usize);
}
}
}
#[cfg(target_arch = "aarch64")]
// TODO: enable when stable
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// #[target_feature(enable = "neon")]
unsafe fn fused_addassign_mul_scalar_neon(octets: &mut [u8], other: &[u8], scalar: &Octet) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
use std::mem;
let mask = vdupq_n_u8(0x0F);
let self_neon_ptr = octets.as_mut_ptr();
let other_neon_ptr = other.as_ptr();
#[allow(clippy::cast_ptr_alignment)]
let low_table = vld1q_u8(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr());
#[allow(clippy::cast_ptr_alignment)]
let hi_table = vld1q_u8(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr());
for i in 0..(octets.len() / mem::size_of::<uint8x16_t>()) {
// Multiply by scalar
#[allow(clippy::cast_ptr_alignment)]
let other_vec = vld1q_u8(other_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let low = vandq_u8(other_vec, mask);
let low_result = vqtbl1q_u8(low_table, low);
let hi = vshrq_n_u8(other_vec, 4);
let hi = vandq_u8(hi, mask);
let hi_result = vqtbl1q_u8(hi_table, hi);
let other_vec = veorq_u8(hi_result, low_result);
// Add to self
#[allow(clippy::cast_ptr_alignment)]
let self_vec = vld1q_u8(self_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let result = veorq_u8(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
store_neon((self_neon_ptr as *mut uint8x16_t).add(i), result);
}
let remainder = octets.len() % mem::size_of::<uint8x16_t>();
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*other.get_unchecked(i) as usize);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "avx2")]
unsafe fn fused_addassign_mul_scalar_avx2(octets: &mut [u8], other: &[u8], scalar: &Octet) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let low_mask = _mm256_set1_epi8(0x0F);
let hi_mask = _mm256_set1_epi8(0xF0u8 as i8);
let self_avx_ptr = octets.as_mut_ptr();
let other_avx_ptr = other.as_ptr();
// Safe because _mm256_loadu_si256 loads from unaligned memory
#[allow(clippy::cast_ptr_alignment)]
let low_table =
_mm256_loadu_si256(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr() as *const __m256i);
// Safe because _mm256_loadu_si256 loads from unaligned memory
#[allow(clippy::cast_ptr_alignment)]
let hi_table =
_mm256_loadu_si256(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr() as *const __m256i);
for i in 0..(octets.len() / 32) {
// Multiply by scalar
#[allow(clippy::cast_ptr_alignment)]
let other_vec = _mm256_loadu_si256((other_avx_ptr as *const __m256i).add(i));
let low = _mm256_and_si256(other_vec, low_mask);
let low_result = _mm256_shuffle_epi8(low_table, low);
let hi = _mm256_and_si256(other_vec, hi_mask);
let hi = _mm256_srli_epi64(hi, 4);
let hi_result = _mm256_shuffle_epi8(hi_table, hi);
let other_vec = _mm256_xor_si256(hi_result, low_result);
// Add to self
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm256_loadu_si256((self_avx_ptr as *const __m256i).add(i));
let result = _mm256_xor_si256(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
_mm256_storeu_si256((self_avx_ptr as *mut __m256i).add(i), result);
}
let remainder = octets.len() % 32;
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*other.get_unchecked(i) as usize);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "ssse3")]
unsafe fn fused_addassign_mul_scalar_ssse3(octets: &mut [u8], other: &[u8], scalar: &Octet) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let low_mask = _mm_set1_epi8(0x0F);
let hi_mask = _mm_set1_epi8(0xF0u8 as i8);
let self_ssse_ptr = octets.as_mut_ptr();
let other_ssse_ptr = other.as_ptr();
#[allow(clippy::cast_ptr_alignment)]
let low_table =
_mm_loadu_si128(OCTET_MUL_LOW_BITS[scalar.byte() as usize].as_ptr() as *const __m128i);
#[allow(clippy::cast_ptr_alignment)]
let hi_table =
_mm_loadu_si128(OCTET_MUL_HI_BITS[scalar.byte() as usize].as_ptr() as *const __m128i);
for i in 0..(octets.len() / 16) {
// Multiply by scalar
#[allow(clippy::cast_ptr_alignment)]
let other_vec = _mm_loadu_si128((other_ssse_ptr as *const __m128i).add(i));
let low = _mm_and_si128(other_vec, low_mask);
let low_result = _mm_shuffle_epi8(low_table, low);
let hi = _mm_and_si128(other_vec, hi_mask);
let hi = _mm_srli_epi64(hi, 4);
let hi_result = _mm_shuffle_epi8(hi_table, hi);
let other_vec = _mm_xor_si128(hi_result, low_result);
// Add to self
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm_loadu_si128((self_ssse_ptr as *const __m128i).add(i));
let result = _mm_xor_si128(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
_mm_storeu_si128((self_ssse_ptr as *mut __m128i).add(i), result);
}
let remainder = octets.len() % 16;
let scalar_index = scalar.byte() as usize;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= *OCTET_MUL
.get_unchecked(scalar_index)
.get_unchecked(*other.get_unchecked(i) as usize);
}
}
pub fn fused_addassign_mul_scalar(octets: &mut [u8], other: &[u8], scalar: &Octet) {
debug_assert_ne!(
*scalar,
Octet::one(),
"Don't call this with one. Use += instead"
);
debug_assert_ne!(
*scalar,
Octet::zero(),
"Don't call with zero. It's very inefficient"
);
assert_eq!(octets.len(), other.len());
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return fused_addassign_mul_scalar_avx2(octets, other, scalar);
}
}
if is_x86_feature_detected!("ssse3") {
unsafe {
return fused_addassign_mul_scalar_ssse3(octets, other, scalar);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "std"))]
{
if is_aarch64_feature_detected!("neon") {
unsafe {
return fused_addassign_mul_scalar_neon(octets, other, scalar);
}
}
}
#[cfg(all(target_arch = "arm", feature = "std"))]
{
// TODO: enable when stable
// if is_arm_feature_detected!("neon") {
// unsafe {
// return fused_addassign_mul_scalar_neon(octets, other, scalar);
// }
// }
}
return fused_addassign_mul_scalar_fallback(octets, other, scalar);
}
fn add_assign_fallback(octets: &mut [u8], other: &[u8]) {
assert_eq!(octets.len(), other.len());
let self_ptr = octets.as_mut_ptr();
let other_ptr = other.as_ptr();
for i in 0..(octets.len() / 8) {
unsafe {
#[allow(clippy::cast_ptr_alignment)]
let self_value = (self_ptr as *const u64).add(i).read_unaligned();
#[allow(clippy::cast_ptr_alignment)]
let other_value = (other_ptr as *const u64).add(i).read_unaligned();
let result = self_value ^ other_value;
#[allow(clippy::cast_ptr_alignment)]
(self_ptr as *mut u64).add(i).write_unaligned(result);
}
}
let remainder = octets.len() % 8;
for i in (octets.len() - remainder)..octets.len() {
unsafe {
*octets.get_unchecked_mut(i) ^= other.get_unchecked(i);
}
}
}
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::uint8x16_t;
// TODO: enable when stable
// #[cfg(target_arch = "arm")]
// use std::arch::arm::uint8x16_t;
#[cfg(target_arch = "aarch64")]
// TODO: enable when stable
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// #[target_feature(enable = "neon")]
unsafe fn store_neon(ptr: *mut uint8x16_t, value: uint8x16_t) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
// TODO: replace with vst1q_u8 when it's supported
let reinterp = vreinterpretq_u64_u8(value);
(ptr as *mut u64).write_unaligned(vgetq_lane_u64(reinterp, 0));
(ptr as *mut u64)
.add(1)
.write_unaligned(vgetq_lane_u64(reinterp, 1));
}
#[cfg(all(target_arch = "aarch64", feature = "std"))]
// TODO: enable when stable
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// #[target_feature(enable = "neon")]
unsafe fn add_assign_neon(octets: &mut [u8], other: &[u8]) {
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
use std::mem;
assert_eq!(octets.len(), other.len());
let self_neon_ptr = octets.as_mut_ptr();
let other_neon_ptr = other.as_ptr();
for i in 0..(octets.len() / 16) {
#[allow(clippy::cast_ptr_alignment)]
let self_vec = vld1q_u8(self_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
#[allow(clippy::cast_ptr_alignment)]
let other_vec = vld1q_u8(other_neon_ptr.add(i * mem::size_of::<uint8x16_t>()));
let result = veorq_u8(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
store_neon((self_neon_ptr as *mut uint8x16_t).add(i), result);
}
let remainder = octets.len() % 16;
let self_ptr = octets.as_mut_ptr();
let other_ptr = other.as_ptr();
for i in ((octets.len() - remainder) / 8)..(octets.len() / 8) {
#[allow(clippy::cast_ptr_alignment)]
let self_value = (self_ptr as *mut u64).add(i).read_unaligned();
#[allow(clippy::cast_ptr_alignment)]
let other_value = (other_ptr as *mut u64).add(i).read_unaligned();
let result = self_value ^ other_value;
#[allow(clippy::cast_ptr_alignment)]
(self_ptr as *mut u64).add(i).write_unaligned(result);
}
let remainder = octets.len() % 8;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= other.get_unchecked(i);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "avx2")]
unsafe fn add_assign_avx2(octets: &mut [u8], other: &[u8]) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
assert_eq!(octets.len(), other.len());
let self_avx_ptr = octets.as_mut_ptr();
let other_avx_ptr = other.as_ptr();
for i in 0..(octets.len() / 32) {
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm256_loadu_si256((self_avx_ptr as *const __m256i).add(i));
#[allow(clippy::cast_ptr_alignment)]
let other_vec = _mm256_loadu_si256((other_avx_ptr as *const __m256i).add(i));
let result = _mm256_xor_si256(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
_mm256_storeu_si256((self_avx_ptr as *mut __m256i).add(i), result);
}
let remainder = octets.len() % 32;
let self_ptr = octets.as_mut_ptr();
let other_ptr = other.as_ptr();
for i in ((octets.len() - remainder) / 8)..(octets.len() / 8) {
#[allow(clippy::cast_ptr_alignment)]
let self_value = (self_ptr as *mut u64).add(i).read_unaligned();
#[allow(clippy::cast_ptr_alignment)]
let other_value = (other_ptr as *mut u64).add(i).read_unaligned();
let result = self_value ^ other_value;
#[allow(clippy::cast_ptr_alignment)]
(self_ptr as *mut u64).add(i).write_unaligned(result);
}
let remainder = octets.len() % 8;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= other.get_unchecked(i);
}
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
#[target_feature(enable = "ssse3")]
unsafe fn add_assign_ssse3(octets: &mut [u8], other: &[u8]) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
assert_eq!(octets.len(), other.len());
let self_ssse_ptr = octets.as_mut_ptr();
let other_ssse_ptr = other.as_ptr();
for i in 0..(octets.len() / 16) {
#[allow(clippy::cast_ptr_alignment)]
let self_vec = _mm_loadu_si128((self_ssse_ptr as *const __m128i).add(i));
#[allow(clippy::cast_ptr_alignment)]
let other_vec = _mm_loadu_si128((other_ssse_ptr as *const __m128i).add(i));
let result = _mm_xor_si128(self_vec, other_vec);
#[allow(clippy::cast_ptr_alignment)]
_mm_storeu_si128((self_ssse_ptr as *mut __m128i).add(i), result);
}
let remainder = octets.len() % 16;
let self_ptr = octets.as_mut_ptr();
let other_ptr = other.as_ptr();
for i in ((octets.len() - remainder) / 8)..(octets.len() / 8) {
#[allow(clippy::cast_ptr_alignment)]
let self_value = (self_ptr as *mut u64).add(i).read_unaligned();
#[allow(clippy::cast_ptr_alignment)]
let other_value = (other_ptr as *mut u64).add(i).read_unaligned();
let result = self_value ^ other_value;
#[allow(clippy::cast_ptr_alignment)]
(self_ptr as *mut u64).add(i).write_unaligned(result);
}
let remainder = octets.len() % 8;
for i in (octets.len() - remainder)..octets.len() {
*octets.get_unchecked_mut(i) ^= other.get_unchecked(i);
}
}
pub fn add_assign(octets: &mut [u8], other: &[u8]) {
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return add_assign_avx2(octets, other);
}
}
if is_x86_feature_detected!("ssse3") {
unsafe {
return add_assign_ssse3(octets, other);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "std"))]
{
if is_aarch64_feature_detected!("neon") {
unsafe {
return add_assign_neon(octets, other);
}
}
}
#[cfg(all(target_arch = "arm", feature = "std"))]
{
// TODO: enable when stable
// if is_arm_feature_detected!("neon") {
// unsafe {
// return add_assign_neon(octets, other);
// }
// }
}
return add_assign_fallback(octets, other);
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests {
use rand::Rng;
use std::vec::Vec;
use crate::octet::Octet;
use crate::octets::mulassign_scalar;
use crate::octets::{
fused_addassign_mul_scalar, fused_addassign_mul_scalar_binary, BinaryOctetVec,
};
#[test]
fn mul_assign() {
let size = 41;
let scalar = Octet::new(rand::thread_rng().gen_range(1..255));
let mut data1: Vec<u8> = vec![0; size];
let mut expected: Vec<u8> = vec![0; size];
for i in 0..size {
data1[i] = rand::thread_rng().gen();
expected[i] = (&Octet::new(data1[i]) * &scalar).byte();
}
mulassign_scalar(&mut data1, &scalar);
assert_eq!(expected, data1);
}
#[test]
fn fma() {
let size = 41;
let scalar = Octet::new(rand::thread_rng().gen_range(2..255));
let mut data1: Vec<u8> = vec![0; size];
let mut data2: Vec<u8> = vec![0; size];
let mut expected: Vec<u8> = vec![0; size];
for i in 0..size {
data1[i] = rand::thread_rng().gen();
data2[i] = rand::thread_rng().gen();
expected[i] = (Octet::new(data1[i]) + &Octet::new(data2[i]) * &scalar).byte();
}
fused_addassign_mul_scalar(&mut data1, &data2, &scalar);
assert_eq!(expected, data1);
}
#[test]
fn fma_binary() {
let size = 41;
let scalar = Octet::new(rand::thread_rng().gen_range(2..255));
let mut binary_vec: Vec<u64> = vec![0; (size + 63) / 64];
for i in 0..binary_vec.len() {
binary_vec[i] = rand::thread_rng().gen();
}
let binary_octet_vec = BinaryOctetVec::new(binary_vec, size);
let mut data1: Vec<u8> = vec![0; size];
let data2: Vec<u8> = binary_octet_vec.to_octet_vec();
let mut expected: Vec<u8> = vec![0; size];
for i in 0..size {
data1[i] = rand::thread_rng().gen();
expected[i] = (Octet::new(data1[i]) + &Octet::new(data2[i]) * &scalar).byte();
}
fused_addassign_mul_scalar_binary(&mut data1, &binary_octet_vec, &scalar);
assert_eq!(expected, data1);
}
}