raptorq/src/decoder.rs
Christopher Berner 82eb5dee14 Resolve 2.0 TODOs
2024-03-14 20:55:54 -07:00

614 lines
20 KiB
Rust

#[cfg(feature = "std")]
use std::{collections::HashSet as Set, iter, vec::Vec};
#[cfg(not(feature = "std"))]
use core::iter;
#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeSet as Set, vec::Vec};
use crate::base::intermediate_tuple;
use crate::base::partition;
use crate::base::EncodingPacket;
use crate::base::ObjectTransmissionInformation;
use crate::constraint_matrix::enc_indices;
use crate::constraint_matrix::generate_constraint_matrix;
use crate::encoder::SPARSE_MATRIX_THRESHOLD;
use crate::matrix::{BinaryMatrix, DenseBinaryMatrix};
use crate::octet_matrix::DenseOctetMatrix;
use crate::pi_solver::fused_inverse_mul_symbols;
use crate::sparse_matrix::SparseBinaryMatrix;
use crate::symbol::Symbol;
use crate::systematic_constants::num_hdpc_symbols;
use crate::systematic_constants::num_ldpc_symbols;
use crate::systematic_constants::{
calculate_p1, extended_source_block_symbols, num_lt_symbols, num_pi_symbols, systematic_index,
};
use crate::util::int_div_ceil;
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct Decoder {
config: ObjectTransmissionInformation,
block_decoders: Vec<SourceBlockDecoder>,
blocks: Vec<Option<Vec<u8>>>,
}
impl Decoder {
pub fn new(config: ObjectTransmissionInformation) -> Decoder {
let kt = int_div_ceil(config.transfer_length(), config.symbol_size() as u64);
let (kl, ks, zl, zs) = partition(kt, config.source_blocks());
let mut decoders = vec![];
for i in 0..zl {
decoders.push(SourceBlockDecoder::new(
i as u8,
&config,
u64::from(kl) * u64::from(config.symbol_size()),
));
}
for i in zl..(zl + zs) {
decoders.push(SourceBlockDecoder::new(
i as u8,
&config,
u64::from(ks) * u64::from(config.symbol_size()),
));
}
Decoder {
config,
block_decoders: decoders,
blocks: vec![None; (zl + zs) as usize],
}
}
#[cfg(all(any(test, feature = "benchmarking"), not(feature = "python")))]
pub fn set_sparse_threshold(&mut self, value: u32) {
for block_decoder in self.block_decoders.iter_mut() {
block_decoder.set_sparse_threshold(value);
}
}
pub fn decode(&mut self, packet: EncodingPacket) -> Option<Vec<u8>> {
let block_number = packet.payload_id.source_block_number() as usize;
if self.blocks[block_number].is_none() {
self.blocks[block_number] =
self.block_decoders[block_number].decode(iter::once(packet));
}
for block in self.blocks.iter() {
if block.is_none() {
return None;
}
}
let mut result = vec![];
for block in self.blocks.iter().flatten() {
result.extend(block);
}
result.truncate(self.config.transfer_length() as usize);
Some(result)
}
#[cfg(not(feature = "python"))]
pub fn add_new_packet(&mut self, packet: EncodingPacket) {
let block_number = packet.payload_id.source_block_number() as usize;
if self.blocks[block_number].is_none() {
self.blocks[block_number] =
self.block_decoders[block_number].decode(iter::once(packet));
}
}
#[cfg(not(feature = "python"))]
pub fn get_result(&self) -> Option<Vec<u8>> {
for block in self.blocks.iter() {
if block.is_none() {
return None;
}
}
let mut result = vec![];
for block in self.blocks.iter().flatten() {
result.extend(block);
}
result.truncate(self.config.transfer_length() as usize);
Some(result)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct SourceBlockDecoder {
source_block_id: u8,
symbol_size: u16,
num_sub_blocks: u16,
symbol_alignment: u8,
source_block_symbols: u32,
source_symbols: Vec<Option<Symbol>>,
repair_packets: Vec<EncodingPacket>,
received_source_symbols: u32,
received_esi: Set<u32>,
decoded: bool,
sparse_threshold: u32,
}
impl SourceBlockDecoder {
pub fn new(
source_block_id: u8,
config: &ObjectTransmissionInformation,
block_length: u64,
) -> SourceBlockDecoder {
let source_symbols = int_div_ceil(block_length, config.symbol_size() as u64);
SourceBlockDecoder {
source_block_id,
symbol_size: config.symbol_size(),
num_sub_blocks: config.sub_blocks(),
symbol_alignment: config.symbol_alignment(),
source_block_symbols: source_symbols,
source_symbols: vec![None; source_symbols as usize],
repair_packets: vec![],
received_source_symbols: 0,
received_esi: Set::new(),
decoded: false,
sparse_threshold: SPARSE_MATRIX_THRESHOLD,
}
}
#[cfg(any(test, feature = "benchmarking"))]
pub fn set_sparse_threshold(&mut self, value: u32) {
self.sparse_threshold = value;
}
fn unpack_sub_blocks(&self, result: &mut [u8], symbol: &Symbol, symbol_index: usize) {
let (tl, ts, nl, ns) = partition(
(self.symbol_size / self.symbol_alignment as u16) as u32,
self.num_sub_blocks,
);
let mut symbol_offset = 0;
let mut sub_block_offset = 0;
for sub_block in 0..(nl + ns) {
let bytes = if sub_block < nl {
tl as usize * self.symbol_alignment as usize
} else {
ts as usize * self.symbol_alignment as usize
};
let start = sub_block_offset + bytes * symbol_index;
result[start..start + bytes]
.copy_from_slice(&symbol.as_bytes()[symbol_offset..symbol_offset + bytes]);
symbol_offset += bytes;
sub_block_offset += bytes * self.source_block_symbols as usize;
}
}
fn try_pi_decode(
&mut self,
constraint_matrix: impl BinaryMatrix,
hdpc_rows: DenseOctetMatrix,
symbols: Vec<Symbol>,
) -> Option<Vec<u8>> {
let intermediate_symbols = match fused_inverse_mul_symbols(
constraint_matrix,
hdpc_rows,
symbols,
self.source_block_symbols,
) {
(None, _) => return None,
(Some(s), _) => s,
};
let mut result = vec![0; self.symbol_size as usize * self.source_block_symbols as usize];
let lt_symbols = num_lt_symbols(self.source_block_symbols);
let pi_symbols = num_pi_symbols(self.source_block_symbols);
let sys_index = systematic_index(self.source_block_symbols);
let p1 = calculate_p1(self.source_block_symbols);
for i in 0..self.source_block_symbols as usize {
if let Some(ref symbol) = self.source_symbols[i] {
self.unpack_sub_blocks(&mut result, symbol, i);
} else {
let rebuilt = self.rebuild_source_symbol(
&intermediate_symbols,
i as u32,
lt_symbols,
pi_symbols,
sys_index,
p1,
);
self.unpack_sub_blocks(&mut result, &rebuilt, i);
}
}
self.decoded = true;
return Some(result);
}
pub fn decode<T: IntoIterator<Item = EncodingPacket>>(
&mut self,
packets: T,
) -> Option<Vec<u8>> {
for packet in packets {
assert_eq!(
self.source_block_id,
packet.payload_id.source_block_number()
);
let (payload_id, payload) = packet.split();
if self.received_esi.insert(payload_id.encoding_symbol_id()) {
if payload_id.encoding_symbol_id() >= self.source_block_symbols {
// Repair symbol
self.repair_packets
.push(EncodingPacket::new(payload_id, payload));
} else {
// Source symbol
self.source_symbols[payload_id.encoding_symbol_id() as usize] =
Some(Symbol::new(payload));
self.received_source_symbols += 1;
}
}
}
let num_extended_symbols = extended_source_block_symbols(self.source_block_symbols);
let num_padding_symbols = num_extended_symbols - self.source_block_symbols;
// Case 1: the number of received packets is insufficient for decoding
if self.received_esi.len() < self.source_block_symbols as usize {
return None;
}
// Case 2: we have all source symbols and can return them without decoding
if self.received_source_symbols == self.source_block_symbols {
let mut result =
vec![0; self.symbol_size as usize * self.source_block_symbols as usize];
for (i, symbol) in self.source_symbols.iter().enumerate() {
self.unpack_sub_blocks(&mut result, symbol.as_ref().unwrap(), i);
}
self.decoded = true;
return Some(result);
}
// Case 3: we may have sufficient symbols to do a standard decoding
let s = num_ldpc_symbols(self.source_block_symbols) as usize;
let h = num_hdpc_symbols(self.source_block_symbols) as usize;
let mut encoded_isis = vec![];
// See section 5.3.3.4.2. There are S + H zero symbols to start the D vector
let mut d = vec![Symbol::zero(self.symbol_size); s + h];
for (i, source) in self.source_symbols.iter().enumerate() {
if let Some(symbol) = source {
encoded_isis.push(i as u32);
d.push(symbol.clone());
}
}
// Append the extended padding symbols
for i in self.source_block_symbols..num_extended_symbols {
encoded_isis.push(i);
d.push(Symbol::zero(self.symbol_size));
}
// Append the received repair symbols
for repair_packet in self.repair_packets.iter() {
// We need to convert from ESI to ISI
encoded_isis.push(repair_packet.payload_id.encoding_symbol_id() + num_padding_symbols);
d.push(Symbol::new(repair_packet.data.clone()));
}
if num_extended_symbols >= self.sparse_threshold {
let (constraint_matrix, hdpc) = generate_constraint_matrix::<SparseBinaryMatrix>(
self.source_block_symbols,
&encoded_isis,
);
self.try_pi_decode(constraint_matrix, hdpc, d)
} else {
let (constraint_matrix, hdpc) = generate_constraint_matrix::<DenseBinaryMatrix>(
self.source_block_symbols,
&encoded_isis,
);
self.try_pi_decode(constraint_matrix, hdpc, d)
}
}
fn rebuild_source_symbol(
&self,
intermediate_symbols: &[Symbol],
source_symbol_id: u32,
lt_symbols: u32,
pi_symbols: u32,
sys_index: u32,
p1: u32,
) -> Symbol {
let mut rebuilt = Symbol::zero(self.symbol_size);
let tuple = intermediate_tuple(source_symbol_id, lt_symbols, sys_index, p1);
for i in enc_indices(tuple, lt_symbols, pi_symbols, p1) {
rebuilt += &intermediate_symbols[i];
}
rebuilt
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod codec_tests {
use std::{
iter,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
vec::Vec,
};
#[cfg(not(feature = "python"))]
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(not(feature = "python"))]
use crate::Decoder;
#[cfg(not(feature = "python"))]
use crate::{Encoder, EncoderBuilder};
use crate::{
ObjectTransmissionInformation, SourceBlockDecoder, SourceBlockEncoder,
SourceBlockEncodingPlan,
};
#[cfg(not(feature = "python"))]
#[test]
fn random_erasure_dense() {
random_erasure(99_999);
}
#[cfg(not(feature = "python"))]
#[test]
fn random_erasure_sparse() {
random_erasure(0);
}
#[cfg(not(feature = "python"))]
fn random_erasure(sparse_threshold: u32) {
let elements: usize = rand::thread_rng().gen_range(1..1_000_000);
let mut data: Vec<u8> = vec![0; elements];
for element in &mut data {
*element = rand::thread_rng().gen();
}
// MTU is set to not be too small, otherwise this test may take a very long time
let mtu = rand::thread_rng().gen_range(((elements / 100) as u16)..10_000);
let encoder = Encoder::with_defaults(&data, mtu);
let mut packets = encoder.get_encoded_packets(15);
packets.shuffle(&mut rand::thread_rng());
// Erase 10 packets at random
let length = packets.len();
packets.truncate(length - 10);
let mut decoder = Decoder::new(encoder.get_config());
decoder.set_sparse_threshold(sparse_threshold);
let mut result = None;
while !packets.is_empty() {
result = decoder.decode(packets.pop().unwrap());
if result.is_some() {
break;
}
}
assert_eq!(result.unwrap(), data);
}
#[cfg(not(feature = "python"))]
#[test]
fn sub_block_erasure() {
let elements: usize = 10_000;
let mut data: Vec<u8> = vec![0; elements];
for element in &mut data {
*element = rand::thread_rng().gen();
}
let mut builder = EncoderBuilder::new();
builder.set_decoder_memory_requirement(5000);
builder.set_max_packet_size(500);
let encoder = builder.build(&data);
assert!(encoder.get_config().sub_blocks() > 2);
// Test round trip
let mut decoder = Decoder::new(encoder.get_config());
let mut result = None;
for packet in encoder.get_encoded_packets(0) {
assert_eq!(result, None);
result = decoder.decode(packet);
}
assert_eq!(result.unwrap(), data);
// Test repair
let mut packets = encoder.get_encoded_packets(15);
packets.shuffle(&mut rand::thread_rng());
// Erase 10 packets at random
let length = packets.len();
packets.truncate(length - 10);
let mut decoder = Decoder::new(encoder.get_config());
let mut result = None;
while !packets.is_empty() {
result = decoder.decode(packets.pop().unwrap());
if result.is_some() {
break;
}
}
assert_eq!(result.unwrap(), data);
}
#[test]
fn round_trip_dense() {
round_trip(99_999, 100, false);
}
#[test]
fn round_trip_sparse() {
round_trip(0, 100, false);
}
#[test]
#[ignore]
fn round_trip_dense_extended() {
round_trip(99_999, 5000, true);
}
#[test]
#[ignore]
fn round_trip_sparse_extended() {
round_trip(0, 56403, true);
}
fn round_trip(sparse_threshold: u32, max_symbols: usize, progress: bool) {
let symbol_size = 8;
for symbol_count in 1..=max_symbols {
let elements = symbol_size * symbol_count;
let mut data: Vec<u8> = vec![0; elements];
for element in &mut data {
*element = rand::thread_rng().gen();
}
if progress && symbol_count % 100 == 0 {
println!("Completed {symbol_count} symbols")
}
let config = ObjectTransmissionInformation::new(0, symbol_size as u16, 0, 1, 1);
let encoder = SourceBlockEncoder::new(1, &config, &data);
let mut decoder = SourceBlockDecoder::new(1, &config, elements as u64);
decoder.set_sparse_threshold(sparse_threshold);
let mut result = None;
for packet in encoder.source_packets() {
assert_eq!(result, None);
result = decoder.decode(iter::once(packet));
}
assert_eq!(result.unwrap(), data);
}
}
#[test]
#[ignore]
fn repair_dense_extended() {
repair(99_999, 5000, true, false);
}
#[test]
#[ignore]
fn repair_sparse_extended() {
repair(0, 56403, true, false);
}
#[test]
fn repair_dense() {
repair(99_999, 50, false, false);
}
#[test]
fn repair_sparse() {
repair(0, 50, false, false);
}
#[test]
fn repair_dense_pre_planned() {
repair(99_999, 50, false, true);
}
#[test]
fn repair_sparse_pre_planned() {
repair(0, 50, false, true);
}
#[test]
fn issue_120() {
let symbol_size = 1280;
let overhead = 0.5;
let symbol_count = 10;
let elements = symbol_count * symbol_size as usize;
let mut data: Vec<u8> = vec![0; elements];
for i in 0..elements {
data[i] = rand::thread_rng().gen();
}
let total_bytes: usize = 1024 * 1024;
let iterations = total_bytes / elements;
let config = ObjectTransmissionInformation::new(0, symbol_size, 0, 1, 1);
let encoder = SourceBlockEncoder::new(1, &config, &data);
let elements_and_overhead = (symbol_count as f64 * (1.0 + overhead)) as u32;
let mut packets = encoder.repair_packets(0, iterations as u32 * elements_and_overhead);
for _ in 0..iterations {
let mut decoder = SourceBlockDecoder::new(1, &config, elements as u64);
let start = packets.len() - elements_and_overhead as usize;
decoder.decode(packets.drain(start..));
}
}
fn repair(sparse_threshold: u32, max_symbols: usize, progress: bool, pre_plan: bool) {
let pool = threadpool::Builder::new().build();
let failed = Arc::new(AtomicU32::new(0));
for symbol_count in 1..=max_symbols {
let failed = failed.clone();
pool.execute(move || {
if failed.load(Ordering::SeqCst) != 0 {
return;
}
let success = do_repair(symbol_count, sparse_threshold, pre_plan);
if !success {
failed.store(symbol_count as u32, Ordering::SeqCst);
}
if progress && symbol_count % 100 == 0 {
println!("[repair] Completed {symbol_count} symbols")
}
})
}
pool.join();
assert_eq!(0, failed.load(Ordering::SeqCst));
}
fn do_repair(symbol_count: usize, sparse_threshold: u32, pre_plan: bool) -> bool {
let symbol_size = 8;
let elements = symbol_size * symbol_count;
let mut data: Vec<u8> = vec![0; elements];
for element in &mut data {
*element = rand::thread_rng().gen();
}
let config = ObjectTransmissionInformation::new(0, 8, 0, 1, 1);
let encoder = if pre_plan {
let plan = SourceBlockEncodingPlan::generate(symbol_count as u16);
SourceBlockEncoder::with_encoding_plan(1, &config, &data, &plan)
} else {
SourceBlockEncoder::new(1, &config, &data)
};
let mut decoder = SourceBlockDecoder::new(1, &config, elements as u64);
decoder.set_sparse_threshold(sparse_threshold);
let mut result = None;
let mut parsed_packets = 0;
// This test can theoretically fail with ~1/256^5 probability
for packet in encoder.repair_packets(0, (elements / symbol_size + 4) as u32) {
if parsed_packets < elements / symbol_size && result.is_some() {
return false;
}
result = decoder.decode(iter::once(packet));
parsed_packets += 1;
}
return result.unwrap() == data;
}
}