Skip to content

Commit

Permalink
feat: streamlined dataset reading, shuffling, and splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 12, 2023
1 parent 2bf9c1e commit 90e467b
Showing 1 changed file with 29 additions and 40 deletions.
69 changes: 29 additions & 40 deletions cakes-results/src/genomic/read_silva.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,19 @@ pub fn silva_to_dataset(
.to_str()
.ok_or_else(|| format!("Could not convert file stem to string for {unaligned_path:?}"))?;

// Open the file and read the lines.
// Open the unaligned sequences file and read the lines.
let file = File::open(unaligned_path)
.map_err(|e| format!("Could not open file {unaligned_path:?}: {e}"))?;
let reader = BufReader::new(file);
let lines = reader
let sequences = reader
.lines()
.map(|line| line.map_err(|e| format!("Could not read line: {e}")))
.collect::<Result<Vec<_>, _>>()?;

// shuffle the lines and keep track of the original indices.
let mut lines = lines.into_iter().enumerate().collect::<Vec<_>>();
lines.shuffle(&mut rand::thread_rng());

// Collect the first 1000 lines for the query set. The remaining lines are
// for the training set.
let (train_indices, train_sequences): (Vec<_>, Vec<_>) =
lines.split_off(1000).into_iter().unzip();
let train = VecDataset::new(
format!("{stem}-train"),
train_sequences,
metric,
is_expensive,
info!(
"Read {} sequences from {unaligned_path:?}.",
sequences.len()
);

// Collect the lines for the query set.
let (query_indices, queries): (Vec<_>, Vec<_>) = lines.into_iter().unzip();
let queries = VecDataset::new(format!("{stem}-queries"), queries, metric, is_expensive);

// Read the headers file.
let file = File::open(headers_path)
.map_err(|e| format!("Could not open file {headers_path:?}: {e}"))?;
Expand All @@ -89,34 +74,38 @@ pub fn silva_to_dataset(
.collect::<Result<Vec<_>, _>>()?;
info!("Read {} headers from {headers_path:?}.", headers.len());

// Split the headers into the training and query sets.
let (query_headers, train_headers) = headers
.into_iter()
.enumerate()
.partition::<Vec<_>, _>(|(i, _)| query_indices.contains(i));
// join the lines and headers into a single vector of (line, header) pairs.
let mut sequences = sequences.into_iter().zip(headers).collect::<Vec<_>>();
sequences.shuffle(&mut thread_rng());
info!("Shuffled sequences and headers.");

let train_headers = train_headers
.into_iter()
.filter(|(i, _)| train_indices.contains(i))
.map(|(_, h)| h)
.collect::<Vec<_>>();
let train_headers = VecDataset::new(
format!("{stem}-train-headers"),
train_headers,
// Split the lines into the training and query sets.
let queries = sequences.split_off(1000);
let (queries, query_headers): (Vec<_>, Vec<_>) = queries.into_iter().unzip();
let queries = VecDataset::new(format!("{stem}-queries"), queries, metric, is_expensive);
let query_headers = VecDataset::new(
format!("{stem}-query-headers"),
query_headers,
metric,
is_expensive,
);
info!(
"Using {} sequences for queries.",
query_headers.cardinality()
);

let query_headers = query_headers
.into_iter()
.map(|(_, h)| h)
.collect::<Vec<_>>();
let query_headers = VecDataset::new(
format!("{stem}-query-headers"),
query_headers,
let (train, train_headers): (Vec<_>, Vec<_>) = sequences.into_iter().unzip();
let train = VecDataset::new(format!("{stem}-train"), train, metric, is_expensive);
let train_headers = VecDataset::new(
format!("{stem}-train-headers"),
train_headers,
metric,
is_expensive,
);
info!(
"Using {} sequences for training.",
train_headers.cardinality()
);

assert_eq!(train.cardinality(), train_headers.cardinality());
assert_eq!(queries.cardinality(), query_headers.cardinality());
Expand Down

0 comments on commit 90e467b

Please sign in to comment.