From 102c6a5a867acfd81a0a547500ae3bdb95aab57f Mon Sep 17 00:00:00 2001 From: Christopher Berner Date: Sun, 6 Dec 2020 21:22:30 -0800 Subject: [PATCH] Optimize DenseBinaryMatrix Switch to a single contiguous vector instead of vec of vecs This improves performance by ~5%, especially for smaller symbol counts --- README.md | 60 ++++++++++----------- src/iterators.rs | 49 ++++++++++++----- src/matrix.rs | 137 ++++++++++++++++++++++++++++++++++------------- src/util.rs | 21 ++++++++ 4 files changed, 186 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index e7e0e97..eda9ae0 100644 --- a/README.md +++ b/README.md @@ -23,16 +23,16 @@ The following were run on an Intel Core i5-6600K @ 3.50GHz ``` Symbol size: 1280 bytes (without pre-built plan) -symbol count = 10, encoded 127 MB in 0.484secs, throughput: 2115.5Mbit/s -symbol count = 100, encoded 127 MB in 0.509secs, throughput: 2010.7Mbit/s -symbol count = 250, encoded 127 MB in 0.482secs, throughput: 2122.3Mbit/s -symbol count = 500, encoded 127 MB in 0.463secs, throughput: 2204.1Mbit/s -symbol count = 1000, encoded 126 MB in 0.492secs, throughput: 2064.3Mbit/s -symbol count = 2000, encoded 126 MB in 0.565secs, throughput: 1797.6Mbit/s -symbol count = 5000, encoded 122 MB in 0.594secs, throughput: 1644.0Mbit/s -symbol count = 10000, encoded 122 MB in 0.716secs, throughput: 1363.9Mbit/s -symbol count = 20000, encoded 122 MB in 1.059secs, throughput: 922.2Mbit/s -symbol count = 50000, encoded 122 MB in 1.508secs, throughput: 647.6Mbit/s +symbol count = 10, encoded 127 MB in 0.465secs, throughput: 2202.0Mbit/s +symbol count = 100, encoded 127 MB in 0.483secs, throughput: 2118.9Mbit/s +symbol count = 250, encoded 127 MB in 0.474secs, throughput: 2158.1Mbit/s +symbol count = 500, encoded 127 MB in 0.460secs, throughput: 2218.5Mbit/s +symbol count = 1000, encoded 126 MB in 0.490secs, throughput: 2072.7Mbit/s +symbol count = 2000, encoded 126 MB in 0.562secs, throughput: 1807.2Mbit/s +symbol count = 5000, encoded 122 MB in 0.578secs, throughput: 1689.6Mbit/s +symbol count = 10000, encoded 122 MB in 0.687secs, throughput: 1421.5Mbit/s +symbol count = 20000, encoded 122 MB in 1.019secs, throughput: 958.4Mbit/s +symbol count = 50000, encoded 122 MB in 1.432secs, throughput: 682.0Mbit/s Symbol size: 1280 bytes (with pre-built plan) symbol count = 10, encoded 127 MB in 0.220secs, throughput: 4654.2Mbit/s @@ -47,27 +47,27 @@ symbol count = 20000, encoded 122 MB in 0.427secs, throughput: 2287.0Mbit/s symbol count = 50000, encoded 122 MB in 0.540secs, throughput: 1808.4Mbit/s Symbol size: 1280 bytes -symbol count = 10, decoded 127 MB in 0.706secs using 0.0% overhead, throughput: 1450.3Mbit/s -symbol count = 100, decoded 127 MB in 0.619secs using 0.0% overhead, throughput: 1653.4Mbit/s -symbol count = 250, decoded 127 MB in 0.568secs using 0.0% overhead, throughput: 1801.0Mbit/s -symbol count = 500, decoded 127 MB in 0.560secs using 0.0% overhead, throughput: 1822.3Mbit/s -symbol count = 1000, decoded 126 MB in 0.601secs using 0.0% overhead, throughput: 1689.9Mbit/s -symbol count = 2000, decoded 126 MB in 0.670secs using 0.0% overhead, throughput: 1515.9Mbit/s -symbol count = 5000, decoded 122 MB in 0.767secs using 0.0% overhead, throughput: 1273.2Mbit/s -symbol count = 10000, decoded 122 MB in 0.970secs using 0.0% overhead, throughput: 1006.8Mbit/s -symbol count = 20000, decoded 122 MB in 1.222secs using 0.0% overhead, throughput: 799.2Mbit/s -symbol count = 50000, decoded 122 MB in 2.046secs using 0.0% overhead, throughput: 477.3Mbit/s +symbol count = 10, decoded 127 MB in 0.679secs using 0.0% overhead, throughput: 1508.0Mbit/s +symbol count = 100, decoded 127 MB in 0.583secs using 0.0% overhead, throughput: 1755.5Mbit/s +symbol count = 250, decoded 127 MB in 0.564secs using 0.0% overhead, throughput: 1813.7Mbit/s +symbol count = 500, decoded 127 MB in 0.539secs using 0.0% overhead, throughput: 1893.3Mbit/s +symbol count = 1000, decoded 126 MB in 0.571secs using 0.0% overhead, throughput: 1778.7Mbit/s +symbol count = 2000, decoded 126 MB in 0.708secs using 0.0% overhead, throughput: 1434.5Mbit/s +symbol count = 5000, decoded 122 MB in 0.769secs using 0.0% overhead, throughput: 1269.9Mbit/s +symbol count = 10000, decoded 122 MB in 0.902secs using 0.0% overhead, throughput: 1082.7Mbit/s +symbol count = 20000, decoded 122 MB in 1.135secs using 0.0% overhead, throughput: 860.4Mbit/s +symbol count = 50000, decoded 122 MB in 1.929secs using 0.0% overhead, throughput: 506.3Mbit/s -symbol count = 10, decoded 127 MB in 0.698secs using 5.0% overhead, throughput: 1466.9Mbit/s -symbol count = 100, decoded 127 MB in 0.617secs using 5.0% overhead, throughput: 1658.7Mbit/s -symbol count = 250, decoded 127 MB in 0.565secs using 5.0% overhead, throughput: 1810.5Mbit/s -symbol count = 500, decoded 127 MB in 0.545secs using 5.0% overhead, throughput: 1872.5Mbit/s -symbol count = 1000, decoded 126 MB in 0.563secs using 5.0% overhead, throughput: 1804.0Mbit/s -symbol count = 2000, decoded 126 MB in 0.599secs using 5.0% overhead, throughput: 1695.5Mbit/s -symbol count = 5000, decoded 122 MB in 0.689secs using 5.0% overhead, throughput: 1417.4Mbit/s -symbol count = 10000, decoded 122 MB in 0.881secs using 5.0% overhead, throughput: 1108.5Mbit/s -symbol count = 20000, decoded 122 MB in 1.117secs using 5.0% overhead, throughput: 874.3Mbit/s -symbol count = 50000, decoded 122 MB in 1.848secs using 5.0% overhead, throughput: 528.4Mbit/s +symbol count = 10, decoded 127 MB in 0.669secs using 5.0% overhead, throughput: 1530.5Mbit/s +symbol count = 100, decoded 127 MB in 0.582secs using 5.0% overhead, throughput: 1758.5Mbit/s +symbol count = 250, decoded 127 MB in 0.550secs using 5.0% overhead, throughput: 1859.9Mbit/s +symbol count = 500, decoded 127 MB in 0.520secs using 5.0% overhead, throughput: 1962.5Mbit/s +symbol count = 1000, decoded 126 MB in 0.548secs using 5.0% overhead, throughput: 1853.3Mbit/s +symbol count = 2000, decoded 126 MB in 0.582secs using 5.0% overhead, throughput: 1745.1Mbit/s +symbol count = 5000, decoded 122 MB in 0.658secs using 5.0% overhead, throughput: 1484.1Mbit/s +symbol count = 10000, decoded 122 MB in 0.835secs using 5.0% overhead, throughput: 1169.5Mbit/s +symbol count = 20000, decoded 122 MB in 1.105secs using 5.0% overhead, throughput: 883.8Mbit/s +symbol count = 50000, decoded 122 MB in 1.784secs using 5.0% overhead, throughput: 547.4Mbit/s ``` ### Public API diff --git a/src/iterators.rs b/src/iterators.rs index b89af5d..387654c 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -8,6 +8,8 @@ pub struct ClonedOctetIter { end_col: usize, dense_elements: Option>, dense_index: usize, + dense_word_index: usize, + dense_bit_index: usize, sparse_elements: Option>, sparse_index: usize, } @@ -29,16 +31,20 @@ impl Iterator for ClonedOctetIter { return None; } else { let old_index = self.dense_index; - self.dense_index += 1; - let (word, bit) = DenseBinaryMatrix::bit_position(old_index); - let value = if self.dense_elements.as_ref().unwrap()[word] - & DenseBinaryMatrix::select_mask(bit) + let value = if self.dense_elements.as_ref().unwrap()[self.dense_word_index] + & DenseBinaryMatrix::select_mask(self.dense_bit_index) == 0 { Octet::zero() } else { Octet::one() }; + self.dense_index += 1; + self.dense_bit_index += 1; + if self.dense_bit_index == 64 { + self.dense_bit_index = 0; + self.dense_word_index += 1; + } return Some((old_index, value)); } } @@ -49,8 +55,10 @@ pub struct OctetIter<'a> { sparse: bool, start_col: usize, end_col: usize, - dense_elements: Option<&'a Vec>, + dense_elements: Option<&'a [u64]>, dense_index: usize, + dense_word_index: usize, + dense_bit_index: usize, sparse_elements: Option<&'a SparseBinaryVec>, sparse_index: usize, sparse_physical_col_to_logical: Option<&'a [u16]>, @@ -69,6 +77,8 @@ impl<'a> OctetIter<'a> { end_col, dense_elements: None, dense_index: 0, + dense_word_index: 0, + dense_bit_index: 0, sparse_elements: Some(sparse_elements), sparse_index: 0, sparse_physical_col_to_logical: Some(sparse_physical_col_to_logical), @@ -79,7 +89,8 @@ impl<'a> OctetIter<'a> { pub fn new_dense_binary( start_col: usize, end_col: usize, - dense_elements: &'a Vec, + start_bit: usize, + dense_elements: &'a [u64], ) -> OctetIter<'a> { OctetIter { sparse: false, @@ -87,6 +98,8 @@ impl<'a> OctetIter<'a> { end_col, dense_elements: Some(dense_elements), dense_index: start_col, + dense_word_index: 0, + dense_bit_index: start_bit, sparse_elements: None, sparse_index: 0, sparse_physical_col_to_logical: None, @@ -111,8 +124,10 @@ impl<'a> OctetIter<'a> { ClonedOctetIter { sparse: self.sparse, end_col: self.end_col, - dense_elements: self.dense_elements.cloned(), + dense_elements: self.dense_elements.map(|x| x.to_vec()), dense_index: self.dense_index, + dense_word_index: self.dense_word_index, + dense_bit_index: self.dense_bit_index, sparse_elements, sparse_index: self.sparse_index, } @@ -144,13 +159,19 @@ impl<'a> Iterator for OctetIter<'a> { } else { let old_index = self.dense_index; self.dense_index += 1; - let (word, bit) = DenseBinaryMatrix::bit_position(old_index); - let value = - if self.dense_elements.unwrap()[word] & DenseBinaryMatrix::select_mask(bit) == 0 { - Octet::zero() - } else { - Octet::one() - }; + let value = if self.dense_elements.unwrap()[self.dense_word_index] + & DenseBinaryMatrix::select_mask(self.dense_bit_index) + == 0 + { + Octet::zero() + } else { + Octet::one() + }; + self.dense_bit_index += 1; + if self.dense_bit_index == 64 { + self.dense_bit_index = 0; + self.dense_word_index += 1; + } return Some((old_index, value)); } } diff --git a/src/matrix.rs b/src/matrix.rs index e2b072a..ab2ea22 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -2,7 +2,7 @@ use crate::gf2::add_assign_binary; use crate::iterators::OctetIter; use crate::octet::Octet; use crate::octets::BinaryOctetVec; -use crate::util::get_both_indices; +use crate::util::get_both_ranges; use std::mem::size_of; // TODO: change this struct to not use the Octet class, since it's binary not GF(256) @@ -60,14 +60,25 @@ pub struct DenseBinaryMatrix { height: usize, width: usize, // Values are bit-packed into u64 - // TODO: optimize into a single dimensional vec - elements: Vec>, + elements: Vec, } impl DenseBinaryMatrix { // Returns (word in elements vec, and bit in word) for the given col - pub fn bit_position(col: usize) -> (usize, usize) { - return (col / WORD_WIDTH, col % WORD_WIDTH); + fn bit_position(&self, row: usize, col: usize) -> (usize, usize) { + return ( + row * self.row_word_width() + Self::word_offset(col), + col % WORD_WIDTH, + ); + } + + fn word_offset(col: usize) -> usize { + col / WORD_WIDTH + } + + // Number of words required per row + fn row_word_width(&self) -> usize { + (self.width + WORD_WIDTH - 1) / WORD_WIDTH } // Returns mask to select the given bit in a word @@ -98,7 +109,7 @@ impl DenseBinaryMatrix { impl BinaryMatrix for DenseBinaryMatrix { fn new(height: usize, width: usize, _: usize) -> DenseBinaryMatrix { - let elements = vec![vec![0; DenseBinaryMatrix::bit_position(width).0 + 1]; height]; + let elements = vec![0; height * (width + WORD_WIDTH - 1) / WORD_WIDTH]; DenseBinaryMatrix { height, width, @@ -107,11 +118,11 @@ impl BinaryMatrix for DenseBinaryMatrix { } fn set(&mut self, i: usize, j: usize, value: Octet) { - let (word, bit) = DenseBinaryMatrix::bit_position(j); + let (word, bit) = self.bit_position(i, j); if value == Octet::zero() { - DenseBinaryMatrix::clear_bit(&mut self.elements[i][word], bit); + DenseBinaryMatrix::clear_bit(&mut self.elements[word], bit); } else { - DenseBinaryMatrix::set_bit(&mut self.elements[i][word], bit); + DenseBinaryMatrix::set_bit(&mut self.elements[word], bit); } } @@ -125,32 +136,32 @@ impl BinaryMatrix for DenseBinaryMatrix { fn size_in_bytes(&self) -> usize { let mut bytes = size_of::(); - bytes += size_of::>() * self.elements.len(); - bytes += size_of::() * self.height * self.width; + bytes += size_of::>(); + bytes += size_of::() * self.elements.len(); bytes } fn count_ones(&self, row: usize, start_col: usize, end_col: usize) -> usize { - let (start_word, start_bit) = DenseBinaryMatrix::bit_position(start_col); - let (end_word, end_bit) = DenseBinaryMatrix::bit_position(end_col); + let (start_word, start_bit) = self.bit_position(row, start_col); + let (end_word, end_bit) = self.bit_position(row, end_col); // Handle case when there is only one word if start_word == end_word { let mut mask = DenseBinaryMatrix::select_bit_and_all_left_mask(start_bit); mask &= DenseBinaryMatrix::select_all_right_of_mask(end_bit); - let bits = self.elements[row][start_word] & mask; + let bits = self.elements[start_word] & mask; return bits.count_ones() as usize; } - let first_word_bits = self.elements[row][start_word] - & DenseBinaryMatrix::select_bit_and_all_left_mask(start_bit); + let first_word_bits = + self.elements[start_word] & DenseBinaryMatrix::select_bit_and_all_left_mask(start_bit); let mut ones = first_word_bits.count_ones(); for word in (start_word + 1)..end_word { - ones += self.elements[row][word].count_ones(); + ones += self.elements[word].count_ones(); } if end_bit > 0 { let bits = - self.elements[row][end_word] & DenseBinaryMatrix::select_all_right_of_mask(end_bit); + self.elements[end_word] & DenseBinaryMatrix::select_all_right_of_mask(end_bit); ones += bits.count_ones(); } @@ -158,7 +169,14 @@ impl BinaryMatrix for DenseBinaryMatrix { } fn get_row_iter(&self, row: usize, start_col: usize, end_col: usize) -> OctetIter { - OctetIter::new_dense_binary(start_col, end_col, &self.elements[row]) + let (first_word, first_bit) = self.bit_position(row, start_col); + let (last_word, _) = self.bit_position(row, end_col); + OctetIter::new_dense_binary( + start_col, + end_col, + first_bit, + &self.elements[first_word..=last_word], + ) } fn get_ones_in_column(&self, col: usize, start_row: usize, end_row: usize) -> Vec { @@ -202,8 +220,8 @@ impl BinaryMatrix for DenseBinaryMatrix { } fn get(&self, i: usize, j: usize) -> Octet { - let (word, bit) = DenseBinaryMatrix::bit_position(j); - if self.elements[i][word] & DenseBinaryMatrix::select_mask(bit) == 0 { + let (word, bit) = self.bit_position(i, j); + if self.elements[word] & DenseBinaryMatrix::select_mask(bit) == 0 { return Octet::zero(); } else { return Octet::one(); @@ -211,27 +229,33 @@ impl BinaryMatrix for DenseBinaryMatrix { } fn swap_rows(&mut self, i: usize, j: usize) { - self.elements.swap(i, j); + let (row_i, _) = self.bit_position(i, 0); + let (row_j, _) = self.bit_position(j, 0); + for k in 0..self.row_word_width() { + self.elements.swap(row_i + k, row_j + k); + } } fn swap_columns(&mut self, i: usize, j: usize, start_row_hint: usize) { - let (word_i, bit_i) = DenseBinaryMatrix::bit_position(i); - let (word_j, bit_j) = DenseBinaryMatrix::bit_position(j); + // Lookup for row zero to get the base word offset + let (word_i, bit_i) = self.bit_position(0, i); + let (word_j, bit_j) = self.bit_position(0, j); let unset_i = !DenseBinaryMatrix::select_mask(bit_i); let unset_j = !DenseBinaryMatrix::select_mask(bit_j); let bit_i = DenseBinaryMatrix::select_mask(bit_i); let bit_j = DenseBinaryMatrix::select_mask(bit_j); - for row in start_row_hint..self.elements.len() { - let i_set = self.elements[row][word_i] & bit_i != 0; - if self.elements[row][word_j] & bit_j == 0 { - self.elements[row][word_i] &= unset_i; + let row_width = self.row_word_width(); + for row in start_row_hint..self.height { + let i_set = self.elements[row * row_width + word_i] & bit_i != 0; + if self.elements[row * row_width + word_j] & bit_j == 0 { + self.elements[row * row_width + word_i] &= unset_i; } else { - self.elements[row][word_i] |= bit_i; + self.elements[row * row_width + word_i] |= bit_i; } if i_set { - self.elements[row][word_j] |= bit_j; + self.elements[row * row_width + word_j] |= bit_j; } else { - self.elements[row][word_j] &= unset_j; + self.elements[row * row_width + word_j] &= unset_j; } } } @@ -250,20 +274,37 @@ impl BinaryMatrix for DenseBinaryMatrix { fn add_assign_rows(&mut self, dest: usize, src: usize, _start_col: usize) { assert_ne!(dest, src); - let (dest_row, temp_row) = get_both_indices(&mut self.elements, dest, src); + let (dest_word, _) = self.bit_position(dest, 0); + let (src_word, _) = self.bit_position(src, 0); + let row_width = self.row_word_width(); + let (dest_row, temp_row) = + get_both_ranges(&mut self.elements, dest_word, src_word, row_width); add_assign_binary(dest_row, temp_row); } fn resize(&mut self, new_height: usize, new_width: usize) { assert!(new_height <= self.height); assert!(new_width <= self.width); - let (new_words, _) = DenseBinaryMatrix::bit_position(new_width); - self.elements.truncate(new_height); - for row in 0..self.elements.len() { - self.elements[row].truncate(new_words + 1); - } + let old_row_width = self.row_word_width(); self.height = new_height; self.width = new_width; + let new_row_width = self.row_word_width(); + let words_to_remove = old_row_width - new_row_width; + if words_to_remove > 0 { + let mut src = 0; + let mut dest = 0; + while dest < new_height * new_row_width { + self.elements[dest] = self.elements[src]; + src += 1; + dest += 1; + if dest % new_row_width == 0 { + // After copying each row, skip over the elements being dropped + src += words_to_remove; + } + } + assert_eq!(src, new_height * old_row_width); + } + self.elements.truncate(new_height * self.row_word_width()); } } @@ -306,6 +347,28 @@ mod tests { } } + #[test] + fn row_iter() { + // rand_dense_and_sparse uses set(), so just check that it works + let (dense, sparse) = rand_dense_and_sparse(8); + for row in 0..dense.height() { + let start_col = rand::thread_rng().gen_range(0, dense.width() - 2); + let end_col = rand::thread_rng().gen_range(start_col + 1, dense.width()); + let mut dense_iter = dense.get_row_iter(row, start_col, end_col); + let mut sparse_iter = sparse.get_row_iter(row, start_col, end_col); + for col in start_col..end_col { + assert_eq!(dense.get(row, col), sparse.get(row, col)); + assert_eq!((col, dense.get(row, col)), dense_iter.next().unwrap()); + // Sparse iter is not required to return zeros + if sparse.get(row, col) != Octet::zero() { + assert_eq!((col, sparse.get(row, col)), sparse_iter.next().unwrap()); + } + } + assert!(dense_iter.next().is_none()); + assert!(sparse_iter.next().is_none()); + } + } + #[test] fn swap_rows() { // rand_dense_and_sparse uses set(), so just check that it works diff --git a/src/util.rs b/src/util.rs index 4054e3d..ce79bd8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,24 @@ +// Get two non-overlapping ranges starting at i & j, both with length len +pub fn get_both_ranges( + vector: &mut Vec, + i: usize, + j: usize, + len: usize, +) -> (&mut [T], &mut [T]) { + debug_assert_ne!(i, j); + debug_assert!(i + len <= vector.len()); + debug_assert!(j + len <= vector.len()); + if i < j { + debug_assert!(i + len <= j); + let (first, last) = vector.split_at_mut(j); + return (&mut first[i..(i + len)], &mut last[0..len]); + } else { + debug_assert!(j + len <= i); + let (first, last) = vector.split_at_mut(i); + return (&mut last[0..len], &mut first[j..(j + len)]); + } +} + pub fn get_both_indices(vector: &mut Vec, i: usize, j: usize) -> (&mut T, &mut T) { debug_assert_ne!(i, j); debug_assert!(i < vector.len());