Heuristic to optimize FMA on large sparse vectors

This commit is contained in:
Christopher Berner 2019-04-07 18:44:38 -07:00
parent 59e88ad758
commit d4e16d8b98

@ -426,13 +426,16 @@ impl <T: Clone> SparseVec<T> {
#[derive(Clone, Debug, PartialEq)]
struct SparseOctetVec {
// Kept sorted by the usize (key)
elements: SparseVec<Octet>
elements: SparseVec<Octet>,
// Number of zero elements in elements vec
zeros: usize
}
impl SparseOctetVec {
pub fn with_capacity(capacity: usize) -> SparseOctetVec {
SparseOctetVec {
elements: SparseVec::with_capacity(capacity)
elements: SparseVec::with_capacity(capacity),
zeros: 0
}
}
@ -442,16 +445,31 @@ impl SparseOctetVec {
// TODO: Probably wouldn't need this if we implemented "Furthermore, the row operations
// required for the HDPC rows may be performed for all such rows in one
// process, by using the algorithm described in Section 5.3.3.3."
if other.elements.elements.len() == 1 {
if other.elements.elements.len() == 1 &&
(self.zeros as f32 <= 0.9 * self.elements.elements.len() as f32) { // Heuristic to compress out the zeros, if they are 90% of the vector
let (other_col, other_value) = &other.elements.elements[0];
if let Some(self_value) = self.elements.remove(*other_col) {
// XXX: heuristic for handling large rows, since these are somewhat common (HDPC rows)
if self.elements.elements.len() > 1000 {
let self_value= self.elements.get(other_col)
.map(|x| x.clone())
.unwrap_or(Octet::zero());
let value = &self_value + &(other_value * scalar);
if value != Octet::zero() {
self.elements.insert(*other_col, value);
if value == Octet::zero() {
// Keep track of stored zeros, so they can be GC'ed later
self.zeros += 1;
}
self.elements.insert(*other_col, value);
}
else {
self.elements.insert(*other_col, other_value * scalar);
if let Some(self_value) = self.elements.remove(*other_col) {
let value = &self_value + &(other_value * scalar);
if value != Octet::zero() {
self.elements.insert(*other_col, value);
}
}
else {
self.elements.insert(*other_col, other_value * scalar);
}
}
return vec![];
}
@ -467,7 +485,9 @@ impl SparseOctetVec {
if let Some((self_col, self_value)) = self_entry {
if let Some((other_col, other_value)) = other_entry {
if self_col < other_col {
result.push((*self_col, self_value.clone()));
if *self_value != Octet::zero() {
result.push((*self_col, self_value.clone()));
}
self_entry = self_iter.next();
}
else if self_col == other_col {
@ -479,20 +499,26 @@ impl SparseOctetVec {
other_entry = other_iter.next();
}
else {
new_columns.push(*other_col);
result.push((*other_col, other_value * scalar));
if *other_value != Octet::zero() {
new_columns.push(*other_col);
result.push((*other_col, other_value * scalar));
}
other_entry = other_iter.next();
}
}
else {
result.push((*self_col, self_value.clone()));
if *self_value != Octet::zero() {
result.push((*self_col, self_value.clone()));
}
self_entry = self_iter.next();
}
}
else {
if let Some((other_col, other_value)) = other_entry {
new_columns.push(*other_col);
result.push((*other_col, other_value * scalar));
if *other_value != Octet::zero() {
new_columns.push(*other_col);
result.push((*other_col, other_value * scalar));
}
other_entry = other_iter.next();
}
else {
@ -501,6 +527,7 @@ impl SparseOctetVec {
}
}
self.elements.elements = result;
self.zeros = 0;
return new_columns;
}