diff --git a/Cargo.lock b/Cargo.lock index 611e29d..7c1ddb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.13" @@ -50,6 +59,17 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -58,11 +78,13 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "binary-ensemble" -version = "0.1.3" +version = "0.2.0" dependencies = [ "byteorder", - "clap", + "clap 4.5.2", "lipsum", + "pcompress", + "pipe", "rand", "rand_chacha", "rand_distr", @@ -70,6 +92,12 @@ dependencies = [ "xz2", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "byteorder" version = "1.5.0" @@ -88,6 +116,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "ansi_term", + "atty", + "bitflags", + "strsim 0.8.0", + "textwrap", + "unicode-width", + "vec_map", +] + [[package]] name = "clap" version = "4.5.2" @@ -107,7 +150,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.11.0", ] [[package]] @@ -116,10 +159,10 @@ version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", - "syn", + "syn 2.0.52", ] [[package]] @@ -134,6 +177,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "getrandom" version = "0.2.12" @@ -145,18 +203,42 @@ dependencies = [ "wasi", ] +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.153" @@ -200,6 +282,26 @@ dependencies = [ "libm", ] +[[package]] +name = "pcompress" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59023a57cca37e60515d3f7e981890b5956c390c0b2fbf5291c234bd26db3105" +dependencies = [ + "serde", + "serde_json", + "structopt", +] + +[[package]] +name = "pipe" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7b8f27da217eb966df4c58d4159ea939431950ca03cf782c22bd7c5c1d8d75" +dependencies = [ + "crossbeam-channel", +] + [[package]] name = "pkg-config" version = "0.3.30" @@ -212,6 +314,30 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.79" @@ -293,7 +419,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.52", ] [[package]] @@ -307,12 +433,53 @@ dependencies = [ "serde", ] +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + [[package]] name = "strsim" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +[[package]] +name = "structopt" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" +dependencies = [ + "clap 2.34.0", + "lazy_static", + "structopt-derive", +] + +[[package]] +name = "structopt-derive" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" +dependencies = [ + "heck 0.3.3", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.52" @@ -324,24 +491,79 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 71b7fc5..f4de914 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "binary-ensemble" -version = "0.1.3" +version = "0.2.0" edition = "2021" authors = ["Peter Rock "] exclude = ["example/"] @@ -9,7 +9,7 @@ readme = "README.md" repository = "https://github.com/peterrrock2/binary-ensemble" -description = "A CLI tool for working with and compressing ensambles of districting plans" +description = "A CLI tool for working with and compressing ensembles of districting plans" [lib] name = "ben" @@ -17,6 +17,8 @@ name = "ben" [dependencies] byteorder = "1.5.0" clap = { version = "^4.5.2", features = ["derive"] } +pcompress = "1.0.7" +pipe = "0.4.0" serde_json = "^1.0.107" xz2 = "0.1.7" diff --git a/TODO.md b/TODO.md index bdbe687..8714635 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ # Things to add -- [ ] Add a flag that allows for the transformation of all assignment vectors +- [x] Add a flag that allows for the transformation of all assignment vectors so that the first item is assigned to district 1, so something like [2,2,4,4,3,1,1,3] would turn into [1,1,2,2,3,4,4,3]. This will improve xben even further, but would technically alter the data @@ -10,10 +10,10 @@ - [ ] Make tests for all of the errors -- [ ] Maybe change the encoder and decoder into things that are their own structs with +- [x] Maybe change the encoder and decoder into things that are their own structs with implementations? -- [ ] Make a special MCMC writer for ben that add a self-loop counter to the start of +- [x] Make a special MCMC writer for ben that add a self-loop counter to the start of the next item. This will be really useful for reducing the size of any chain that has a high rejection ratio (e.g. reversible) @@ -24,4 +24,13 @@ - [ ] Add a reverse mode to reben to make reverting the labeling a little bit easier for the end user -- [ ] Add a `jsonl` mode to reben to relabel the `jsonl` file. \ No newline at end of file +- [ ] Add a `jsonl` mode to reben to relabel the `jsonl` file. + + +- [ ] Add a method to `read` that allows for reading a chunk of assignment vectors. + Will need a cursor to do this so that we can read ahead to the end and then + chunk it up. Might want to make this into a struct that implements `Iterator` + +- [ ] Finish out the robust suite of tests for the MkvChain mode. The pipeline is + already tested, but it probably would be good to duplicate all of the tests that + were written for the standard mode even if the adaptation is really simple. \ No newline at end of file diff --git a/src/bin/ben.rs b/src/bin/ben.rs index 4f39b78..812d433 100644 --- a/src/bin/ben.rs +++ b/src/bin/ben.rs @@ -1,7 +1,7 @@ use ben::decode::read::extract_assignment_ben; use ben::decode::*; use ben::encode::*; -use ben::logln; +use ben::{logln, BenVariant}; use clap::{Parser, ValueEnum}; use std::{ fs::File, @@ -23,9 +23,9 @@ enum Mode { /// Defines the command line arguments accepted by the program. #[derive(Parser, Debug)] #[command( - name = "Binary Ensamble CLI Tool", - about = "This is a command line tool for encoding and decoding binary ensamble files.", - version = "0.1.3" + name = "Binary Ensemble CLI Tool", + about = "This is a command line tool for encoding and decoding binary ensemble files.", + version = "0.2.0" )] struct Args { /// Mode to run the program in (encode, decode, or read). @@ -73,6 +73,14 @@ struct Args { #[arg(short = 'j', long)] jsonl_and_ben: bool, + /// When saving a file in the BEN format, the deault is to have + /// an assignment vector saved followed by the number of repetitions + /// of that assignment vector (this is useful for Markov chian methods + /// like ReCom). This flag will cause the program to forgo the repetition + /// count and just save all of the assignment vectors as they are encountered. + #[arg(short = 'a', long)] + save_all: bool, + /// If the output file already exists, this flag /// will cause the program to overwrite it without /// asking the user for confirmation. @@ -229,8 +237,17 @@ fn main() { } }; - if let Err(e) = jsonl_encode_ben(reader, writer) { - eprintln!("Error: {:?}", e); + let possible_error = if args.save_all { + jsonl_encode_ben(reader, writer, BenVariant::Standard) + } else { + jsonl_encode_ben(reader, writer, BenVariant::MkvChain) + }; + + match possible_error { + Ok(_) => {} + Err(err) => { + eprintln!("Error: {:?}", err); + } } } Mode::XEncode => { @@ -294,8 +311,13 @@ fn main() { eprintln!("Error: {:?}", err); } } else if jsonl_and_xben { - if let Err(err) = jsonl_encode_xben(reader, writer) { - eprintln!("Error: {:?}", err); + let possible_error = if args.save_all { + jsonl_encode_xben(reader, writer, BenVariant::Standard) + } else { + jsonl_encode_xben(reader, writer, BenVariant::MkvChain) + }; + if let Err(e) = possible_error { + eprintln!("Error: {:?}", e); } } else { eprintln!("Error: Unsupported file type(s) for xencode mode"); diff --git a/src/bin/pben.rs b/src/bin/pben.rs new file mode 100644 index 0000000..73530c3 --- /dev/null +++ b/src/bin/pben.rs @@ -0,0 +1,174 @@ +use ben::decode::*; +use ben::encode::*; +use ben::{logln, BenVariant}; +use clap::{Parser, ValueEnum}; +use pcompress; +use pipe::pipe; +use std::{ + fs::File, + io::{self, BufRead, BufReader, BufWriter, Read, Result, Write}, +}; +use xz2::write::XzEncoder; + +/// Defines the mode of operation. +#[derive(Parser, Debug, Clone, ValueEnum, PartialEq)] +enum Mode { + BenToPc, + PcToBen, + PcToXben, +} + +#[derive(Parser, Debug)] +#[command( + name = "Conversion tool for BEN and PCOMPRESS formats", + about = "This is a CLI tool that allows for the conversion between BEN and PCOMPRESS formats.", + version = "0.2.0" +)] +struct Args { + /// Mode to run the program in + #[arg(short, long, value_enum)] + mode: Mode, + + /// Input file to read from. + #[arg(short, long)] + input_file: Option, + + /// Output file to write to. Optional. + /// If not provided, the output file will be determined + /// based on the input file and the mode of operation. + #[arg(short, long)] + output_file: Option, + + /// If the output file already exists, this flag + /// will cause the program to overwrite it without + /// asking the user for confirmation. + #[arg(short = 'w', long)] + overwrite: bool, + + /// Enables verbose printing for the CLI. Optional. + #[arg(short, long)] + verbose: bool, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + if args.verbose { + std::env::set_var("RUST_LOG", "trace"); + } + + match args.mode { + Mode::BenToPc => { + logln!("Converting BEN to PCOMPRESS"); + + let ben_reader: Box = match args.input_file { + Some(file) => Box::new(BufReader::new(File::open(&file).unwrap())), + None => Box::new(io::stdin()), + }; + + let mut pcompress_writer: BufWriter> = match args.output_file { + Some(file) => BufWriter::new(Box::new(File::create(&file).unwrap())), + None => BufWriter::new(Box::new(io::stdout())), + }; + + let (pipe_reader, pipe_writer) = pipe(); + + let _ = std::thread::spawn(move || -> io::Result<()> { + assignment_decode_ben(ben_reader, pipe_writer) + }); + + let mut buf_pipe_reader = BufReader::new(pipe_reader); + + pcompress::encode::encode(&mut buf_pipe_reader, &mut pcompress_writer, false); + + Ok(()) + } + Mode::PcToBen => { + logln!("Converting PCOMPRESS to BEN"); + + let mut pcompress_reader: BufReader> = match args.input_file { + Some(file) => BufReader::new(Box::new(BufReader::new(File::open(&file).unwrap()))), + None => BufReader::new(Box::new(io::stdin())), + }; + + let mut ben_writer: BufWriter> = match args.output_file { + Some(file) => BufWriter::new(Box::new(File::create(&file).unwrap())), + None => BufWriter::new(Box::new(io::stdout())), + }; + + let (pipe_reader, pipe_writer) = pipe(); + + let mut buf_pipe_writer = BufWriter::new(pipe_writer); + + let _ = std::thread::spawn(move || { + pcompress::decode::decode(&mut pcompress_reader, &mut buf_pipe_writer, 0, false) + }); + + let mut buf_pipe_reader = BufReader::new(pipe_reader); + + assignment_encode_ben(&mut buf_pipe_reader, &mut ben_writer) + } + Mode::PcToXben => { + logln!("Converting PCOMPRESS to XBEN"); + + let mut pcompress_reader: BufReader> = match args.input_file { + Some(file) => BufReader::new(Box::new(BufReader::new(File::open(&file).unwrap()))), + None => BufReader::new(Box::new(io::stdin())), + }; + + let mut ben_writer: BufWriter> = match args.output_file { + Some(file) => BufWriter::new(Box::new(File::create(&file).unwrap())), + None => BufWriter::new(Box::new(io::stdout())), + }; + + let (pipe_reader, pipe_writer) = pipe(); + + let mut buf_pipe_writer = BufWriter::new(pipe_writer); + + let _ = std::thread::spawn(move || { + pcompress::decode::decode(&mut pcompress_reader, &mut buf_pipe_writer, 0, false) + }); + + let mut buf_pipe_reader = BufReader::new(pipe_reader); + + assignment_encode_xben(&mut buf_pipe_reader, &mut ben_writer) + } + } +} + +fn assignment_decode_ben(mut reader: R, mut writer: W) -> io::Result<()> { + let ben_reader = BenDecoder::new(&mut reader)?; + + for result in ben_reader { + match result { + Ok(assignment) => { + write!(writer, "{}\n", serde_json::to_string(&assignment).unwrap())?; + } + Err(e) => return Err(e), + } + } + + Ok(()) +} + +fn assignment_encode_ben(reader: R, writer: W) -> io::Result<()> { + let mut ben_writer = BenEncoder::new(writer, BenVariant::MkvChain); + + for line in reader.lines() { + let assignment: Vec = serde_json::from_str::>(&line.unwrap()) + .unwrap() + .into_iter() + .map(|x| x as u16 + 1) + .collect(); + ben_writer.write_assignment(assignment)?; + } + Ok(()) +} + +fn assignment_encode_xben(reader: R, writer: W) -> io::Result<()> { + let encoder = XzEncoder::new(writer, 9); + let mut xben_writer = XBenEncoder::new(encoder, BenVariant::MkvChain); + + xben_writer.write_ben_file(reader)?; + Ok(()) +} diff --git a/src/bin/reben.rs b/src/bin/reben.rs index 72cde8d..4dd2cac 100644 --- a/src/bin/reben.rs +++ b/src/bin/reben.rs @@ -19,12 +19,12 @@ enum Mode { /// Defines the command line arguments accepted by the program. #[derive(Parser, Debug)] #[command( - name = "Relabeling Binary Ensamble CLI Tool", + name = "Relabeling Binary Ensemble CLI Tool", about = concat!( - "This is a command line tool for relabeling binary ensambles ", + "This is a command line tool for relabeling binary ensembles ", "to help improve compression ratios for BEN and XBEN files." ), - version = "0.1.3" + version = "0.2.0" )] // TODO: Change the name of shape_file to dual_graph_file. @@ -51,6 +51,11 @@ struct Args { map_file: Option, /// Mode to run the program in (either JSON or BEN). + /// The JSON mode will sort a JSON file by a given key. + /// The BEN mode will relabel a BEN file according to a map file + /// or a key (the latter also requires a dual-graph file). If no + /// map file or key is provided, the BEN mode will canonicalize + /// the assignment vectors in the BEN file. #[arg(short, long)] mode: Mode, diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 0c4bd31..724ec9b 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -21,7 +21,154 @@ use std::io::{self, BufRead, Error, Read, Write}; use crate::utils::rle_to_vec; use super::encode::translate::*; -use super::{log, logln}; +use super::{log, logln, BenVariant}; + +#[derive(Debug)] +pub enum DecoderInitError { + InvalidFileFormat(String), + Io(io::Error), +} + +impl std::fmt::Display for DecoderInitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DecoderInitError::Io(e) => write!(f, "IO error: {}", e), + DecoderInitError::InvalidFileFormat(msg) => { + write!(f, "Invalid file format. Found header {:?}", msg) + } + } + } +} + +impl std::error::Error for DecoderInitError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DecoderInitError::Io(e) => Some(e), + DecoderInitError::InvalidFileFormat(_) => None, + } + } +} + +impl From for DecoderInitError { + fn from(error: io::Error) -> Self { + DecoderInitError::Io(error) + } +} + +impl From for io::Error { + fn from(error: DecoderInitError) -> Self { + match error { + DecoderInitError::Io(e) => e, + DecoderInitError::InvalidFileFormat(msg) => { + io::Error::new(io::ErrorKind::InvalidData, msg) + } + } + } +} + +// Note: This will make Read easier to use since +// I can now implement the read chunk with a Cursor +// object. +pub struct BenDecoder { + reader: R, + sample_count: usize, + variant: BenVariant, +} + +impl BenDecoder { + pub fn new(mut reader: R) -> Result { + let mut check_buffer = [0u8; 17]; + + if let Err(e) = reader.read_exact(&mut check_buffer) { + return Err(DecoderInitError::Io(e)); + } + + match &check_buffer { + b"STANDARD BEN FILE" => Ok(BenDecoder { + reader, + sample_count: 0, + variant: BenVariant::Standard, + }), + b"MKVCHAIN BEN FILE" => Ok(BenDecoder { + reader, + sample_count: 0, + variant: BenVariant::MkvChain, + }), + _ => Err(DecoderInitError::InvalidFileFormat(format!( + "Invalid file format. Found header bytes {:?}", + check_buffer + ))), + } + } + + fn write_all_jsonl(&mut self, mut writer: impl Write) -> io::Result<()> { + while let Some(result_tuple) = self.next() { + match result_tuple { + Ok((assignment, count)) => { + for _ in 0..count { + self.sample_count += 1; + let line = json!({ + "assignment": assignment, + "sample": self.sample_count, + }) + .to_string() + + "\n"; + writer.write_all(line.as_bytes()).unwrap(); + } + } + Err(e) => { + return Err(e); + } + } + } + Ok(()) + } +} + +impl Iterator for BenDecoder { + type Item = io::Result<(Vec, u16)>; + + fn next(&mut self) -> Option, u16)>> { + let mut tmp_buffer = [0u8]; + let max_val_bits: u8 = match self.reader.read_exact(&mut tmp_buffer) { + Ok(()) => tmp_buffer[0], + Err(e) => { + if e.kind() == io::ErrorKind::UnexpectedEof { + logln!(); + logln!("Done!"); + return None; + } + return Some(Err(e)); + } + }; + + let max_len_bits = self + .reader + .read_u8() + .expect(format!("Error when reading sample {}.", self.sample_count).as_str()); + let n_bytes = self + .reader + .read_u32::() + .expect(format!("Error when reading sample {}.", self.sample_count).as_str()); + + let assignment = + match decode_ben_line(&mut self.reader, max_val_bits, max_len_bits, n_bytes) { + Ok(output_rle) => rle_to_vec(output_rle), + Err(e) => return Some(Err(e)), + }; + + let count = if self.variant == BenVariant::MkvChain { + self.reader + .read_u16::() + .expect(format!("Error when reading sample {}.", self.sample_count).as_str()) + } else { + 1 + }; + + log!("Decoding sample: {}\r", self.sample_count + count as usize); + Some(Ok((assignment, count))) + } +} /// This function takes a reader containing a single ben32 encoded assignment /// vector and decodes it into a full assignment vector of u16s. @@ -40,15 +187,17 @@ use super::{log, logln}; /// bytes long since each assignment vector is an run-length encoded as a 32 bit /// integer (2 bytes for the value and 2 bytes for the count). /// -fn decode_ben32_line(mut reader: R) -> io::Result> { +fn decode_ben32_line( + mut reader: R, + variant: BenVariant, +) -> io::Result<(Vec, u16)> { let mut buffer = [0u8; 4]; let mut output_vec: Vec = Vec::new(); loop { - // Read 4 bytes (u32) from the encoded file - // https://stackoverflow.com/questions/30412521/how-to-read-a-specific-number-of-bytes-from-a-stream match reader.read_exact(&mut buffer) { Ok(()) => { + println!("found {:?}", buffer); let encoded = u32::from_be_bytes(buffer); if encoded == 0 { // Check for separator (all 0s) @@ -68,7 +217,16 @@ fn decode_ben32_line(mut reader: R) -> io::Result> { } } } - Ok(output_vec) + + let count = if variant == BenVariant::MkvChain { + reader + .read_u16::() + .expect("Error when reading sample.") + } else { + 1 + }; + + Ok((output_vec, count)) } /// This function takes a reader containing a file encoded with the @@ -94,37 +252,37 @@ fn decode_ben32_line(mut reader: R) -> io::Result> { /// This function will return an error if the input reader contains invalid ben32 /// data or if the the decode method encounters while trying to extract a single /// assignment vector, that error is propagated. -fn jsonl_decode_ben32(mut reader: R, mut writer: W) -> io::Result<()> { +fn jsonl_decode_ben32( + mut reader: R, + mut writer: W, + starting_sample: usize, + variant: BenVariant, +) -> io::Result<()> { let mut sample_number = 1; - let mut check_buffer = [0u8; 17]; - reader.read_exact(&mut check_buffer)?; - - if &check_buffer != b"STANDARD BEN FILE" { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } - loop { - let output_vec = decode_ben32_line(&mut reader); - if let Err(e) = output_vec { + let result = decode_ben32_line(&mut reader, variant); + println!("In jsonl_decode_ben32 result {:?}", result); + if let Err(e) = result { if e.kind() == io::ErrorKind::UnexpectedEof { return Ok(()); } return Err(e); } - // Write the reconstructed vector as JSON to the output file - let line = json!({ - "assignment": output_vec.unwrap(), - "sample": sample_number, - }) - .to_string() - + "\n"; + let (output_vec, count) = result.unwrap(); + + for _ in 0..count { + // Write the reconstructed vector as JSON to the output file + let line = json!({ + "assignment": output_vec, + "sample": sample_number + starting_sample, + }) + .to_string() + + "\n"; - writer.write_all(line.as_bytes())?; - sample_number += 1; + writer.write_all(line.as_bytes())?; + sample_number += 1; + } } } @@ -150,20 +308,26 @@ pub fn decode_xben_to_ben(reader: R, mut writer: W) -> io: let mut first_buffer = [0u8; 17]; - match decoder.read(&mut first_buffer) { - Ok(_) => { - if &first_buffer[..17] != b"STANDARD BEN FILE" { - return Err(Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } + if let Err(e) = decoder.read_exact(&mut first_buffer) { + return Err(e); + } + + let variant = match &first_buffer { + b"STANDARD BEN FILE" => { writer.write_all(b"STANDARD BEN FILE")?; + BenVariant::Standard } - Err(e) => { - return Err(e); + b"MKVCHAIN BEN FILE" => { + writer.write_all(b"MKVCHAIN BEN FILE")?; + BenVariant::MkvChain } - } + _ => { + return Err(Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )); + } + }; let mut buffer = [0u8; 1048576]; // 1MB buffer let mut overflow: Vec = Vec::new(); @@ -181,13 +345,28 @@ pub fn decode_xben_to_ben(reader: R, mut writer: W) -> io: // It is technically faster to read backwards from the last // multiple of 4 smaller than the length of the overflow buffer // but this provides only a minute speedup in almost all cases (maybe a - // few seconds). Reading form the front is both safer from a + // few seconds). Reading from the front is both safer from a // maintenance perspective and allows for a better progress indicator - for i in (3..overflow.len()).step_by(4) { - if overflow[i - 3..=i] == [0, 0, 0, 0] { - last_valid_assignment = i + 1; - line_count += 1; - log!("Decoding sample: {}\r", line_count); + match variant { + BenVariant::Standard => { + for i in (3..overflow.len()).step_by(4) { + if overflow[i - 3..=i] == [0, 0, 0, 0] { + last_valid_assignment = i + 1; + line_count += 1; + log!("Decoding sample: {}\r", line_count); + } + } + } + BenVariant::MkvChain => { + for i in (3..overflow.len() - 2).step_by(2) { + if overflow[i - 3..=i] == [0, 0, 0, 0] { + last_valid_assignment = i + 3; + let lines = &overflow[i + 1..i + 3]; + let n_lines = u16::from_be_bytes([lines[0], lines[1]]); + line_count += n_lines as usize; + log!("Decoding sample: {}\r", line_count); + } + } } } @@ -195,7 +374,7 @@ pub fn decode_xben_to_ben(reader: R, mut writer: W) -> io: continue; } - ben32_to_ben_lines(&overflow[0..last_valid_assignment], &mut writer)?; + ben32_to_ben_lines(&overflow[0..last_valid_assignment], &mut writer, variant)?; overflow = overflow[last_valid_assignment..].to_vec(); } logln!(); @@ -364,46 +543,9 @@ pub fn decode_ben_line( /// This function will return an error if the input reader contains invalid ben /// data or if the the decode method encounters while trying to extract a single /// assignment vector, that error is then propagated. -pub fn jsonl_decode_ben(mut reader: R, mut writer: W) -> io::Result<()> { - let mut sample_number = 1; - let mut check_buffer = [0u8; 17]; - reader.read_exact(&mut check_buffer)?; - - if &check_buffer != b"STANDARD BEN FILE" { - return Err(Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } - - loop { - let mut tmp_buffer = [0u8]; - let max_val_bits: u8 = match reader.read_exact(&mut tmp_buffer) { - Ok(()) => tmp_buffer[0], - Err(e) => { - if e.kind() == io::ErrorKind::UnexpectedEof { - logln!(); - logln!("Done!"); - return Ok(()); - } - return Err(e); - } - }; - log!("Decoding sample: {}\r", sample_number); - let max_len_bits = reader.read_u8()?; - let n_bytes = reader.read_u32::()?; - - let output_rle = decode_ben_line(&mut reader, max_val_bits, max_len_bits, n_bytes)?; - - let line = json!({ - "assignment": rle_to_vec(output_rle), - "sample": sample_number, - }) - .to_string() - + "\n"; - writer.write_all(line.as_bytes())?; - sample_number += 1; - } +pub fn jsonl_decode_ben(reader: R, writer: W) -> io::Result<()> { + let mut ben_decoder = BenDecoder::new(reader)?; + ben_decoder.write_all_jsonl(writer) } /// This function takes a reader containing a file encoded in the XBEN format @@ -434,24 +576,26 @@ pub fn jsonl_decode_xben(reader: R, mut writer: W) -> io:: let mut first_buffer = [0u8; 17]; - match decoder.read(&mut first_buffer) { - Ok(_) => { - if &first_buffer[..17] != b"STANDARD BEN FILE" { - return Err(Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } - } - Err(e) => { - return Err(e); - } + if let Err(e) = decoder.read_exact(&mut first_buffer) { + return Err(e); } + let variant = match &first_buffer { + b"STANDARD BEN FILE" => BenVariant::Standard, + b"MKVCHAIN BEN FILE" => BenVariant::MkvChain, + _ => { + return Err(Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )); + } + }; + let mut buffer = [0u8; 1048576]; // 1MB buffer let mut overflow: Vec = Vec::new(); let mut line_count: usize = 0; + let mut starting_sample: usize = 0; while let Ok(count) = decoder.read(&mut buffer) { if count == 0 { break; @@ -464,13 +608,31 @@ pub fn jsonl_decode_xben(reader: R, mut writer: W) -> io:: // It is technically faster to read backwards from the last // multiple of 4 smaller than the length of the overflow buffer // but this provides only a minute speedup in almost all cases (maybe a - // few seconds). Reading form the front is both safer from a + // few seconds). Reading from the front is both safer from a // maintenance perspective and allows for a better progress indicator - for i in (3..overflow.len()).step_by(4) { - if overflow[i - 3..=i] == [0, 0, 0, 0] { - last_valid_assignment = i + 1; - line_count += 1; - log!("Decoding sample: {}\r", line_count); + match variant { + BenVariant::Standard => { + for i in (3..overflow.len()).step_by(4) { + if overflow[i - 3..=i] == [0, 0, 0, 0] { + last_valid_assignment = i + 1; + line_count += 1; + log!("Decoding sample: {}\r", line_count); + } + } + } + BenVariant::MkvChain => { + // Need a different step size here because each assignment + // vector is no longer guaranteed to be a multiple of 4 bytes + // due to the 2-byte repetition count appended at the end + for i in (last_valid_assignment + 3..overflow.len() - 2).step_by(2) { + if overflow[i - 3..=i] == [0, 0, 0, 0] { + last_valid_assignment = i + 3; + let lines = &overflow[i + 1..i + 3]; + let n_lines = u16::from_be_bytes([lines[0], lines[1]]); + line_count += n_lines as usize; + log!("Decoding sample: {}\r", line_count); + } + } } } @@ -478,10 +640,14 @@ pub fn jsonl_decode_xben(reader: R, mut writer: W) -> io:: continue; } - let mut new_vec: Vec = b"STANDARD BEN FILE".to_vec(); - new_vec.extend(&overflow[0..last_valid_assignment]); - jsonl_decode_ben32(&new_vec[..], &mut writer)?; + jsonl_decode_ben32( + &overflow[0..last_valid_assignment], + &mut writer, + starting_sample, + variant, + )?; overflow = overflow[last_valid_assignment..].to_vec(); + starting_sample = line_count; } logln!(); logln!("Done!"); @@ -489,6 +655,5 @@ pub fn jsonl_decode_xben(reader: R, mut writer: W) -> io:: } #[cfg(test)] -mod tests { - include!("tests/decode_tests.rs"); -} +#[path = "tests/decode_tests.rs"] +mod tests; diff --git a/src/decode/read.rs b/src/decode/read.rs index 97824bd..47c2099 100644 --- a/src/decode/read.rs +++ b/src/decode/read.rs @@ -144,14 +144,18 @@ pub fn extract_assignment_ben( let mut check_buffer = [0u8; 17]; reader.read_exact(&mut check_buffer)?; - if &check_buffer != b"STANDARD BEN FILE" { - return Err(SampleError { - kind: SampleErrorKind::IoError(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )), - }); - } + let variant = match &check_buffer { + b"STANDARD BEN FILE" => BenVariant::Standard, + b"MKVCHAIN BEN FILE" => BenVariant::MkvChain, + _ => { + return Err(SampleError { + kind: SampleErrorKind::IoError(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )), + }) + } + }; let mut r_sample = 1; let mut writer = Vec::new(); @@ -176,10 +180,16 @@ pub fn extract_assignment_ben( let mut assign_bits: Vec = vec![0; n_bytes as usize]; reader.read_exact(&mut assign_bits)?; + let count_samples = if variant == BenVariant::MkvChain { + reader.read_u16::()? + } else { + 1 + }; + // Reader buffer gets thrown away after each iteration // and only decoded if we are in the right sample. // This speeds up the process significantly by not decoding all samples. - if r_sample == sample_number { + if r_sample == sample_number || r_sample + count_samples as usize > sample_number { // Write the ben header that is expected by jsonl_decode_ben let mut tmp_reader = b"STANDARD BEN FILE".to_vec(); // Write the actual ben data @@ -204,7 +214,11 @@ pub fn extract_assignment_ben( Ok(assignment) } +// #[cfg(test)] +// mod tests { +// include!("tests/read_tests.rs"); +// } + #[cfg(test)] -mod tests { - include!("tests/read_tests.rs"); -} +#[path = "tests/read_tests.rs"] +mod tests; diff --git a/src/decode/tests/decode_tests.rs b/src/decode/tests/decode_tests.rs index 60c18eb..c4f302a 100644 --- a/src/decode/tests/decode_tests.rs +++ b/src/decode/tests/decode_tests.rs @@ -401,14 +401,14 @@ fn test_decode_ben_multiple_simple_lines() { #[test] fn test_jsonl_decode_ben32_simple() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 1, 0, 4, 0, 2, 0, 1, 0, 3, 0, 3, 0, 0, 0, 0]); + let input = vec![0, 1, 0, 4, 0, 2, 0, 1, 0, 3, 0, 3, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); + if let Err(e) = result { panic!("Error: {}", e); } @@ -425,14 +425,13 @@ fn test_jsonl_decode_ben32_simple() { #[test] fn test_jsonl_decode_ben32_16_bit_val() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 1, 0, 4, 2, 0, 0, 1, 0, 3, 0, 3, 0, 0, 0, 0]); + let input = vec![0, 1, 0, 4, 2, 0, 0, 1, 0, 3, 0, 3, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); if let Err(e) = result { panic!("Error: {}", e); } @@ -449,14 +448,13 @@ fn test_jsonl_decode_ben32_16_bit_val() { #[test] fn test_jsonl_decode_ben32_16_bit_len() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 1, 0, 4, 0, 2, 2, 0, 0, 3, 0, 3, 0, 0, 0, 0]); + let input = vec![0, 1, 0, 4, 0, 2, 2, 0, 0, 3, 0, 3, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); if let Err(e) = result { panic!("Error: {}", e); } @@ -473,14 +471,13 @@ fn test_jsonl_decode_ben32_16_bit_len() { #[test] fn test_jsonl_decode_ben32_max_val_65535() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 23, 0, 4, 255, 255, 0, 15, 0, 8, 0, 3, 0, 0, 0, 0]); + let input = vec![0, 23, 0, 4, 255, 255, 0, 15, 0, 8, 0, 3, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); if let Err(e) = result { panic!("Error: {}", e); } @@ -497,14 +494,13 @@ fn test_jsonl_decode_ben32_max_val_65535() { #[test] fn test_jsonl_decode_ben32_max_len_65535() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 23, 0, 4, 0, 60, 255, 255, 0, 8, 0, 3, 0, 0, 0, 0]); + let input = vec![0, 23, 0, 4, 0, 60, 255, 255, 0, 8, 0, 3, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); if let Err(e) = result { panic!("Error: {}", e); } @@ -521,14 +517,14 @@ fn test_jsonl_decode_ben32_max_len_65535() { #[test] fn test_decode_ben32_single_element() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![0, 23, 0, 1, 0, 0, 0, 0]); + let input: Vec = vec![0, 23, 0, 1, 0, 0, 0, 0]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); + println!("result {:?}", result); if let Err(e) = result { panic!("Error: {}", e); } @@ -543,18 +539,17 @@ fn test_decode_ben32_single_element() { #[test] fn test_decode_ben32_multiple_simple_lines() { - let mut input: Vec = b"STANDARD BEN FILE".to_vec(); - input.extend(vec![ + let input = vec![ 0, 1, 0, 4, 0, 2, 0, 4, 0, 3, 0, 4, 0, 4, 0, 4, 0, 0, 0, 0, 0, 2, 0, 2, 0, 3, 0, 7, 0, 1, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 4, 0, 1, 0, 5, 0, 1, 0, 6, 0, 1, 0, 7, 0, 1, 0, 8, 0, 1, 0, 9, 0, 1, 0, 10, 0, 1, 0, 0, 0, 0, - ]); + ]; let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); let writer = &mut output; - let result = jsonl_decode_ben32(&mut reader, writer); + let result = jsonl_decode_ben32(&mut reader, writer, 0, BenVariant::Standard); if let Err(e) = result { panic!("Error: {}", e); } diff --git a/src/encode/mod.rs b/src/encode/mod.rs index 7d9ae8a..f39c83b 100644 --- a/src/encode/mod.rs +++ b/src/encode/mod.rs @@ -29,7 +29,7 @@ use std::io::{self, BufRead, Cursor, Read, Result, Write}; use xz2::write::XzEncoder; use self::translate::ben_to_ben32_lines; -use super::{log, logln}; +use super::{log, logln, BenVariant}; /// A struct to make the writing of BEN files easier /// and more ergonomic. @@ -37,31 +37,105 @@ use super::{log, logln}; /// # Example /// /// ``` -/// use ben::encode::BenEncoder; +/// use ben::{encode::BenEncoder, BenVariant}; /// /// let mut buffer = Vec::new(); -/// let mut ben_encoder = BenEncoder::new(&mut buffer); +/// let mut ben_encoder = BenEncoder::new(&mut buffer, BenVariant::Standard); /// /// ben_encoder.write_assignment(vec![1, 1, 1, 2, 2, 2]); /// ``` +// pub struct BenEncoder { +// writer: W, +// } + +// impl BenEncoder { +// /// Create a new BenEncoder instance and handles +// /// the BEN file header. +// pub fn new(mut writer: W) -> Self { +// writer.write_all(b"STANDARD BEN FILE").unwrap(); +// BenEncoder { writer } +// } + +// /// Write a run-length encoded assignment vector to the +// /// BEN file. +// pub fn write_rle(&mut self, rle_vec: Vec<(u16, u16)>) -> Result<()> { +// let encoded = encode_ben_vec_from_rle(rle_vec); +// self.writer.write_all(&encoded)?; +// Ok(()) +// } + +// /// Write an assignment vector to the BEN file. +// pub fn write_assignment(&mut self, assign_vec: Vec) -> Result<()> { +// let rle_vec = assign_to_rle(assign_vec); +// self.write_rle(rle_vec)?; +// Ok(()) +// } + +// /// Write a JSON value containing an assignment vector to the BEN file. +// pub fn write_json_value(&mut self, data: Value) -> Result<()> { +// let assign_vec = data["assignment"].as_array().unwrap(); +// let rle_vec = assign_to_rle( +// assign_vec +// .into_iter() +// .map(|x| x.as_u64().unwrap() as u16) +// .collect(), +// ); +// self.write_rle(rle_vec)?; +// Ok(()) +// } +// } + pub struct BenEncoder { writer: W, + previous_sample: Vec, + count: u16, + variant: BenVariant, } impl BenEncoder { /// Create a new BenEncoder instance and handles /// the BEN file header. - pub fn new(mut writer: W) -> Self { - writer.write_all(b"STANDARD BEN FILE").unwrap(); - BenEncoder { writer } + pub fn new(mut writer: W, variant: BenVariant) -> Self { + match variant { + BenVariant::Standard => { + writer.write_all(b"STANDARD BEN FILE").unwrap(); + } + BenVariant::MkvChain => { + writer.write_all(b"MKVCHAIN BEN FILE").unwrap(); + } + } + BenEncoder { + writer, + previous_sample: Vec::new(), + count: 0, + variant, + } } /// Write a run-length encoded assignment vector to the /// BEN file. pub fn write_rle(&mut self, rle_vec: Vec<(u16, u16)>) -> Result<()> { - let encoded = encode_ben_vec_from_rle(rle_vec); - self.writer.write_all(&encoded)?; - Ok(()) + match self.variant { + BenVariant::Standard => { + let encoded = encode_ben_vec_from_rle(rle_vec); + self.writer.write_all(&encoded)?; + Ok(()) + } + BenVariant::MkvChain => { + let encoded = encode_ben_vec_from_rle(rle_vec); + if encoded == self.previous_sample { + self.count += 1; + } else { + if self.count > 0 { + self.writer.write_all(&self.previous_sample)?; + self.writer.write_all(&self.count.to_be_bytes())?; + } + self.previous_sample = encoded; + self.count = 1; + } + Ok(()) + } + } } /// Write an assignment vector to the BEN file. @@ -85,25 +159,73 @@ impl BenEncoder { } } +impl Drop for BenEncoder { + fn drop(&mut self) { + if self.variant == BenVariant::MkvChain && self.count > 0 { + self.writer + .write_all(&self.previous_sample) + .expect("Error writing last line to file"); + self.writer + .write_all(&self.count.to_be_bytes()) + .expect("Error writing last line count to file"); + } + } +} + /// A struct to make the writing of XBEN files easier /// and more ergonomic. pub struct XBenEncoder { encoder: XzEncoder, + previous_sample: Vec, + count: u16, + variant: BenVariant, } impl XBenEncoder { - /// Create a new XBenEncoder instance and handles - /// the BEN file header. - pub fn new(mut encoder: XzEncoder) -> Self { - encoder.write_all(b"STANDARD BEN FILE").unwrap(); - XBenEncoder { encoder } + pub fn new(mut encoder: XzEncoder, variant: BenVariant) -> Self { + match variant { + BenVariant::Standard => { + encoder.write_all(b"STANDARD BEN FILE").unwrap(); + XBenEncoder { + encoder, + previous_sample: Vec::new(), + count: 0, + variant: BenVariant::Standard, + } + } + BenVariant::MkvChain => { + encoder.write_all(b"MKVCHAIN BEN FILE").unwrap(); + XBenEncoder { + encoder, + previous_sample: Vec::new(), + count: 0, + variant: BenVariant::MkvChain, + } + } + } } /// Write a an assigment vector encoded as a JSON value /// to the XBEN file. pub fn write_json_value(&mut self, data: Value) -> Result<()> { let encoded = encode_ben32_line(data); - self.encoder.write_all(&encoded)?; + match self.variant { + BenVariant::Standard => { + self.encoder.write_all(&encoded)?; + } + BenVariant::MkvChain => { + if encoded == self.previous_sample { + self.count += 1; + } else { + if self.count > 0 { + self.encoder.write_all(&self.previous_sample)?; + self.encoder.write_all(&self.count.to_be_bytes())?; + } + self.previous_sample = encoded; + self.count = 1; + } + } + } Ok(()) } @@ -115,15 +237,29 @@ impl XBenEncoder { reader.read_exact(&mut buff)?; // Create a new reader that prepends buff back onto the original reader - let mut reader = if buff != b"STANDARD BEN FILE".as_slice() { - let cursor = Cursor::new(buff.to_vec()); - let reader = cursor.chain(reader); - Box::new(reader) as Box - } else { - Box::new(reader) - }; + let mut reader = + if buff != b"STANDARD BEN FILE".as_slice() || buff != b"MKVCHAIN BEN FILE".as_slice() { + let cursor = Cursor::new(buff.to_vec()); + let reader = cursor.chain(reader); + Box::new(reader) as Box + } else { + Box::new(reader) + }; + + ben_to_ben32_lines(&mut *reader, &mut self.encoder, self.variant) + } +} - ben_to_ben32_lines(&mut *reader, &mut self.encoder) +impl Drop for XBenEncoder { + fn drop(&mut self) { + if self.variant == BenVariant::MkvChain && self.count > 0 { + self.encoder + .write_all(&self.previous_sample) + .expect("Error writing last line to file"); + self.encoder + .write_all(&self.count.to_be_bytes()) + .expect("Error writing last line count to file"); + } } } @@ -192,9 +328,13 @@ fn encode_ben32_line(data: Value) -> Vec { /// the byte level to achieve better compression ratios. In order /// to use XBEN files, the `decode_xben_to_ben` function must be /// used to decode the file back into a BEN format. -pub fn jsonl_encode_xben(reader: R, writer: W) -> Result<()> { +pub fn jsonl_encode_xben( + reader: R, + writer: W, + variant: BenVariant, +) -> Result<()> { let encoder = XzEncoder::new(writer, 9); - let mut ben_encoder = XBenEncoder::new(encoder); + let mut ben_encoder = XBenEncoder::new(encoder, variant); let mut line_num = 1; @@ -361,7 +501,7 @@ fn encode_ben_vec_from_rle(rle_vec: Vec<(u16, u16)>) -> Vec { /// ``` /// use std::io::{BufReader, BufWriter}; /// use serde_json::json; -/// use ben::encode::jsonl_encode_ben; +/// use ben::{encode::jsonl_encode_ben, BenVariant}; /// /// let input = r#"{"assignment": [1,1,1,2,2,2], "sample": 1}"#.to_string() /// + "\n" @@ -371,7 +511,7 @@ fn encode_ben_vec_from_rle(rle_vec: Vec<(u16, u16)>) -> Vec { /// let mut write_buffer = Vec::new(); /// let mut writer = BufWriter::new(&mut write_buffer); /// -/// jsonl_encode_ben(reader, writer).unwrap(); +/// jsonl_encode_ben(reader, writer, BenVariant::Standard).unwrap(); /// /// println!("{:?}", write_buffer); /// // This will output @@ -381,9 +521,13 @@ fn encode_ben_vec_from_rle(rle_vec: Vec<(u16, u16)>) -> Vec { /// // 2, 106, 89] /// ``` /// -pub fn jsonl_encode_ben(reader: R, writer: W) -> Result<()> { +pub fn jsonl_encode_ben( + reader: R, + writer: W, + variant: BenVariant, +) -> Result<()> { let mut line_num = 1; - let mut ben_encoder = BenEncoder::new(writer); + let mut ben_encoder = BenEncoder::new(writer, variant); for line_result in reader.lines() { log!("Encoding line: {}\r", line_num); line_num += 1; @@ -412,15 +556,18 @@ pub fn ben_encode_xben(mut reader: R, writer: W) -> Result let mut check_buffer = [0u8; 17]; reader.read_exact(&mut check_buffer)?; - if &check_buffer != b"STANDARD BEN FILE" { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } - let encoder = XzEncoder::new(writer, 9); - let mut ben_encoder = XBenEncoder::new(encoder); + + let mut ben_encoder = match &check_buffer { + b"STANDARD BEN FILE" => XBenEncoder::new(encoder, BenVariant::Standard), + b"MKVCHAIN BEN FILE" => XBenEncoder::new(encoder, BenVariant::MkvChain), + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )); + } + }; ben_encoder.write_ben_file(reader)?; @@ -428,6 +575,5 @@ pub fn ben_encode_xben(mut reader: R, writer: W) -> Result } #[cfg(test)] -mod tests { - include!("tests/encode_tests.rs"); -} +#[path = "tests/encode_tests.rs"] +mod tests; diff --git a/src/encode/relabel.rs b/src/encode/relabel.rs index d8e759f..8401e4f 100644 --- a/src/encode/relabel.rs +++ b/src/encode/relabel.rs @@ -23,8 +23,12 @@ use std::io::Error; /// /// Returns an error if the file format is invalid or if there is an issue reading or writing /// the file. -pub fn relabel_ben_lines(mut reader: R, mut writer: W) -> io::Result<()> { - let mut sample_number = 1; +pub fn relabel_ben_lines( + mut reader: R, + mut writer: W, + variant: BenVariant, +) -> io::Result<()> { + let mut sample_number = 0; loop { let mut tmp_buffer = [0u8]; let max_val_bits = match reader.read_exact(&mut tmp_buffer) { @@ -42,9 +46,6 @@ pub fn relabel_ben_lines(mut reader: R, mut writer: W) -> io: let mut ben_line = decode_ben_line(&mut reader, max_val_bits, max_len_bits, n_bytes)?; - log!("Relabeling line: {}\r", sample_number); - sample_number += 1; - // relabel the line let mut label = 0; let mut label_map = HashMap::new(); @@ -62,6 +63,18 @@ pub fn relabel_ben_lines(mut reader: R, mut writer: W) -> io: let relabeled = encode_ben_vec_from_rle(ben_line); writer.write_all(&relabeled)?; + + let count_occurrences = if variant == BenVariant::MkvChain { + let count = reader.read_u16::()?; + writer.write_all(&count.to_be_bytes())?; + count + } else { + 1 + }; + + sample_number += count_occurrences as usize; + + log!("Relabeling line: {}\r", sample_number); } logln!(); logln!("Done!"); @@ -88,16 +101,20 @@ pub fn relabel_ben_file(mut reader: R, mut writer: W) -> io:: let mut check_buffer = [0u8; 17]; reader.read_exact(&mut check_buffer)?; - if &check_buffer != b"STANDARD BEN FILE" { - return Err(Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } + let variant = match &check_buffer { + b"STANDARD BEN FILE" => BenVariant::Standard, + b"MKVCHAIN BEN FILE" => BenVariant::MkvChain, + _ => { + return Err(Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )); + } + }; - writer.write_all(b"STANDARD BEN FILE")?; + writer.write_all(&check_buffer)?; - relabel_ben_lines(&mut reader, &mut writer)?; + relabel_ben_lines(&mut reader, &mut writer, variant)?; Ok(()) } @@ -124,8 +141,9 @@ pub fn relabel_ben_lines_with_map( mut reader: R, mut writer: W, new_to_old_node_map: HashMap, + variant: BenVariant, ) -> io::Result<()> { - let mut sample_number = 1; + let mut sample_number = 0; loop { let mut tmp_buffer = [0u8]; let max_val_bits = match reader.read_exact(&mut tmp_buffer) { @@ -158,11 +176,19 @@ pub fn relabel_ben_lines_with_map( let new_rle = assign_to_rle(new_assignment_vec); - log!("Relabeling line: {}\r", sample_number); - sample_number += 1; - let relabeled = encode_ben_vec_from_rle(new_rle); writer.write_all(&relabeled)?; + + let count_occurrences = if variant == BenVariant::MkvChain { + let count = reader.read_u16::()?; + writer.write_all(&count.to_be_bytes())?; + count + } else { + 1 + }; + + sample_number += count_occurrences as usize; + log!("Relabeling line: {}\r", sample_number); } logln!(); logln!("Done!"); @@ -196,16 +222,20 @@ pub fn relabel_ben_file_with_map( let mut check_buffer = [0u8; 17]; reader.read_exact(&mut check_buffer)?; - if &check_buffer != b"STANDARD BEN FILE" { - return Err(Error::new( - io::ErrorKind::InvalidData, - "Invalid file format", - )); - } + let variant = match &check_buffer { + b"STANDARD BEN FILE" => BenVariant::Standard, + b"MKVCHAIN BEN FILE" => BenVariant::MkvChain, + _ => { + return Err(Error::new( + io::ErrorKind::InvalidData, + "Invalid file format", + )); + } + }; - writer.write_all(b"STANDARD BEN FILE")?; + writer.write_all(&check_buffer)?; - relabel_ben_lines_with_map(&mut reader, &mut writer, new_to_old_node_map)?; + relabel_ben_lines_with_map(&mut reader, &mut writer, new_to_old_node_map, variant)?; Ok(()) } @@ -244,7 +274,7 @@ mod tests { let expected = encode_ben_vec_from_rle(out_rle); let mut buf = Vec::new(); - relabel_ben_lines(input.as_slice(), &mut buf).unwrap(); + relabel_ben_lines(input.as_slice(), &mut buf, BenVariant::Standard).unwrap(); assert_eq!(buf, expected); } @@ -267,7 +297,7 @@ mod tests { let mut output = Vec::new(); let writer = io::BufWriter::new(&mut output); - jsonl_encode_ben(input, writer).unwrap(); + jsonl_encode_ben(input, writer, BenVariant::Standard).unwrap(); let mut output2 = Vec::new(); let writer2 = io::BufWriter::new(&mut output2); @@ -293,6 +323,56 @@ mod tests { assert_eq!(output_str, out_file); } + #[test] + fn test_relabel_simple_file_mkv() { + let file = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":1}", + "{\"assignment\":[2,1,3,4,5,5,3,4,2],\"sample\":2}", + "{\"assignment\":[3,3,1,1,2,2,3,3,4],\"sample\":3}", + "{\"assignment\":[4,3,2,1,4,3,2,1,1],\"sample\":4}", + "{\"assignment\":[3,2,2,4,1,3,1,4,3],\"sample\":5}", + "{\"assignment\":[3,2,2,4,1,3,1,4,3],\"sample\":6}", + "{\"assignment\":[3,2,2,4,1,3,1,4,3],\"sample\":7}", + "{\"assignment\":[2,2,3,3,4,4,5,5,1],\"sample\":8}", + "{\"assignment\":[2,4,1,5,2,4,3,1,3],\"sample\":9}", + "{\"assignment\":[2,4,1,5,2,4,3,1,3],\"sample\":10}" + ); + + let input = file.as_bytes(); + + let mut output = Vec::new(); + let writer = io::BufWriter::new(&mut output); + + jsonl_encode_ben(input, writer, BenVariant::MkvChain).unwrap(); + + let mut output2 = Vec::new(); + let writer2 = io::BufWriter::new(&mut output2); + relabel_ben_file(output.as_slice(), writer2).unwrap(); + + let mut output3 = Vec::new(); + let writer3 = io::BufWriter::new(&mut output3); + jsonl_decode_ben(output2.as_slice(), writer3).unwrap(); + + let output_str = String::from_utf8(output3).unwrap(); + + let out_file = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":1}", + "{\"assignment\":[1,2,3,4,5,5,3,4,1],\"sample\":2}", + "{\"assignment\":[1,1,2,2,3,3,1,1,4],\"sample\":3}", + "{\"assignment\":[1,2,3,4,1,2,3,4,4],\"sample\":4}", + "{\"assignment\":[1,2,2,3,4,1,4,3,1],\"sample\":5}", + "{\"assignment\":[1,2,2,3,4,1,4,3,1],\"sample\":6}", + "{\"assignment\":[1,2,2,3,4,1,4,3,1],\"sample\":7}", + "{\"assignment\":[1,1,2,2,3,3,4,4,5],\"sample\":8}", + "{\"assignment\":[1,2,3,4,1,2,5,3,5],\"sample\":9}", + "{\"assignment\":[1,2,3,4,1,2,5,3,5],\"sample\":10}" + ); + + assert_eq!(output_str, out_file); + } + #[test] fn test_relabel_ben_line_with_map() { let in_assign = vec![2, 3, 1, 4, 5, 5, 3, 4, 2]; @@ -316,7 +396,13 @@ mod tests { new_to_old_map.insert(8, 5); let mut buf = Vec::new(); - relabel_ben_lines_with_map(input.as_slice(), &mut buf, new_to_old_map).unwrap(); + relabel_ben_lines_with_map( + input.as_slice(), + &mut buf, + new_to_old_map, + BenVariant::Standard, + ) + .unwrap(); assert_eq!(buf, expected); } @@ -334,7 +420,13 @@ mod tests { let expected = encode_ben_vec_from_rle(out_rle); let mut buf = Vec::new(); - relabel_ben_lines_with_map(input.as_slice(), &mut buf, new_to_old_map).unwrap(); + relabel_ben_lines_with_map( + input.as_slice(), + &mut buf, + new_to_old_map, + BenVariant::Standard, + ) + .unwrap(); assert_eq!(buf, expected); } @@ -359,7 +451,13 @@ mod tests { let expected = encode_ben_vec_from_rle(out_rle); let mut buf = Vec::new(); - relabel_ben_lines_with_map(input.as_slice(), &mut buf, new_to_old_map).unwrap(); + relabel_ben_lines_with_map( + input.as_slice(), + &mut buf, + new_to_old_map, + BenVariant::Standard, + ) + .unwrap(); assert_eq!(buf, expected); } @@ -397,7 +495,7 @@ mod tests { let mut output = Vec::new(); let writer = io::BufWriter::new(&mut output); - jsonl_encode_ben(input, writer).unwrap(); + jsonl_encode_ben(input, writer, BenVariant::Standard).unwrap(); let mut output2 = Vec::new(); let writer2 = io::BufWriter::new(&mut output2); @@ -422,4 +520,69 @@ mod tests { assert_eq!(output_str, out_file); } + + #[test] + fn test_relabel_simple_file_with_map_mkv() { + let file = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":1}", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":2}", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":3}", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":4}", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":5}", + "{\"assignment\":[1,2,3,4,5,5,3,4,2],\"sample\":6}", + "{\"assignment\":[2,1,3,4,5,5,3,4,2],\"sample\":7}", + "{\"assignment\":[2,1,3,4,5,5,3,4,2],\"sample\":8}", + "{\"assignment\":[2,1,3,4,5,5,3,4,2],\"sample\":9}", + "{\"assignment\":[2,4,1,5,2,4,3,1,3],\"sample\":10}", + ); + + let new_to_old_map: HashMap = [ + (0, 2), + (1, 3), + (2, 4), + (3, 5), + (4, 6), + (5, 7), + (6, 8), + (7, 0), + (8, 1), + ] + .iter() + .cloned() + .collect(); + + let input = file.as_bytes(); + + let mut output = Vec::new(); + let writer = io::BufWriter::new(&mut output); + + jsonl_encode_ben(input, writer, BenVariant::MkvChain).unwrap(); + + let mut output2 = Vec::new(); + let writer2 = io::BufWriter::new(&mut output2); + relabel_ben_file_with_map(output.as_slice(), writer2, new_to_old_map).unwrap(); + + let mut output3 = Vec::new(); + let writer3 = io::BufWriter::new(&mut output3); + jsonl_decode_ben(output2.as_slice(), writer3).unwrap(); + + let output_str = String::from_utf8(output3).unwrap(); + + let out_file = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":1}", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":2}", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":3}", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":4}", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":5}", + "{\"assignment\":[3,4,5,5,3,4,2,1,2],\"sample\":6}", + "{\"assignment\":[3,4,5,5,3,4,2,2,1],\"sample\":7}", + "{\"assignment\":[3,4,5,5,3,4,2,2,1],\"sample\":8}", + "{\"assignment\":[3,4,5,5,3,4,2,2,1],\"sample\":9}", + "{\"assignment\":[1,5,2,4,3,1,3,2,4],\"sample\":10}", + ); + + assert_eq!(output_str, out_file); + } } diff --git a/src/encode/tests/encode_tests.rs b/src/encode/tests/encode_tests.rs index 4ecde57..d60bb0c 100644 --- a/src/encode/tests/encode_tests.rs +++ b/src/encode/tests/encode_tests.rs @@ -27,7 +27,11 @@ fn test_jsonl_encode_ben_underflow() { 0b01_11011_0, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -72,7 +76,11 @@ fn test_jsonl_encode_ben_exact() { 0b001_11001_, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -108,7 +116,11 @@ fn test_jsonl_encode_ben_16_bit_val() { 0b0011011_0, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -144,7 +156,11 @@ fn test_jsonl_encode_ben_16_bit_len() { 0b0011_0000, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -183,7 +199,11 @@ fn test_jsonl_encode_ben_max_val_65535() { 0b0011_0000, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -223,7 +243,11 @@ fn test_jsonl_encode_ben_len_65535() { 0b11_000000, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -266,7 +290,11 @@ fn jsonl_encode_ben_max_val_and_len_at_65535() { 0b00000100_, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -298,7 +326,11 @@ fn jsonl_encode_ben_single_element() { 0b101111_00, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -330,7 +362,11 @@ fn jsonl_encode_ben_single_zero() { 0b01_000000, ]); - let output = jsonl_encode_ben(json!(data).to_string().as_bytes(), writer); + let output = jsonl_encode_ben( + json!(data).to_string().as_bytes(), + writer, + BenVariant::Standard, + ); if let Err(e) = output { panic!("Error: {}", e); } @@ -408,7 +444,7 @@ fn jsonl_encode_ben_multiple_simple_lines() { 0b01_000000, ]); - let output = jsonl_encode_ben(full_data.as_bytes(), writer); + let output = jsonl_encode_ben(full_data.as_bytes(), writer, BenVariant::Standard); if let Err(e) = output { panic!("Error {}", e); } diff --git a/src/encode/tests/translate_tests.rs b/src/encode/tests/translate_tests.rs index fe5fb7d..4ad0ab2 100644 --- a/src/encode/tests/translate_tests.rs +++ b/src/encode/tests/translate_tests.rs @@ -29,7 +29,7 @@ fn translate_ben32_to_ben_file(mut reader: R, mut writer: W) } writer.write_all(b"STANDARD BEN FILE")?; - ben32_to_ben_lines(reader, writer) + ben32_to_ben_lines(reader, writer, BenVariant::Standard) } fn translate_ben_to_ben32_file(mut reader: R, mut writer: W) -> io::Result<()> { @@ -44,7 +44,7 @@ fn translate_ben_to_ben32_file(mut reader: R, mut writer: W) } writer.write_all(b"STANDARD BEN FILE")?; - ben_to_ben32_lines(reader, writer) + ben_to_ben32_lines(reader, writer, BenVariant::Standard) } #[test] @@ -80,7 +80,7 @@ fn test_simple_translation_ben32_to_ben() { let mut buffer: Vec = Vec::new(); let writer2 = &mut buffer; - jsonl_encode_ben(full_data.as_bytes(), writer2).unwrap(); + jsonl_encode_ben(full_data.as_bytes(), writer2, BenVariant::Standard).unwrap(); assert_eq!(writer, &buffer); } @@ -134,7 +134,7 @@ fn test_random_translation_ben32_to_ben() { let mut buffer: Vec = Vec::new(); let writer2 = &mut buffer; - jsonl_encode_ben(full_data.as_bytes(), writer2).unwrap(); + jsonl_encode_ben(full_data.as_bytes(), writer2, BenVariant::Standard).unwrap(); assert_eq!(writer, &buffer); } @@ -159,7 +159,7 @@ fn test_simple_translation_ben_to_ben32() { let mut input: Vec = Vec::new(); let input_writer = &mut input; - jsonl_encode_ben(full_data.as_bytes(), input_writer).unwrap(); + jsonl_encode_ben(full_data.as_bytes(), input_writer, BenVariant::Standard).unwrap(); let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); @@ -214,7 +214,7 @@ fn test_random_translation_ben_to_ben32() { let mut input: Vec = Vec::new(); let input_writer = &mut input; - jsonl_encode_ben(full_data.as_bytes(), input_writer).unwrap(); + jsonl_encode_ben(full_data.as_bytes(), input_writer, BenVariant::Standard).unwrap(); let mut reader = input.as_slice(); let mut output: Vec = Vec::new(); diff --git a/src/encode/translate.rs b/src/encode/translate.rs index 15364a3..c7b17e3 100644 --- a/src/encode/translate.rs +++ b/src/encode/translate.rs @@ -6,7 +6,7 @@ use byteorder::{BigEndian, ReadBytesExt}; use std::io::{self, Error, Read, Write}; -use super::{log, logln}; +use super::{log, logln, BenVariant}; use crate::decode::decode_ben_line; use crate::encode::encode_ben_vec_from_rle; @@ -81,17 +81,26 @@ fn ben32_to_ben_line(ben32_vec: Vec) -> io::Result> { /// /// This function will return an error if the input reader contains invalid ben32 /// data or if the writer encounters an error while writing the ben data. -pub fn ben32_to_ben_lines(mut reader: R, mut writer: W) -> io::Result<()> { +pub fn ben32_to_ben_lines( + mut reader: R, + mut writer: W, + variant: BenVariant, +) -> io::Result<()> { 'outer: loop { let mut ben32_vec: Vec = Vec::new(); let mut ben32_read_buff: [u8; 4] = [0u8; 4]; + let mut n_reps = 0; + // extract the ben32 data 'inner: loop { match reader.read_exact(&mut ben32_read_buff) { Ok(()) => { ben32_vec.extend(ben32_read_buff); if ben32_read_buff == [0u8; 4] { + if variant == BenVariant::MkvChain { + n_reps = reader.read_u16::()?; + } break 'inner; } } @@ -106,6 +115,9 @@ pub fn ben32_to_ben_lines(mut reader: R, mut writer: W) -> io let ben_vec = ben32_to_ben_line(ben32_vec)?; writer.write_all(&ben_vec)?; + if variant == BenVariant::MkvChain { + writer.write_all(&n_reps.to_be_bytes())?; + } } Ok(()) @@ -160,7 +172,11 @@ fn ben_to_ben32_line( /// /// This function will return an error if the input reader contains invalid ben /// data or if the writer encounters an error while writing the ben32 data. -pub fn ben_to_ben32_lines(mut reader: R, mut writer: W) -> io::Result<()> { +pub fn ben_to_ben32_lines( + mut reader: R, + mut writer: W, + variant: BenVariant, +) -> io::Result<()> { let mut sample_number = 1; 'outer: loop { let mut tmp_buffer = [0u8]; @@ -178,10 +194,25 @@ pub fn ben_to_ben32_lines(mut reader: R, mut writer: W) -> io let n_bytes = reader.read_u32::()?; log!("Encoding line: {}\r", sample_number); - sample_number += 1; - let ben32_vec = ben_to_ben32_line(&mut reader, max_val_bits, max_len_bits, n_bytes)?; - writer.write_all(&ben32_vec)?; + match variant { + BenVariant::Standard => { + sample_number += 1; + let ben32_vec = + ben_to_ben32_line(&mut reader, max_val_bits, max_len_bits, n_bytes)?; + writer.write_all(&ben32_vec)?; + } + BenVariant::MkvChain => { + let ben32_vec = + ben_to_ben32_line(&mut reader, max_val_bits, max_len_bits, n_bytes)?; + + // Read the number of repetitions AFTER the ben32 data + let n_reps = reader.read_u16::()?; + sample_number += n_reps as usize; + writer.write_all(&ben32_vec)?; + writer.write_all(&n_reps.to_be_bytes())?; + } + } } logln!(); @@ -190,6 +221,5 @@ pub fn ben_to_ben32_lines(mut reader: R, mut writer: W) -> io } #[cfg(test)] -mod tests { - include!("tests/translate_tests.rs"); -} +#[path = "tests/translate_tests.rs"] +mod tests; diff --git a/src/lib.rs b/src/lib.rs index 0566314..22b7cea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,3 +43,9 @@ macro_rules! logln { } }} } + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BenVariant { + Standard, + MkvChain, +} diff --git a/tests/test_pipeline.rs b/tests/test_pipeline.rs index 8303370..7263a96 100644 --- a/tests/test_pipeline.rs +++ b/tests/test_pipeline.rs @@ -1,6 +1,7 @@ use ben::decode::*; use ben::encode::*; use ben::utils::*; +use ben::BenVariant; use serde_json::json; use std::io::{Cursor, Read, Write}; @@ -53,7 +54,7 @@ fn test_ben_pipeline() { let mut output_writer = Vec::new(); // Assume these functions are adapted to work with streams - jsonl_encode_ben(&mut buffer, &mut input_writer).unwrap(); + jsonl_encode_ben(&mut buffer, &mut input_writer, BenVariant::Standard).unwrap(); buffer.set_position(0); // Reset if needed for reuse jsonl_decode_ben(&input_writer[..], &mut output_writer).unwrap(); @@ -66,12 +67,73 @@ fn test_ben_pipeline() { } #[test] -fn test_xben_pipeline() { +fn test_mkvben_pipeline() { let seed = 129530786u64; let mut rng = ChaCha8Rng::seed_from_u64(seed); let n_samples = 100; + let shape = 2.0; + let scale = 50.0; + let gamma = Gamma::new(shape, scale).unwrap(); + + let mu = Uniform::new(1, 51); + let count = Uniform::new(1, 11); + + // In-memory buffer for streaming + let mut buffer = Cursor::new(Vec::new()); + + eprintln!(); + let mut sample_count = 0; + while sample_count < n_samples { + eprint!("Generating sample: {}\r", sample_count + 1); + let mut rle_vec = Vec::new(); + while rle_vec.len() < 500 { + rle_vec.push((mu.sample(&mut rng) as u16, gamma.sample(&mut rng) as u16)); + } + + for _ in 0..count.sample(&mut rng) { + sample_count += 1; + // Directly write each JSON line to the buffer + writeln!( + &mut buffer, + "{}", + json!({ + "assignment": rle_to_vec(rle_vec.clone()), + "sample": sample_count, + }) + ) + .unwrap(); + } + } + eprintln!(); + + // Reset buffer cursor to the start + buffer.set_position(0); + + let mut input_writer = Vec::new(); + let mut output_writer = Vec::new(); + + // Assume these functions are adapted to work with streams + jsonl_encode_ben(&mut buffer, &mut input_writer, BenVariant::MkvChain).unwrap(); + buffer.set_position(0); // Reset if needed for reuse + jsonl_decode_ben(&input_writer[..], &mut output_writer).unwrap(); + + // Reset buffer to compare + buffer.set_position(0); + let mut original_data = Vec::new(); + buffer.read_to_end(&mut original_data).unwrap(); + + assert_eq!(original_data, output_writer); +} + +#[test] +fn test_xben_pipeline() { + let seed = 129530786u64; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + let n_samples = 50; + let shape = 2.0; let scale = 200.0; let gamma = Gamma::new(shape, scale).unwrap(); @@ -114,7 +176,69 @@ fn test_xben_pipeline() { let mut output_writer = Vec::new(); // Assume these functions are adapted to work with streams - jsonl_encode_xben(sample_writer, &mut input_writer).unwrap(); + jsonl_encode_xben(sample_writer, &mut input_writer, BenVariant::Standard).unwrap(); + decode_xben_to_ben(&input_writer[..], &mut output_writer).unwrap(); + + let mut xoutput_writer = Vec::new(); + jsonl_decode_ben(&output_writer[..], &mut xoutput_writer).unwrap(); + + assert_eq!(original_data, xoutput_writer); +} + +#[test] +fn test_xmkvben_pipeline() { + let seed = 129530786u64; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + let n_samples = 50; + + let shape = 2.0; + let scale = 200.0; + let gamma = Gamma::new(shape, scale).unwrap(); + + let mu = Uniform::new(1, 51); + let count = Uniform::new(1, 11); + + // In-memory buffer for streaming + let mut buffer = Vec::new(); + let mut sample_writer = Cursor::new(&mut buffer); + + eprintln!(); + let mut sample_count = 0; + while sample_count < n_samples { + eprint!("Generating sample: {}\r", sample_count + 1); + let mut rle_vec = Vec::new(); + while rle_vec.len() < 500 { + rle_vec.push((mu.sample(&mut rng) as u16, gamma.sample(&mut rng) as u16)); + } + + for _ in 0..count.sample(&mut rng) { + sample_count += 1; + // Directly write each JSON line to the buffer + writeln!( + &mut sample_writer, + "{}", + json!({ + "assignment": rle_to_vec(rle_vec.clone()), + "sample": sample_count, + }) + ) + .unwrap(); + } + } + eprintln!(); + + sample_writer.set_position(0); + let mut original_data = Vec::new(); + sample_writer.read_to_end(&mut original_data).unwrap(); + + sample_writer.set_position(0); + + let mut input_writer = Vec::new(); + let mut output_writer = Vec::new(); + + // Assume these functions are adapted to work with streams + jsonl_encode_xben(sample_writer, &mut input_writer, BenVariant::MkvChain).unwrap(); decode_xben_to_ben(&input_writer[..], &mut output_writer).unwrap(); let mut xoutput_writer = Vec::new();