Skip to content

Commit

Permalink
Merge pull request #52 from JakubOnderka/zstd-safe
Browse files Browse the repository at this point in the history
Make zstd_safe really safe
  • Loading branch information
gyscos authored Dec 9, 2018
2 parents 15c10cc + e4a8462 commit ef551d2
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 118 deletions.
7 changes: 2 additions & 5 deletions src/block/compressor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use parse_code;
use map_error_code;

use std::io;
use zstd_safe;
Expand Down Expand Up @@ -39,16 +39,13 @@ impl Compressor {
destination: &mut [u8],
level: i32,
) -> io::Result<usize> {
let code = {
zstd_safe::compress_using_dict(
&mut self.context,
destination,
source,
&self.dict[..],
level,
)
};
parse_code(code)
).map_err(map_error_code)
}

/// Compresses a block of data and returns the compressed result.
Expand Down
17 changes: 7 additions & 10 deletions src/block/decompressor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use parse_code;
use map_error_code;

use std::io;
use zstd_safe;
Expand Down Expand Up @@ -35,15 +35,12 @@ impl Decompressor {
source: &[u8],
destination: &mut [u8],
) -> io::Result<usize> {
let code = {
zstd_safe::decompress_using_dict(
&mut self.context,
destination,
source,
&self.dict,
)
};
parse_code(code)
zstd_safe::decompress_using_dict(
&mut self.context,
destination,
source,
&self.dict,
).map_err(map_error_code)
}

/// Decompress a block of data, and return the result in a `Vec<u8>`.
Expand Down
6 changes: 3 additions & 3 deletions src/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//! [`Encoder::with_dictionary`]: ../struct.Encoder.html#method.with_dictionary
//! [`Decoder::with_dictionary`]: ../struct.Decoder.html#method.with_dictionary
use parse_code;
use map_error_code;
use std::fs;

use std::io::{self, Read};
Expand Down Expand Up @@ -81,11 +81,11 @@ pub fn from_continuous(
let mut result = Vec::with_capacity(max_size);
unsafe {
result.set_len(max_size);
let written = parse_code(zstd_safe::train_from_buffer(
let written = zstd_safe::train_from_buffer(
&mut result,
sample_data,
sample_sizes,
))?;
).map_err(map_error_code)?;
result.set_len(written);
}
Ok(result)
Expand Down
16 changes: 4 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,10 @@ pub use zstd_safe::CLEVEL_DEFAULT as DEFAULT_COMPRESSION_LEVEL;
#[doc(no_inline)]
pub use stream::{decode_all, encode_all, Decoder, Encoder};

/// Parse the result code
///
/// Returns the number of bytes written if the code represents success,
/// or the error message otherwise.
fn parse_code(code: usize) -> io::Result<usize> {
if zstd_safe::is_error(code) == 0 {
Ok(code)
} else {
let msg = zstd_safe::get_error_name(code);
let error = io::Error::new(io::ErrorKind::Other, msg.to_string());
Err(error)
}
/// Returns the error message as io::Error based on error_code.
fn map_error_code(code: usize) -> io::Error {
let msg = zstd_safe::get_error_name(code);
io::Error::new(io::ErrorKind::Other, msg.to_string())
}

// Some helper functions to write full-cycle tests.
Expand Down
36 changes: 18 additions & 18 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::io;
use zstd_safe::{self, CStream, DStream, InBuffer, OutBuffer};

use dict::{DecoderDictionary, EncoderDictionary};
use parse_code;
use map_error_code;

/// Represents an abstract compression/decompression operation.
///
Expand Down Expand Up @@ -133,10 +133,10 @@ impl Decoder {
/// Creates a new decoder initialized with the given dictionary.
pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::create_dstream();
parse_code(zstd_safe::init_dstream_using_dict(
zstd_safe::init_dstream_using_dict(
&mut context,
dictionary,
))?;
).map_err(map_error_code)?;
Ok(Decoder { context })
}

Expand All @@ -145,10 +145,10 @@ impl Decoder {
dictionary: &DecoderDictionary,
) -> io::Result<Self> {
let mut context = zstd_safe::create_dstream();
parse_code(zstd_safe::init_dstream_using_ddict(
zstd_safe::init_dstream_using_ddict(
&mut context,
dictionary.as_ddict(),
))?;
).map_err(map_error_code)?;
Ok(Decoder { context })
}
}
Expand All @@ -159,15 +159,15 @@ impl Operation for Decoder {
input: &mut InBuffer,
output: &mut OutBuffer,
) -> io::Result<usize> {
parse_code(zstd_safe::decompress_stream(
zstd_safe::decompress_stream(
&mut self.context,
output,
input,
))
).map_err(map_error_code)
}

fn reinit(&mut self) -> io::Result<()> {
parse_code(zstd_safe::reset_dstream(&mut self.context))?;
zstd_safe::reset_dstream(&mut self.context).map_err(map_error_code)?;
Ok(())
}
fn finish(
Expand Down Expand Up @@ -200,11 +200,11 @@ impl Encoder {
/// Creates a new encoder initialized with the given dictionary.
pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::create_cstream();
parse_code(zstd_safe::init_cstream_using_dict(
zstd_safe::init_cstream_using_dict(
&mut context,
dictionary,
level,
))?;
).map_err(map_error_code)?;
Ok(Encoder { context })
}

Expand All @@ -213,10 +213,10 @@ impl Encoder {
dictionary: &EncoderDictionary,
) -> io::Result<Self> {
let mut context = zstd_safe::create_cstream();
parse_code(zstd_safe::init_cstream_using_cdict(
zstd_safe::init_cstream_using_cdict(
&mut context,
dictionary.as_cdict(),
))?;
).map_err(map_error_code)?;
Ok(Encoder { context })
}
}
Expand All @@ -227,30 +227,30 @@ impl Operation for Encoder {
input: &mut InBuffer,
output: &mut OutBuffer,
) -> io::Result<usize> {
parse_code(zstd_safe::compress_stream(
zstd_safe::compress_stream(
&mut self.context,
output,
input,
))
).map_err(map_error_code)
}

fn flush(&mut self, output: &mut OutBuffer) -> io::Result<usize> {
parse_code(zstd_safe::flush_stream(&mut self.context, output))
zstd_safe::flush_stream(&mut self.context, output).map_err(map_error_code)
}

fn finish(
&mut self,
output: &mut OutBuffer,
_finished_frame: bool,
) -> io::Result<usize> {
parse_code(zstd_safe::end_stream(&mut self.context, output))
zstd_safe::end_stream(&mut self.context, output).map_err(map_error_code)
}

fn reinit(&mut self) -> io::Result<()> {
parse_code(zstd_safe::reset_cstream(
zstd_safe::reset_cstream(
&mut self.context,
zstd_safe::CONTENTSIZE_UNKNOWN,
))?;
).map_err(map_error_code)?;
Ok(())
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/stream/read/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use std::io::Read;
use stream::read::{Decoder, Encoder};

#[test]
fn test_error_handling() {
let invalid_input = b"Abcdefghabcdefgh";

let mut decoder = Decoder::new(&invalid_input[..]).unwrap();
let output = decoder.read_to_end(&mut Vec::new());

assert_eq!(output.is_err(), true);
}

#[test]
fn test_cycle() {
let input = b"Abcdefghabcdefgh";
Expand Down
Loading

0 comments on commit ef551d2

Please sign in to comment.