Skip to content

Commit

Permalink
Add 'neutral' argument to benchmark CLI and update processing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
khb7840 committed Nov 4, 2024
1 parent 028350a commit dcf6859
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/cli/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ fn parse_arg() -> Result<AppArgs, Box<dyn std::error::Error>> {
Some("benchmark") => Ok(AppArgs::Benchmark {
result: args.opt_value_from_str(["-r", "--result"])?,
answer: args.opt_value_from_str(["-a", "--answer"])?,
neutral: args.opt_value_from_str(["-n", "--neutral"])?,
index: args.opt_value_from_str(["-i", "--index"])?,
input: args.opt_value_from_str("--input")?,
format: args.value_from_str(["-f", "--format"]).unwrap_or("tsv".into()),
Expand Down
1 change: 1 addition & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub enum AppArgs {
Benchmark {
result: Option<String>,
answer: Option<String>,
neutral: Option<String>,
index: Option<String>,
input: Option<String>,
format: String,
Expand Down
42 changes: 34 additions & 8 deletions src/cli/workflows/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::index::lookup::load_lookup_from_file;
use crate::prelude::*;

use crate::cli::config::read_index_config_from_file;
use crate::utils::benchmark::measure_up_to_k_fp;
use crate::utils::benchmark::{compare_target_answer_neutral_set, measure_up_to_k_fp, measure_up_to_k_fp_with_neutral};

// usage: folddisco benchmark -r <result.tsv> -a <answer.tsv> -i <index> -f tsv
// usage: folddisco benchmark -r <result.tsv> -a <answer.tsv> -i <index> -f default
Expand All @@ -17,6 +17,7 @@ pub fn benchmark(env: AppArgs) {
AppArgs::Benchmark {
result,
answer,
neutral,
index,
input,
format,
Expand All @@ -31,7 +32,11 @@ pub fn benchmark(env: AppArgs) {
}
// If input is given, read from file
let input_vector = if input.is_none() {
vec![(result.unwrap(), answer.unwrap())]
if neutral.is_some() {
vec![(result.unwrap(), answer.unwrap(), neutral)]
} else {
vec![(result.unwrap(), answer.unwrap(), None)]
}
} else {
let input = input.unwrap();
let file = std::fs::File::open(&input).expect(
Expand All @@ -41,7 +46,15 @@ pub fn benchmark(env: AppArgs) {
reader.lines().map(|line| {
let line = line.expect(&log_msg(FAIL, "Failed to read line"));
let row = line.split('\t').collect::<Vec<_>>();
(row[0].to_string(), row[1].to_string())
if row.len() == 2 {
(row[0].to_string(), row[1].to_string(), None)
} else if row.len() >= 3 {
(row[0].to_string(), row[1].to_string(), Some(row[2].to_string()))
} else {
print_log_msg(FAIL, "Invalid input format");
std::process::exit(1);
}
// (row[0].to_string(), row[1].to_string())
}).collect::<Vec<_>>()
};

Expand All @@ -59,7 +72,7 @@ pub fn benchmark(env: AppArgs) {
parse_path_set_as_set(&raw_lookup, &mut lookup);
let config = read_index_config_from_file(&config_path);

input_vector.par_iter().for_each(|(result_path, answer_path)| {
input_vector.par_iter().for_each(|(result_path, answer_path, neutral)| {
// Parse path by id type
let raw_result = read_one_column_of_tsv_as_vec(&result_path, 0);
let mut result = Vec::with_capacity(raw_result.len());
Expand All @@ -70,12 +83,23 @@ pub fn benchmark(env: AppArgs) {
parse_path_set_as_set(&raw_answer, &mut answer);

let result_set = HashSet::from_iter(result.iter().cloned());
let metric = if let Some(fp) = fp {
measure_up_to_k_fp(&result, &answer, &lookup, fp)
let metric = if let Some(neutral) = neutral {
let raw_neutral = read_one_column_of_tsv_as_set(neutral, 0);
let mut neutral = HashSet::with_capacity(raw_neutral.len());
parse_path_set_as_set(&raw_neutral, &mut neutral);
if let Some(fp) = fp {
measure_up_to_k_fp_with_neutral(&result, &answer, &neutral, &lookup, fp)
} else {
compare_target_answer_neutral_set(&result_set, &answer, &neutral, &lookup)
}
} else {
compare_target_answer_set(&result_set, &answer, &lookup)
if let Some(fp) = fp {
measure_up_to_k_fp(&result, &answer, &lookup, fp)
} else {
compare_target_answer_set(&result_set, &answer, &lookup)
}
};

match format {
"tsv" => {
// lookup, result, answer, lookup_len, result_len, answer_len,
Expand Down Expand Up @@ -202,6 +226,7 @@ mod tests {
fn test_benchmark() {
let result = Some("data/zinc_folddisco.tsv".to_string());
let answer = Some("data/zinc_answer.tsv".to_string());
let neutral = None;
let index = Some("analysis/h_sapiens/d16a4/index_id".to_string());
let input = None;
let format = "tsv";
Expand All @@ -210,6 +235,7 @@ mod tests {
let env = AppArgs::Benchmark {
result,
answer,
neutral,
index,
input,
format: format.to_string(),
Expand Down
48 changes: 48 additions & 0 deletions src/utils/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ pub fn compare_target_answer_vec<T: Eq + PartialEq>(target: &Vec<T>, answer: &Ve
Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

pub fn compare_target_answer_neutral_vec<T: Eq + PartialEq>(target: &Vec<T>, answer: &Vec<T>, neutral: &Vec<T>, all: &Vec<T>) -> Metrics {
// True Positive: Elements in both target and answer
let true_pos = target.iter().filter(|&x| answer.contains(x)).count() as f64;
// True Negative: Elements in neither target nor answer
let true_neg = all.iter().filter(|&x| !target.contains(x) && !answer.contains(x) && !neutral.contains(x)).count() as f64;
// False Positive: Elements in target but not in answer
let false_pos = target.iter().filter(|&x| !answer.contains(x) && !neutral.contains(x)).count() as f64;
// False Negative: Elements in answer but not in target
let false_neg = answer.iter().filter(|&x| !target.contains(x)).count() as f64;

Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

pub fn compare_target_answer_set<T: Eq + PartialEq + Hash>(target: &HashSet<T>, answer: &HashSet<T>, all: &HashSet<T>) -> Metrics {
// True Positive: Elements in both target and answer
let true_pos = target.iter().filter(|&x| answer.contains(x)).count() as f64;
Expand All @@ -71,6 +84,19 @@ pub fn compare_target_answer_set<T: Eq + PartialEq + Hash>(target: &HashSet<T>,
Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

pub fn compare_target_answer_neutral_set<T: Eq + PartialEq + Hash>(target: &HashSet<T>, answer: &HashSet<T>, neutral: &HashSet<T>, all: &HashSet<T>) -> Metrics {
// True Positive: Elements in both target and answer
let true_pos = target.iter().filter(|&x| answer.contains(x)).count() as f64;
// True Negative: Elements in neither target nor answer
let true_neg = all.iter().filter(|&x| !target.contains(x) && !answer.contains(x) && !neutral.contains(x)).count() as f64;
// False Positive: Elements in target but not in answer
let false_pos = target.iter().filter(|&x| !answer.contains(x) && !neutral.contains(x)).count() as f64;
// False Negative: Elements in answer but not in target
let false_neg = answer.iter().filter(|&x| !target.contains(x)).count() as f64;

Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

pub fn measure_up_to_k_fp<T: Eq + PartialEq + Hash>(target: &Vec<T>, answer: &HashSet<T>, all: &HashSet<T>, k: f64) -> Metrics {
// Iter until k false positives are found
let mut true_pos = 0.0;
Expand All @@ -91,4 +117,26 @@ pub fn measure_up_to_k_fp<T: Eq + PartialEq + Hash>(target: &Vec<T>, answer: &Ha
// True negatives
let true_neg = all.len() as f64 - (true_pos + false_pos + false_neg);
Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

pub fn measure_up_to_k_fp_with_neutral<T: Eq + PartialEq + Hash>(target: &Vec<T>, answer: &HashSet<T>, neutral: &HashSet<T>, all: &HashSet<T>, k: f64) -> Metrics {
// Iter until k false positives are found
let mut true_pos = 0.0;
let mut false_pos = 0.0;
// Iterate over target
for t in target {
if answer.contains(t) {
true_pos += 1.0;
} else if !neutral.contains(t) {
false_pos += 1.0;
}
if false_pos >= k {
break;
}
}
// False negatives
let false_neg = answer.len() as f64 - true_pos;
// True negatives
let true_neg = all.len() as f64 - (true_pos + false_pos + false_neg);
Metrics::new(true_pos, true_neg, false_pos, false_neg)
}

0 comments on commit dcf6859

Please sign in to comment.