Skip to content

Commit

Permalink
Multi-threaded benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
khb7840 committed Nov 4, 2024
1 parent 3a6597b commit 028350a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 109 deletions.
2 changes: 2 additions & 0 deletions src/cli/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ fn parse_arg() -> Result<AppArgs, Box<dyn std::error::Error>> {
result: args.opt_value_from_str(["-r", "--result"])?,
answer: args.opt_value_from_str(["-a", "--answer"])?,
index: args.opt_value_from_str(["-i", "--index"])?,
input: args.opt_value_from_str("--input")?,
format: args.value_from_str(["-f", "--format"]).unwrap_or("tsv".into()),
fp: args.opt_value_from_str("--fp")?,
threads: args.value_from_str(["-t", "--threads"]).unwrap_or(1),
}),
Some("test") => Ok(AppArgs::Test {
index_path: args.value_from_str(["-i", "--index"])?,
Expand Down
2 changes: 2 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ pub enum AppArgs {
result: Option<String>,
answer: Option<String>,
index: Option<String>,
input: Option<String>,
format: String,
fp: Option<f64>,
threads: usize,
},
Test {
index_path: String,
Expand Down
242 changes: 133 additions & 109 deletions src/cli/workflows/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashSet;
use std::io::BufRead;

use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use crate::cli::*;
use crate::index::lookup::load_lookup_from_file;
use crate::prelude::*;
Expand All @@ -16,124 +18,117 @@ pub fn benchmark(env: AppArgs) {
result,
answer,
index,
input,
format,
fp,
threads,
} => {
if result.is_none() || answer.is_none() || index.is_none() {
print_log_msg(FAIL, "Result, answer, and index files must be provided");
std::process::exit(1);
if input.is_none() {
if result.is_none() || answer.is_none() || index.is_none() {
print_log_msg(FAIL, "Result, answer, and index files must be provided");
std::process::exit(1);
}
}
let result_path = result.unwrap();
let answer_path = answer.unwrap();
// If input is given, read from file
let input_vector = if input.is_none() {
vec![(result.unwrap(), answer.unwrap())]
} else {
let input = input.unwrap();
let file = std::fs::File::open(&input).expect(
&log_msg(FAIL, &format!("Failed to open input file: {}", input))
);
let reader = std::io::BufReader::new(file);
reader.lines().map(|line| {
let line = line.expect(&log_msg(FAIL, "Failed to read line"));
let row = line.split('\t').collect::<Vec<_>>();
(row[0].to_string(), row[1].to_string())
}).collect::<Vec<_>>()
};

// TODO: Add false list (3rd column), neutral list (4th column)

let _pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build_global().unwrap();

let index_path = index.unwrap();
let lookup_path = format!("{}.lookup", index_path);
let config_path = format!("{}.type", index_path);
let format = format.as_str();

// let result = read_one_column_of_tsv(&result_path, 0);
let result = read_one_column_of_tsv_as_vec(&result_path, 0);
let answer = read_one_column_of_tsv(&answer_path, 0);
let lookup = load_lookup_from_file(&lookup_path);
let lookup = lookup.into_iter().map(|(id, _, _, _)| id).collect::<HashSet<_>>();
// Parse path by id type
let result = result.into_iter().map(|id| {
let id = id.split('/').last().unwrap();
// remove extension. Split by '.' and get the all elements except the last one
let id_split = id.split('.').collect::<Vec<_>>();
// If last element in pdb, cif, fcz, ent pdb.gz, cif.gz, fcz.gz, ent.gz; return except last element
if id_split.last().unwrap() == &"pdb" || id_split.last().unwrap() == &"cif" || id_split.last().unwrap() == &"fcz" || id_split.last().unwrap() == &"ent" {
id_split[..id_split.len()-1].join(".")
} else if id_split.last().unwrap() == &".gz" {
id_split[..id_split.len()-2].join(".")
} else {
id_split.join(".")
}
}).collect::<Vec<_>>();
let answer = answer.into_iter().map(|id| {
let id = id.split('/').last().unwrap();
// remove extension. Split by '.' and get the all elements except the last one
let id_split = id.split('.').collect::<Vec<_>>();
// If last element in pdb, cif, fcz, ent pdb.gz, cif.gz, fcz.gz, ent.gz; return except last element
if id_split.last().unwrap() == &"pdb" || id_split.last().unwrap() == &"cif" || id_split.last().unwrap() == &"fcz" || id_split.last().unwrap() == &"ent" {
id_split[..id_split.len()-1].join(".")
} else if id_split.last().unwrap() == &".gz" {
id_split[..id_split.len()-2].join(".")
} else {
id_split.join(".")
}
}).collect::<HashSet<_>>();
let lookup = lookup.into_iter().map(|id| {
let id = id.split('/').last().unwrap();
// remove extension. Split by '.' and get the all elements except the last one
let id_split = id.split('.').collect::<Vec<_>>();
// If last element in pdb, cif, fcz, ent pdb.gz, cif.gz, fcz.gz, ent.gz; return except last element
if id_split.last().unwrap() == &"pdb" || id_split.last().unwrap() == &"cif" || id_split.last().unwrap() == &"fcz" || id_split.last().unwrap() == &"ent" {
id_split[..id_split.len()-1].join(".")
} else if id_split.last().unwrap() == &".gz" {
id_split[..id_split.len()-2].join(".")
} else {
id_split.join(".")
}
}).collect::<HashSet<_>>();
let raw_lookup = load_lookup_from_file(&lookup_path);
let raw_lookup = raw_lookup.into_iter().map(|(id, _, _, _)| id).collect::<HashSet<_>>();
let mut lookup = HashSet::with_capacity(raw_lookup.len());
parse_path_set_as_set(&raw_lookup, &mut lookup);
let config = read_index_config_from_file(&config_path);
let result_set = HashSet::from_iter(result.iter().cloned());
let metric = if let Some(fp) = fp {
measure_up_to_k_fp(&result, &answer, &lookup, fp)
} else {
compare_target_answer_set(&result_set, &answer, &lookup)
};

match format {
"tsv" => {
// lookup, result, answer, lookup_len, result_len, answer_len,
// hash_type, num_bin_dist, num_bin_angle,
// true_pos, true_neg, false_pos, false_neg, precision, recall, accuracy, f1_score,
println!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.4}\t{:.4}\t{:.4}\t{:.4}",
index_path,
result_path,
answer_path,
lookup.len(),
result.len(),
answer.len(),
config.hash_type.to_string(),
config.num_bin_dist,
config.num_bin_angle,
metric.true_pos,
metric.true_neg,
metric.false_pos,
metric.false_neg,
metric.precision(),
metric.recall(),
metric.accuracy(),
metric.f1_score(),
);
}
"default" => {
println!("Index: {}", lookup_path);
println!("Result: {}", result_path);
println!("Answer: {}", answer_path);
println!("Total ids: {}", lookup.len());
println!("Result length: {}", result.len());
println!("Answer length: {}", answer.len());
println!("Hash type: {}", config.hash_type.to_string());
println!("Number of distance bins: {}", config.num_bin_dist);
println!("Number of angle bins: {}", config.num_bin_angle);
println!("TP: {}", metric.true_pos);
println!("TN: {}", metric.true_neg);
println!("FP: {}", metric.false_pos);
println!("FN: {}", metric.false_neg);
// Print float with 4 decimal places
println!("Precision: {:.4}", metric.precision());
println!("Recall: {:.4}", metric.recall());
println!("Accuracy: {:.4}", metric.accuracy());
println!("F1 score: {:.4}", metric.f1_score());
}
_ => {
print_log_msg(FAIL, "Invalid format");
std::process::exit(1);

input_vector.par_iter().for_each(|(result_path, answer_path)| {
// Parse path by id type
let raw_result = read_one_column_of_tsv_as_vec(&result_path, 0);
let mut result = Vec::with_capacity(raw_result.len());
parse_path_vector_as_vec(&raw_result, &mut result);

let raw_answer = read_one_column_of_tsv_as_set(&answer_path, 0);
let mut answer = HashSet::with_capacity(raw_answer.len());
parse_path_set_as_set(&raw_answer, &mut answer);

let result_set = HashSet::from_iter(result.iter().cloned());
let metric = if let Some(fp) = fp {
measure_up_to_k_fp(&result, &answer, &lookup, fp)
} else {
compare_target_answer_set(&result_set, &answer, &lookup)
};

match format {
"tsv" => {
// lookup, result, answer, lookup_len, result_len, answer_len,
// hash_type, num_bin_dist, num_bin_angle,
// true_pos, true_neg, false_pos, false_neg, precision, recall, accuracy, f1_score,
println!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.4}\t{:.4}\t{:.4}\t{:.4}",
index_path,
result_path,
answer_path,
lookup.len(),
result.len(),
answer.len(),
config.hash_type.to_string(),
config.num_bin_dist,
config.num_bin_angle,
metric.true_pos,
metric.true_neg,
metric.false_pos,
metric.false_neg,
metric.precision(),
metric.recall(),
metric.accuracy(),
metric.f1_score(),
);
}
"default" => {
println!("Index: {}", lookup_path);
println!("Result: {}", result_path);
println!("Answer: {}", answer_path);
println!("Total ids: {}", lookup.len());
println!("Result length: {}", result.len());
println!("Answer length: {}", answer.len());
println!("Hash type: {}", config.hash_type.to_string());
println!("Number of distance bins: {}", config.num_bin_dist);
println!("Number of angle bins: {}", config.num_bin_angle);
println!("TP: {}", metric.true_pos);
println!("TN: {}", metric.true_neg);
println!("FP: {}", metric.false_pos);
println!("FN: {}", metric.false_neg);
// Print float with 4 decimal places
println!("Precision: {:.4}", metric.precision());
println!("Recall: {:.4}", metric.recall());
println!("Accuracy: {:.4}", metric.accuracy());
println!("F1 score: {:.4}", metric.f1_score());
}
_ => {
print_log_msg(FAIL, "Invalid format");
std::process::exit(1);
}
}
}
});

}
_ => {
eprintln!("Invalid subcommand");
Expand All @@ -142,7 +137,7 @@ pub fn benchmark(env: AppArgs) {
}
}

fn read_one_column_of_tsv(file: &str, col_index: usize) -> HashSet<String> {
fn read_one_column_of_tsv_as_set(file: &str, col_index: usize) -> HashSet<String> {
let mut set = HashSet::new();
// Open file and get specific column
let file = std::fs::File::open(file).expect(
Expand Down Expand Up @@ -172,6 +167,31 @@ fn read_one_column_of_tsv_as_vec(file: &str, col_index: usize) -> Vec<String> {
vec
}

#[inline]
fn parse_path(path: &str) -> &str {
let path = path.split('/').last().unwrap();
if path.ends_with(".pdb") || path.ends_with(".cif") || path.ends_with(".fcz") || path.ends_with(".ent") {
// Return slice of string from start to end-4
&path[..path.len()-4]
} else if path.ends_with(".pdb.gz") || path.ends_with(".cif.gz") || path.ends_with(".fcz.gz") || path.ends_with(".ent.gz") {
// Return slice of string from start to end-7
&path[..path.len()-7]
} else {
path
}
}

fn parse_path_vector_as_vec<'a>(path_vector: &'a Vec<String>, parsed_vector: &mut Vec<&'a str>) {
// parallel
for path in path_vector {
parsed_vector.push(parse_path(path));
}
}
fn parse_path_set_as_set<'a>(path_set: &'a HashSet<String>, parsed_set: &mut HashSet<&'a str>) {
for path in path_set {
parsed_set.insert(parse_path(path));
}
}

#[cfg(test)]
mod tests {
Expand All @@ -183,14 +203,18 @@ mod tests {
let result = Some("data/zinc_folddisco.tsv".to_string());
let answer = Some("data/zinc_answer.tsv".to_string());
let index = Some("analysis/h_sapiens/d16a4/index_id".to_string());
let input = None;
let format = "tsv";
let fp = None;
let threads = 1;
let env = AppArgs::Benchmark {
result,
answer,
index,
input,
format: format.to_string(),
fp,
threads,
};
benchmark(env);
}
Expand Down

0 comments on commit 028350a

Please sign in to comment.