Skip to content

Commit

Permalink
evaluation update - minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maskomic committed Feb 1, 2022
1 parent 9e5b52a commit 6e79a3d
Show file tree
Hide file tree
Showing 12 changed files with 522 additions and 7 deletions.
Binary file removed MGMM.png
Binary file not shown.
Binary file added data/results/toy/toy_results_collection.bson
Binary file not shown.
Binary file added data/results/toy/toy_results_names_scores.bson
Binary file not shown.
126 changes: 126 additions & 0 deletions scripts/PoolAE/PoolAE_script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
using DrWatson
@quickactivate
include(srcdir("models", "utils.jl"))
include(srcdir("models", "PoolAE.jl"))

using Plots
using StatsPlots
ENV["GKSwstype"] = "100"

data1 = [randn(2,rand(Poisson(20))) .+ [2.1, -1.4] for _ in 1:100]
data2 = [randn(2,rand(Poisson(20))) .+ [-2.1, 1.4] for _ in 1:100]
data3 = [randn(2,rand(Poisson(20))) for _ in 1:100]
data4 = [randn(2,rand(Poisson(50))) .+ [2.1, -1.4] for _ in 1:100]

train_data = vcat(data1,data2)
val_data = vcat(
[randn(2,rand(Poisson(20))) .+ [2.1, -1.4] for _ in 1:100],
[randn(2,rand(Poisson(20))) .+ [-2.1, 1.4] for _ in 1:100]
)

model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max)
opt = ADAM()
ps = Flux.params(model)
loss(x) = pm_variational_loss(model, x)

for i in 1:100
Flux.train!(loss, ps, train_data, opt)
@info i mean(loss.(val_data))
end

scatter = Plots.scatter
scatter! = Plots.scatter!

X = hcat(val_data...)
Y = hcat([reconstruct(model, x) for x in val_data]...)

scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0)
scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0)
savefig("val_data.png")

E = hcat([encoding(model, x) for x in val_data]...)
scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100)))
savefig("enc.png")

E_an1 = hcat([encoding(model, x) for x in data3]...)
E_an2 = hcat([encoding(model, x) for x in data4]...)
scatter(E[1,:],E[2,:],label="normal")
scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1")
scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2")
savefig("enc_anomaly.png")

# different pooling fuction (with cardinality)
model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max_card)
opt = ADAM()
ps = Flux.params(model)
loss(x) = pm_variational_loss(model, x)

for i in 1:100
Flux.train!(loss, ps, train_data, opt)
@info "$i: $(mean(loss.(val_data)))"
end

X = hcat(val_data...)
Y = hcat([reconstruct(model, x) for x in val_data]...)

scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0)
scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0)
savefig("val_data_card.png")

E = hcat([encoding(model, x) for x in val_data]...)
scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100)))
savefig("enc_card.png")

E_an1 = hcat([encoding(model, x) for x in data3]...)
E_an2 = hcat([encoding(model, x) for x in data4]...)
scatter(E[1,:],E[2,:];label="normal", legend=:bottomright)
scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1")
scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2")
savefig("enc_anomaly_card.png")

E_all = hcat(E, E_an1, E_an2)
card = vcat(
map(x -> size(x, 2), val_data),
map(x -> size(x ,2), data3),
map(x -> size(x ,2), data4)
)
scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet)
savefig("enc_card.png")


model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max)
opt = ADAM()
ps = Flux.params(model)
loss(x) = pm_variational_loss(model, x; β=10)

for i in 1:200
Flux.train!(loss, ps, train_data, opt)
@info i mean(loss.(val_data))
end

X = hcat(val_data...)
Y = hcat([reconstruct(model, x) for x in val_data]...)

scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0)
scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0)
savefig("val_data_card_β=10.png")

E = hcat([encoding(model, x) for x in val_data]...)
scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100)))
savefig("enc_card_β=10.png")

E_an1 = hcat([encoding(model, x) for x in data3]...)
E_an2 = hcat([encoding(model, x) for x in data4]...)
scatter(E[1,:],E[2,:];label="normal", legend=:bottomright)
scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1")
scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2")
savefig("enc_anomaly_card_β=10.png")

E_all = hcat(E, E_an1, E_an2)
card = vcat(
map(x -> size(x, 2), val_data),
map(x -> size(x ,2), data3),
map(x -> size(x ,2), data4)
)
scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet)
savefig("enc_card_β=10.png")
11 changes: 9 additions & 2 deletions scripts/evaluation/MIL/mill_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ using Plots
using StatsPlots
ENV["GKSwstype"] = "100"

#include(scriptsdir("evaluation", "MIL", "workflow.jl"))

# names
mill_datasets = [
"BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2",
"Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein",
Expand All @@ -27,6 +26,10 @@ mill_names = [
modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"]
modelscores = [:distance, :score, :type, :type, :type, :score]

#######################################
### First time results calculations ###
#######################################

# MIL results - finding the best model
# if calculated for the first time
mill_results_collection = Dict()
Expand Down Expand Up @@ -58,10 +61,14 @@ for (modelname, score) in map((x, y) -> (x, y), modelnames, modelscores_agg)
end
save(datadir("dataframes", "mill_results_scores_agg.bson"), mill_results_scores_agg)

#############################################
### Load results from existing data files ###
#############################################

# if already calculated, just load the data
mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson"))
mill_results_scores_agg = load(datadir("results", "MIL", "mill_results_scores_agg.bson"))
mill_results_scores = load(datadir("results", "MIL", "mill_results_scores.bson"))


###################################################
Expand Down
7 changes: 7 additions & 0 deletions scripts/evaluation/MIL/mill_results_table.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ using Statistics
using EvalMetrics
using BSON

# Milldata sets names
mill_datasets = [
"BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2",
"Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein",
"Tiger", "UCSBBreastCancer", "Web1", "Web2", "Web3", "Web4", "WinterWren"
]

# load results dataframes
modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"]
mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson"))
Expand Down
Binary file removed scripts/evaluation/toy/test.png
Binary file not shown.
4 changes: 2 additions & 2 deletions scripts/evaluation/toy/toy_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ toy_results_names_scores = Dict(map((x, y) -> x => y, modelnames, modelscores))
safesave(datadir("dataframes", "toy_results_names_scores.bson"), toy_results_names_scores)

# load results collection
toy_results_collection = load(datadir("dataframes", "toy_results_collection.bson"))
toy_results_names_scores = load(datadir("dataframes", "toy_results_names_scores.bson"))
toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson"))
toy_results_names_scores = load(datadir("results/toy", "toy_results_names_scores.bson"))


### BARPLOTS
Expand Down
63 changes: 63 additions & 0 deletions scripts/evaluation/toy/toy_summary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using DrWatson
@quickactivate
using GroupAD
using GroupAD: Evaluation
using DataFrames
using Statistics
using EvalMetrics
using PrettyTables

using Plots
using StatsPlots
#using PlotlyJS
ENV["GKSwstype"] = "100"

modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"]
modelscores = [:distance, :score, :type, :type, :type, :score]

# load results collection
toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson"))

df_vec = map(name -> toy_results_collection[name], modelnames)
df_vec2 = map(name -> insertcols!(toy_results_collection[name], :model => name), modelnames)
df_full = vcat(df_vec2..., cols=:union)
sort!(df_full, :val_AUC_mean, rev=true)
g = groupby(df_full, [:model, :scenario])
df_best = map(df -> DataFrame(df[1,[:model, :scenario, :test_AUC_mean]]), g)
df_red = vcat(df_best...)

s1 = filter(:scenario => scenario -> scenario == 1, df_red)[:, [:model, :test_AUC_mean]]
s2 = filter(:scenario => scenario -> scenario == 2, df_red)[:, [:model, :test_AUC_mean]]
s3 = filter(:scenario => scenario -> scenario == 3, df_red)[:, [:model, :test_AUC_mean]]

H = []
for modelname in modelnames
v1 = s1[s1[:, :model] .== modelname, :test_AUC_mean]
v2 = s2[s2[:, :model] .== modelname, :test_AUC_mean]
v3 = s3[s3[:, :model] .== modelname, :test_AUC_mean]
V = vcat(v1,v2,v3)
push!(H, V)
end

H2 = hcat(H...)
H3 = vcat(H2, mean(H2, dims=1))
_final = DataFrame(hcat(["1","2","3","Average"],H3))
nice_modelnames = ["scenario", "kNNagg", "VAEagg", "VAE", "NS", "PoolModel", "MGMM"]
final = rename(_final, nice_modelnames)


l_max = LatexHighlighter(
(data, i, j) -> (data[i,j] == maximum(final[i, 2:7])) && typeof(data[i,j])!==String,
["textbf", "textcolor{blue}"]
)
l_min = LatexHighlighter(
(data, i, j) -> (data[i,j] == minimum(final[i, 2:7])) && typeof(data[i,j])!==String,
["textcolor{red}"]
)

t = pretty_table(
final,
highlighters = (l_max, l_min),
formatters = ft_printf("%5.3f"),
backend=:latex, tf=tf_latex_booktabs, nosubheader=true
)
8 changes: 8 additions & 0 deletions src/evaluation/plotting.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
using StatsPlots

mill_names = [
"BrownCreeper", "CorelAfrican", "CorelBeach", "Elephant", "Fox", "Musk1", "Musk2",
"Mut1", "Mut2", "News1", "News2", "News3", "Protein",
"Tiger", "UCSB-BC", "Web1", "Web2", "Web3", "Web4", "WinterWren"
]

"""
groupedbar_matrix(df::DataFrame; group::Symbol, cols::Symbol, value::Symbol, groupnamefull=true)
Expand Down
Loading

0 comments on commit 6e79a3d

Please sign in to comment.