Skip to content

Commit

Permalink
lhco tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
maskomic committed Aug 26, 2022
1 parent 2189516 commit 7867be7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
5 changes: 5 additions & 0 deletions experimental/lhco_evaluation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using DrWatson
include("experimental", "lhco_results.jl")

model = ARGS[1]

27 changes: 23 additions & 4 deletions experimental/lhco_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ mill_datasets_wo_Web = [
"Tiger", "UCSBBreastCancer", "WinterWren"
]

"""
collect_mill(model::String, mill_datasets=mill_datasets)
Collects the results from all folders of MIL datasets using multi-threading.
"""
function collect_mill(model::String, mill_datasets=mill_datasets)
len = length(mill_datasets)
dfs = repeat([DataFrame()], len)
Expand All @@ -103,14 +108,28 @@ function collect_mill(model::String, mill_datasets=mill_datasets)
return vcat(dfs...)
end

"""
collect_mill(model::String, mill_datasets=mill_datasets)
Collects the results from LHCO using multi-threading.
*Note: It is recommended to use the same number of threads as the number of seeds.*
"""
function collect_lhco(model::String, dataset="events_anomalydetection_v2.h5")
df = collect_results(datadir("experiments", "contamination-0.0", "LHCO", model, dataset), subfolders=true, rexclude=[r"model_.*"])
dir = readdir(datadir("experiments", "contamination-0.0", "LHCO", model, dataset), join=true)
len = length(dir)
dfs = repeat([DataFrame()], len)
Threads.@threads for i in 1:len
_df = collect_results(dir[i], subfolders=true, rexclude=[r"model_.*"])
dfs[i] = _df
end
return vcat(dfs...)
end

"""
calculate_results(model::String; dataset::String="MIL", metric::Symbol=:val_AUC, show=false, tf=tf_unicode, filter_fun=nothing, max_seed=10)
Collects results for given model, filters only models with completed run over `max_seed` seeds.
Collects results for given model, filters only models with completed at least `max_seed` runs over the seeds.
Returns a grouped dataframe, where groups are dataset results aggregated over seeds.
Uses parallel processes for collecting results and calculating scores.
Expand Down Expand Up @@ -202,5 +221,5 @@ function lhco_model_results(model::String; metric::Symbol=:val_AUC, show=false,
if show
pretty_table(R2, nosubheader=true, tf = tf)
end
R2, g2
end
R2, g2[1]
end
2 changes: 2 additions & 0 deletions scripts/evaluate_performance_single.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# in the second case, it will recursively search for all compatible files in subdirectories
target = ARGS[1]

using Pkg
Pkg.activate(split(pwd(), ".jl")[1]*".jl")
using DrWatson
@quickactivate
using EvalMetrics
Expand Down

0 comments on commit 7867be7

Please sign in to comment.