Skip to content

Commit

Permalink
prevent UB from a malformed BitPack at decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
LGFae committed Feb 29, 2024
1 parent 5cd00b2 commit b6b9a2d
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 38 deletions.
5 changes: 3 additions & 2 deletions utils/src/compression/comp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ pub(super) unsafe fn pack_bytes(cur: &[u8], goal: &[u8], v: &mut Vec<u8>) {
}

if !v.is_empty() {
// add one extra zero to prevent access out of bounds later during decompression
v.push(0)
// add two extra bytes to prevent access out of bounds later during decompression
v.push(0);
v.push(0);
}
}

Expand Down
5 changes: 3 additions & 2 deletions utils/src/compression/comp/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ pub(super) unsafe fn pack_bytes(cur: &[u8], goal: &[u8], v: &mut Vec<u8>) {
}

if !v.is_empty() {
// add one extra zero to prevent access out of bounds later during decompression
v.push(0)
// add two extra bytes to prevent access out of bounds later during decompression
v.push(0);
v.push(0);
}
}

Expand Down
102 changes: 85 additions & 17 deletions utils/src/compression/decomp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub(super) mod ssse3;
/// buf must have the EXACT expected size by the BitPack
#[inline(always)]
pub(super) fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
assert!(
diff[diff.len() - 1] | diff[diff.len() - 2] == 0,
"Poorly formed BitPack"
);
// use the most efficient implementation available:
#[cfg(not(test))] // when testing, we want to use the specific implementation
{
Expand All @@ -17,15 +21,14 @@ pub(super) fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
}
}

// The very final byte is just padding to let us read 4 bytes at once without going out of
// bounds
let len = diff.len() - 1;
// The final bytes are just padding to prevent us from going out of bounds
let len = diff.len() - 3;
let buf_ptr = buf.as_mut_ptr();
let diff_ptr = diff.as_ptr();

let mut diff_idx = 0;
let mut pix_idx = 0;
while diff_idx + 1 < len {
while diff_idx < len {
while unsafe { diff_ptr.add(diff_idx).read() } == u8::MAX {
pix_idx += u8::MAX as usize;
diff_idx += 1;
Expand All @@ -41,14 +44,13 @@ pub(super) fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
to_cpy += unsafe { diff_ptr.add(diff_idx).read() } as usize;
diff_idx += 1;

assert!(
diff_idx + to_cpy * 3 + 1 < diff.len(),
"copying: {}, diff.len(): {}",
diff_idx + to_cpy * 3 + 1,
diff.len()
);
for _ in 0..to_cpy {
// it is much faster to use this assertion for testing than miri
debug_assert!(
diff_idx + 3 < diff.len(),
"diff_idx + 3: {}, diff.len(): {}",
diff_idx + 3,
diff.len()
);
unsafe {
std::ptr::copy_nonoverlapping(diff_ptr.add(diff_idx), buf_ptr.add(pix_idx * 4), 4)
}
Expand All @@ -61,15 +63,18 @@ pub(super) fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {

#[inline(always)]
pub(super) fn unpack_bytes_3channels(buf: &mut [u8], diff: &[u8]) {
// The very final byte is just padding to let us read 4 bytes at once without going out of
// bounds
let len = diff.len() - 1;
assert!(
diff[diff.len() - 1] | diff[diff.len() - 2] == 0,
"Poorly formed BitPack"
);
// The final bytes are just padding to prevent us from going out of bounds
let len = diff.len() - 3;
let buf_ptr = buf.as_mut_ptr();
let diff_ptr = diff.as_ptr();

let mut diff_idx = 0;
let mut pix_idx = 0;
while diff_idx + 1 < len {
while diff_idx < len {
while unsafe { diff_ptr.add(diff_idx).read() } == u8::MAX {
pix_idx += u8::MAX as usize;
diff_idx += 1;
Expand All @@ -85,9 +90,9 @@ pub(super) fn unpack_bytes_3channels(buf: &mut [u8], diff: &[u8]) {
to_cpy += unsafe { diff_ptr.add(diff_idx).read() } as usize;
diff_idx += 1;

debug_assert!(
assert!(
diff_idx + to_cpy * 3 <= diff.len(),
"diff_idx: {diff_idx}, to_copy: {to_cpy} diff.len(): {}",
"diff_idx: {diff_idx}, to_copy: {to_cpy}, diff.len(): {}",
diff.len()
);
unsafe {
Expand All @@ -101,3 +106,66 @@ pub(super) fn unpack_bytes_3channels(buf: &mut [u8], diff: &[u8]) {
pix_idx += to_cpy + 1;
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[should_panic]
fn ub_unpack_bytes4_poorly_formed() {
let mut bytes = vec![u8::MAX; 9];
let diff = vec![u8::MAX; 18];
unpack_bytes_4channels(&mut bytes, &diff);
}

#[test]
#[should_panic]
fn ub_unpack_bytes3_poorly_formed() {
let mut bytes = vec![u8::MAX; 9];
let diff = vec![u8::MAX; 18];
unpack_bytes_3channels(&mut bytes, &diff);
}

#[test]
#[should_panic]
fn ub_unpack_bytes4_poorly_formed2() {
let mut bytes = vec![u8::MAX; 9];
let mut diff = vec![u8::MAX; 18];
diff[8] = 0;
diff[7] = 0;
unpack_bytes_4channels(&mut bytes, &diff);
}

#[test]
#[should_panic]
fn ub_unpack_bytes3_poorly_formed2() {
let mut bytes = vec![u8::MAX; 9];
let mut diff = vec![u8::MAX; 18];
diff[8] = 0;
diff[7] = 0;
unpack_bytes_3channels(&mut bytes, &diff);
}

#[test]
#[should_panic]
fn ub_unpack_bytes4_poorly_formed3() {
let mut bytes = vec![u8::MAX; 9];
let mut diff = vec![u8::MAX; 18];
diff[8] = 0;
diff[7] = 0;
diff[2] = 0;
unpack_bytes_4channels(&mut bytes, &diff);
}

#[test]
#[should_panic]
fn ub_unpack_bytes3_poorly_formed3() {
let mut bytes = vec![u8::MAX; 9];
let mut diff = vec![u8::MAX; 18];
diff[8] = 0;
diff[7] = 0;
diff[2] = 0;
unpack_bytes_3channels(&mut bytes, &diff);
}
}
19 changes: 9 additions & 10 deletions utils/src/compression/decomp/ssse3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
pub(super) unsafe fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
use std::arch::x86_64 as intr;

// The very final byte is just padding to let us read 4 bytes at once without going out of
// bounds
let len = diff.len() - 1;
// The final bytes are just padding to prevent us from going out of bounds
let len = diff.len() - 3;
let buf_ptr = buf.as_mut_ptr();
let diff_ptr = diff.as_ptr();
let mask = intr::_mm_set_epi8(-1, 11, 10, 9, -1, 8, 7, 6, -1, 5, 4, 3, -1, 2, 1, 0);

let mut diff_idx = 0;
let mut pix_idx = 0;
while diff_idx + 1 < len {
while diff_idx < len {
while diff_ptr.add(diff_idx).read() == u8::MAX {
pix_idx += u8::MAX as usize;
diff_idx += 1;
Expand All @@ -28,6 +27,12 @@ pub(super) unsafe fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
to_cpy += diff_ptr.add(diff_idx).read() as usize;
diff_idx += 1;

assert!(
diff_idx + to_cpy * 3 + 1 < diff.len(),
"copying: {}, diff.len(): {}",
diff_idx + to_cpy * 3 + 1,
diff.len()
);
while to_cpy > 4 {
let d = intr::_mm_loadu_si128(diff_ptr.add(diff_idx).cast());
let to_store = intr::_mm_shuffle_epi8(d, mask);
Expand All @@ -38,12 +43,6 @@ pub(super) unsafe fn unpack_bytes_4channels(buf: &mut [u8], diff: &[u8]) {
to_cpy -= 4;
}
for _ in 0..to_cpy {
debug_assert!(
diff_idx + 3 < diff.len(),
"diff_idx + 3: {}, diff.len(): {}",
diff_idx + 3,
diff.len()
);
std::ptr::copy_nonoverlapping(diff_ptr.add(diff_idx), buf_ptr.add(pix_idx * 4), 4);
diff_idx += 3;
pix_idx += 1;
Expand Down
23 changes: 16 additions & 7 deletions utils/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,22 @@ impl Decompressor {
bitpack.expected_buf_size
));
}

self.ensure_capacity(bitpack.compressed_size as usize);

// SAFETY: errors will never happen because BitPacked is *always* only produced
// with correct lz4 compression
unsafe {
// with correct lz4 compression, and ptr has the necessary capacity
let size = unsafe {
LZ4_decompress_safe(
bitpack.inner.as_ptr() as _,
self.ptr.as_ptr() as _,
bitpack.inner.len() as c_int,
bitpack.compressed_size as c_int,
);
)
};

if size != bitpack.compressed_size {
return Err("BitPack is malformed!".to_string());
}

// SAFETY: the call to self.ensure_capacity guarantees the pointer has the necessary size
Expand Down Expand Up @@ -256,22 +261,26 @@ impl Decompressor {
expected_len
));
}

let cap: i32 = archived
.compressed_size
.deserialize(&mut rkyv::Infallible)
.unwrap();

self.ensure_capacity(cap as usize);

// SAFETY: errors will never happen because BitPacked is *always* only produced
// with correct lz4 compression
unsafe {
// with correct lz4 compression, and ptr has the necessary capacity
let size = unsafe {
LZ4_decompress_safe(
archived.inner.as_ptr() as _,
self.ptr.as_ptr() as _,
archived.inner.len() as c_int,
cap as c_int,
);
)
};

if size != cap {
return Err("BitPack is malformed!".to_string());
}

// SAFETY: the call to self.ensure_capacity guarantees the pointer has the necessary size
Expand Down

0 comments on commit b6b9a2d

Please sign in to comment.