-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from aicenter/chamfer
Chamfer VAE/NS, HMill classifier, LHCO processing.
- Loading branch information
Showing
25 changed files
with
1,360 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
capsule_together | ||
hazelnut_together | ||
pill_together | ||
screw_together | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
using DrWatson | ||
@quickactivate | ||
using ArgParse | ||
using GroupAD | ||
import StatsBase: fit!, predict | ||
using StatsBase | ||
using BSON | ||
using Flux | ||
using Distributions | ||
using ValueHistories | ||
using MLDataPattern: RandomBatches | ||
using Random | ||
|
||
# [0, 10, 20, full] | ||
|
||
s = ArgParseSettings() | ||
@add_arg_table! s begin | ||
"max_seed" | ||
arg_type = Int | ||
help = "seed" | ||
default = 1 | ||
"dataset" | ||
default = "Fox" | ||
arg_type = String | ||
help = "dataset" | ||
"contamination" | ||
default = 0.0 | ||
arg_type = Float64 | ||
help = "training data contamination rate" | ||
end | ||
parsed_args = parse_args(ARGS, s) | ||
@unpack dataset, max_seed, contamination = parsed_args | ||
|
||
####################################################################################### | ||
################ THIS PART IS TO BE PROVIDED FOR EACH MODEL SEPARATELY ################ | ||
modelname = "hmil_classifier" | ||
# sample parameters, should return a Dict of model kwargs | ||
|
||
# fix seed to always choose the same hyperparameters | ||
function sample_params() | ||
mdim = sample([8,16,32,64,128,256]) | ||
activation = sample(["sigmoid", "tanh", "relu", "swish"]) | ||
aggregation = sample(["SegmentedMeanMax", "SegmentedMax", "SegmentedMean"]) | ||
nlayers = sample(1:3) | ||
return (mdim=mdim, activation=activation, aggregation=aggregation, nlayers=nlayers) | ||
end | ||
|
||
loss(model, x, y) = Flux.logitcrossentropy(model(x), y) | ||
|
||
""" | ||
fit(data, parameters) | ||
This is the most important function - returns `training_info` and a tuple or a vector of tuples `(score_fun, final_parameters)`. | ||
`training_info` contains additional information on the training process that should be saved, the same for all anomaly score functions. | ||
Each element of the return vector contains a specific anomaly score function - there can be multiple for each trained model. | ||
Final parameters is a named tuple of names and parameter values that are used for creation of the savefile name. | ||
""" | ||
function fit(data, parameters, seed) | ||
# construct model - constructor should only accept kwargs | ||
# model = GroupAD.Models.hmil_constructor(;idim=size(data[1][1],1), parameters...) | ||
model = GroupAD.Models.hmil_constructor(data[1][1]; parameters...) | ||
|
||
# fit train data | ||
# max. train time: 24 hours | ||
try | ||
global _info, fit_t, _, _, _ = @timed GroupAD.Models.fit_hmil!(model, data, loss; max_train_time=23*3600/max_seed/4, | ||
patience=200, check_interval=5, seed=seed, parameters...) | ||
global info = _info[1] | ||
global new_data = (_info[2], _info[3], data[3]) | ||
catch e | ||
# return an empty array if fit fails so nothing is computed | ||
@info "Failed training due to \n$e" | ||
return (fit_t = NaN, history=nothing, npars=nothing, model=nothing), [] | ||
end | ||
|
||
# construct return information - put e.g. the model structure here for generative models | ||
training_info = ( | ||
fit_t = fit_t, | ||
history = info.history, | ||
npars = info.npars, | ||
model = info.model | ||
) | ||
|
||
# now return the info to be saved and an array of tuples (anomaly score function, hyperparatemers) | ||
# the score functions themselves are inside experimental loop | ||
return training_info, [ | ||
(x -> GroupAD.Models.score_hmil(info.model, x), | ||
merge(parameters, (score = "normal_prob",))), | ||
(x -> GroupAD.Models.get_label_hmil(info.model, x), | ||
merge(parameters, (score = "get_label",))) | ||
], new_data | ||
end | ||
|
||
""" | ||
edit_params(data, parameters) | ||
This modifies parameters according to data. Default version only returns the input arg. | ||
Overload for models where this is needed. | ||
""" | ||
function edit_params(data, parameters) | ||
parameters | ||
end | ||
|
||
#################################################################### | ||
################ THIS PART IS COMMON FOR ALL MODELS ################ | ||
# only execute this if run directly - so it can be included in other files | ||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
if in(dataset, mill_datasets) | ||
GroupAD.Models.hmil_basic_loop( | ||
sample_params, | ||
fit, | ||
edit_params, | ||
max_seed, | ||
modelname, | ||
dataset, | ||
contamination, | ||
datadir("experiments/contamination-$(contamination)/MIL"), | ||
) | ||
elseif in(dataset, mvtec_datasets) | ||
GroupAD.Models.hmil_basic_loop( | ||
sample_params, | ||
fit, | ||
edit_params, | ||
max_seed, | ||
modelname, | ||
dataset, | ||
contamination, | ||
datadir("experiments/contamination-$(contamination)/mv_tec") | ||
) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
#SBATCH --partition=cpulong | ||
#SBATCH --time=35:00:00 | ||
#SBATCH --nodes=1 --ntasks-per-node=2 --cpus-per-task=1 | ||
#SBATCH --mem=12G | ||
|
||
MAX_SEED=$1 | ||
DATASET=$2 | ||
CONTAMINATION=$3 | ||
|
||
module load Python/3.8 | ||
module load Julia/1.7.3-linux-x86_64 | ||
|
||
julia --project ./hmil_classifier.jl ${MAX_SEED} $DATASET $CONTAMINATION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/bin/bash | ||
# This runs parallel experiments over all datasets. | ||
# USAGE EXAMPLE | ||
# ./run_parallel_mill.sh vae_basic 3 1 2 datasets_mill.txt 0.05 | ||
# Run from this folder only. | ||
MODEL=$1 # which model to run | ||
NUM_SAMPLES=$2 # how many repetitions | ||
MAX_SEED=$3 # how many folds over dataset | ||
NUM_CONC=$4 # number of concurrent tasks in the array job | ||
DATASET_FILE=$5 # file with dataset list | ||
|
||
LOG_DIR="${HOME}/logs/${MODEL}" | ||
|
||
if [ ! -d "$LOG_DIR" ]; then | ||
mkdir $LOG_DIR | ||
fi | ||
|
||
while read d; do | ||
# submit to slurm | ||
for na in 0 10 20 100 | ||
do | ||
sbatch \ | ||
--array=1-${NUM_SAMPLES}%${NUM_CONC} \ | ||
--output="${LOG_DIR}/${d}-%A_%a.out" \ | ||
./${MODEL}.sh $MAX_SEED $d $na | ||
|
||
# for local testing | ||
# ./${MODEL}_run.sh $MAX_SEED $d | ||
done | ||
done < ${DATASET_FILE} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
using DrWatson | ||
@quickactivate | ||
using ArgParse | ||
using GroupAD | ||
import StatsBase: fit!, predict | ||
using StatsBase | ||
using BSON | ||
using Flux | ||
using GroupAD.GenerativeModels | ||
using Distributions | ||
|
||
s = ArgParseSettings() | ||
@add_arg_table! s begin | ||
"max_seed" | ||
arg_type = Int | ||
help = "seed" | ||
default = 1 | ||
"dataset" | ||
default = "Fox" | ||
arg_type = String | ||
help = "dataset" | ||
"contamination" | ||
default = 0.0 | ||
arg_type = Float64 | ||
help = "training data contamination rate" | ||
end | ||
parsed_args = parse_args(ARGS, s) | ||
@unpack dataset, max_seed, contamination = parsed_args | ||
|
||
####################################################################################### | ||
################ THIS PART IS TO BE PROVIDED FOR EACH MODEL SEPARATELY ################ | ||
modelname = "statistician_chamfer" | ||
# sample parameters, should return a Dict of model kwargs | ||
""" | ||
sample_params() | ||
Should return a named tuple that contains a sample of model parameters. | ||
For NeuralStatistician, latent dimensions cdim and zdim should be smaller | ||
or equal to hidden dimension: | ||
- `cdim` <= `hdim` | ||
- `vdim` <= `hdim` | ||
- `zdim` <= `hdim` | ||
""" | ||
function sample_params() | ||
par_vec = (2 .^(4:9), 2 .^(3:8), 2 .^(3:8), 2 .^(3:8), ["scalar", "diagonal"], 10f0 .^(-4:-3), 3:4, 2 .^(5:7), ["relu", "swish", "tanh"], 1:Int(1e8)) | ||
argnames = (:hdim, :vdim, :cdim, :zdim, :var, :lr, :nlayers, :batchsize, :activation, :init_seed) | ||
parameters = (;zip(argnames, map(x->sample(x, 1)[1], par_vec))...) | ||
|
||
# ensure that vdim, zdim, cdim <= hdim | ||
while parameters.vdim >= parameters.hdim | ||
parameters = merge(parameters, (vdim = sample(par_vec[2]),)) | ||
end | ||
while parameters.cdim >= parameters.hdim | ||
parameters = merge(parameters, (cdim = sample(par_vec[3]),)) | ||
end | ||
while parameters.zdim >= parameters.hdim | ||
parameters = merge(parameters, (zdim = sample(par_vec[4]),)) | ||
end | ||
return parameters | ||
end | ||
|
||
""" | ||
loss(model::GenerativeModels.NeuralStatistician, x) | ||
Negative ELBO for training of a Neural Statistician model. | ||
""" | ||
loss(model::GenerativeModels.NeuralStatistician, batch) = mean(x -> GroupAD.Models.chamfer_elbo1(model, x), batch) | ||
|
||
""" | ||
fit(data, parameters) | ||
This is the most important function - returns `training_info` and a tuple or a vector of tuples `(score_fun, final_parameters)`. | ||
`training_info` contains additional information on the training process that should be saved, the same for all anomaly score functions. | ||
Each element of the return vector contains a specific anomaly score function - there can be multiple for each trained model. | ||
Final parameters is a named tuple of names and parameter values that are used for creation of the savefile name. | ||
""" | ||
function fit(data, parameters) | ||
# construct model - constructor should only accept kwargs | ||
model = GroupAD.Models.statistician_constructor(;idim=size(data[1][1],1), parameters...) | ||
|
||
# fit train data | ||
try | ||
global info, fit_t, _, _, _ = @timed fit!(model, data, loss; max_train_time=82800/max_seed, | ||
patience=200, check_interval=5, parameters...) | ||
catch e | ||
# return an empty array if fit fails so nothing is computed | ||
@info "Failed training due to \n$e" | ||
return (fit_t = NaN, history=nothing, npars=nothing, model=nothing), [] | ||
end | ||
|
||
# construct return information - put e.g. the model structure here for generative models | ||
training_info = ( | ||
fit_t = fit_t, | ||
history = info.history, | ||
npars = info.npars, | ||
model = info.model | ||
) | ||
|
||
# now return the info to be saved and an array of tuples (anomaly score function, hyperparatemers) | ||
return training_info, [ | ||
(x -> GroupAD.Models.reconstruct_input(info.model, x), | ||
merge(parameters, (score = "reconstructed_input", L=1))) | ||
] | ||
end | ||
|
||
""" | ||
edit_params(data, parameters) | ||
This modifies parameters according to data. Default version only returns the input arg. | ||
Overload for models where this is needed. | ||
""" | ||
function edit_params(data, parameters) | ||
parameters | ||
end | ||
|
||
#################################################################### | ||
################ THIS PART IS COMMON FOR ALL MODELS ################ | ||
# only execute this if run directly - so it can be included in other files | ||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
if in(dataset, mill_datasets) | ||
GroupAD.basic_experimental_loop( | ||
sample_params, | ||
fit, | ||
edit_params, | ||
max_seed, | ||
modelname, | ||
dataset, | ||
contamination, | ||
datadir("experiments/contamination-$(contamination)/MIL"), | ||
) | ||
elseif in(dataset, mvtec_datasets) | ||
GroupAD.basic_experimental_loop( | ||
sample_params, | ||
fit, | ||
edit_params, | ||
max_seed, | ||
modelname, | ||
dataset, | ||
contamination, | ||
datadir("experiments/contamination-$(contamination)/mv_tec") | ||
) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
#SBATCH --partition=cpulong | ||
#SBATCH --time=48:00:00 | ||
#SBATCH --nodes=1 --ntasks-per-node=2 --cpus-per-task=1 | ||
#SBATCH --mem=30G | ||
|
||
MAX_SEED=$1 | ||
DATASET=$2 | ||
CONTAMINATION=$3 | ||
|
||
module load Python/3.8 | ||
module load Julia/1.7.3-linux-x86_64 | ||
|
||
julia --project ./statistician_chamfer.jl ${MAX_SEED} $DATASET $CONTAMINATION |
Oops, something went wrong.