diff --git a/MGMM.png b/MGMM.png deleted file mode 100644 index 6a6913f..0000000 Binary files a/MGMM.png and /dev/null differ diff --git a/data/results/toy/toy_results_collection.bson b/data/results/toy/toy_results_collection.bson new file mode 100644 index 0000000..e210027 Binary files /dev/null and b/data/results/toy/toy_results_collection.bson differ diff --git a/data/results/toy/toy_results_names_scores.bson b/data/results/toy/toy_results_names_scores.bson new file mode 100644 index 0000000..6b981b7 Binary files /dev/null and b/data/results/toy/toy_results_names_scores.bson differ diff --git a/scripts/PoolAE/PoolAE_script.jl b/scripts/PoolAE/PoolAE_script.jl new file mode 100644 index 0000000..01b8db9 --- /dev/null +++ b/scripts/PoolAE/PoolAE_script.jl @@ -0,0 +1,126 @@ +using DrWatson +@quickactivate +include(srcdir("models", "utils.jl")) +include(srcdir("models", "PoolAE.jl")) + +using Plots +using StatsPlots +ENV["GKSwstype"] = "100" + +data1 = [randn(2,rand(Poisson(20))) .+ [2.1, -1.4] for _ in 1:100] +data2 = [randn(2,rand(Poisson(20))) .+ [-2.1, 1.4] for _ in 1:100] +data3 = [randn(2,rand(Poisson(20))) for _ in 1:100] +data4 = [randn(2,rand(Poisson(50))) .+ [2.1, -1.4] for _ in 1:100] + +train_data = vcat(data1,data2) +val_data = vcat( + [randn(2,rand(Poisson(20))) .+ [2.1, -1.4] for _ in 1:100], + [randn(2,rand(Poisson(20))) .+ [-2.1, 1.4] for _ in 1:100] +) + +model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max) +opt = ADAM() +ps = Flux.params(model) +loss(x) = pm_variational_loss(model, x) + +for i in 1:100 + Flux.train!(loss, ps, train_data, opt) + @info i mean(loss.(val_data)) +end + +scatter = Plots.scatter +scatter! = Plots.scatter! + +X = hcat(val_data...) +Y = hcat([reconstruct(model, x) for x in val_data]...) + +scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0) +scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0) +savefig("val_data.png") + +E = hcat([encoding(model, x) for x in val_data]...) +scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100))) +savefig("enc.png") + +E_an1 = hcat([encoding(model, x) for x in data3]...) +E_an2 = hcat([encoding(model, x) for x in data4]...) +scatter(E[1,:],E[2,:],label="normal") +scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1") +scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2") +savefig("enc_anomaly.png") + +# different pooling fuction (with cardinality) +model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max_card) +opt = ADAM() +ps = Flux.params(model) +loss(x) = pm_variational_loss(model, x) + +for i in 1:100 + Flux.train!(loss, ps, train_data, opt) + @info "$i: $(mean(loss.(val_data)))" +end + +X = hcat(val_data...) +Y = hcat([reconstruct(model, x) for x in val_data]...) + +scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0) +scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0) +savefig("val_data_card.png") + +E = hcat([encoding(model, x) for x in val_data]...) +scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100))) +savefig("enc_card.png") + +E_an1 = hcat([encoding(model, x) for x in data3]...) +E_an2 = hcat([encoding(model, x) for x in data4]...) +scatter(E[1,:],E[2,:];label="normal", legend=:bottomright) +scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1") +scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2") +savefig("enc_anomaly_card.png") + +E_all = hcat(E, E_an1, E_an2) +card = vcat( + map(x -> size(x, 2), val_data), + map(x -> size(x ,2), data3), + map(x -> size(x ,2), data4) +) +scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet) +savefig("enc_card.png") + + +model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max) +opt = ADAM() +ps = Flux.params(model) +loss(x) = pm_variational_loss(model, x; β=10) + +for i in 1:200 + Flux.train!(loss, ps, train_data, opt) + @info i mean(loss.(val_data)) +end + +X = hcat(val_data...) +Y = hcat([reconstruct(model, x) for x in val_data]...) + +scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0) +scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0) +savefig("val_data_card_β=10.png") + +E = hcat([encoding(model, x) for x in val_data]...) +scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100))) +savefig("enc_card_β=10.png") + +E_an1 = hcat([encoding(model, x) for x in data3]...) +E_an2 = hcat([encoding(model, x) for x in data4]...) +scatter(E[1,:],E[2,:];label="normal", legend=:bottomright) +scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1") +scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2") +savefig("enc_anomaly_card_β=10.png") + +E_all = hcat(E, E_an1, E_an2) +card = vcat( + map(x -> size(x, 2), val_data), + map(x -> size(x ,2), data3), + map(x -> size(x ,2), data4) +) +scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet) +savefig("enc_card_β=10.png") \ No newline at end of file diff --git a/scripts/evaluation/MIL/mill_results.jl b/scripts/evaluation/MIL/mill_results.jl index bb1f95d..8804541 100644 --- a/scripts/evaluation/MIL/mill_results.jl +++ b/scripts/evaluation/MIL/mill_results.jl @@ -11,8 +11,7 @@ using Plots using StatsPlots ENV["GKSwstype"] = "100" -#include(scriptsdir("evaluation", "MIL", "workflow.jl")) - +# names mill_datasets = [ "BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2", "Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein", @@ -27,6 +26,10 @@ mill_names = [ modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] modelscores = [:distance, :score, :type, :type, :type, :score] +####################################### +### First time results calculations ### +####################################### + # MIL results - finding the best model # if calculated for the first time mill_results_collection = Dict() @@ -58,10 +61,14 @@ for (modelname, score) in map((x, y) -> (x, y), modelnames, modelscores_agg) end save(datadir("dataframes", "mill_results_scores_agg.bson"), mill_results_scores_agg) +############################################# +### Load results from existing data files ### +############################################# # if already calculated, just load the data mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson")) mill_results_scores_agg = load(datadir("results", "MIL", "mill_results_scores_agg.bson")) +mill_results_scores = load(datadir("results", "MIL", "mill_results_scores.bson")) ################################################### diff --git a/scripts/evaluation/MIL/mill_results_table.jl b/scripts/evaluation/MIL/mill_results_table.jl index d26d9b4..eab674a 100644 --- a/scripts/evaluation/MIL/mill_results_table.jl +++ b/scripts/evaluation/MIL/mill_results_table.jl @@ -11,6 +11,13 @@ using Statistics using EvalMetrics using BSON +# Milldata sets names +mill_datasets = [ + "BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2", + "Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein", + "Tiger", "UCSBBreastCancer", "Web1", "Web2", "Web3", "Web4", "WinterWren" +] + # load results dataframes modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson")) diff --git a/scripts/evaluation/toy/test.png b/scripts/evaluation/toy/test.png deleted file mode 100644 index 969ba04..0000000 Binary files a/scripts/evaluation/toy/test.png and /dev/null differ diff --git a/scripts/evaluation/toy/toy_results.jl b/scripts/evaluation/toy/toy_results.jl index 7ceb93c..7dd639b 100644 --- a/scripts/evaluation/toy/toy_results.jl +++ b/scripts/evaluation/toy/toy_results.jl @@ -43,8 +43,8 @@ toy_results_names_scores = Dict(map((x, y) -> x => y, modelnames, modelscores)) safesave(datadir("dataframes", "toy_results_names_scores.bson"), toy_results_names_scores) # load results collection -toy_results_collection = load(datadir("dataframes", "toy_results_collection.bson")) -toy_results_names_scores = load(datadir("dataframes", "toy_results_names_scores.bson")) +toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson")) +toy_results_names_scores = load(datadir("results/toy", "toy_results_names_scores.bson")) ### BARPLOTS diff --git a/scripts/evaluation/toy/toy_summary.jl b/scripts/evaluation/toy/toy_summary.jl new file mode 100644 index 0000000..d64ce17 --- /dev/null +++ b/scripts/evaluation/toy/toy_summary.jl @@ -0,0 +1,63 @@ +using DrWatson +@quickactivate +using GroupAD +using GroupAD: Evaluation +using DataFrames +using Statistics +using EvalMetrics +using PrettyTables + +using Plots +using StatsPlots +#using PlotlyJS +ENV["GKSwstype"] = "100" + +modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] +modelscores = [:distance, :score, :type, :type, :type, :score] + +# load results collection +toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson")) + +df_vec = map(name -> toy_results_collection[name], modelnames) +df_vec2 = map(name -> insertcols!(toy_results_collection[name], :model => name), modelnames) +df_full = vcat(df_vec2..., cols=:union) +sort!(df_full, :val_AUC_mean, rev=true) +g = groupby(df_full, [:model, :scenario]) +df_best = map(df -> DataFrame(df[1,[:model, :scenario, :test_AUC_mean]]), g) +df_red = vcat(df_best...) + +s1 = filter(:scenario => scenario -> scenario == 1, df_red)[:, [:model, :test_AUC_mean]] +s2 = filter(:scenario => scenario -> scenario == 2, df_red)[:, [:model, :test_AUC_mean]] +s3 = filter(:scenario => scenario -> scenario == 3, df_red)[:, [:model, :test_AUC_mean]] + +H = [] +for modelname in modelnames + v1 = s1[s1[:, :model] .== modelname, :test_AUC_mean] + v2 = s2[s2[:, :model] .== modelname, :test_AUC_mean] + v3 = s3[s3[:, :model] .== modelname, :test_AUC_mean] + V = vcat(v1,v2,v3) + push!(H, V) +end + +H2 = hcat(H...) +H3 = vcat(H2, mean(H2, dims=1)) +_final = DataFrame(hcat(["1","2","3","Average"],H3)) +nice_modelnames = ["scenario", "kNNagg", "VAEagg", "VAE", "NS", "PoolModel", "MGMM"] +final = rename(_final, nice_modelnames) + + +l_max = LatexHighlighter( + (data, i, j) -> (data[i,j] == maximum(final[i, 2:7])) && typeof(data[i,j])!==String, + ["textbf", "textcolor{blue}"] +) +l_min = LatexHighlighter( + (data, i, j) -> (data[i,j] == minimum(final[i, 2:7])) && typeof(data[i,j])!==String, + ["textcolor{red}"] +) + +t = pretty_table( + final, + highlighters = (l_max, l_min), + formatters = ft_printf("%5.3f"), + backend=:latex, tf=tf_latex_booktabs, nosubheader=true +) \ No newline at end of file diff --git a/src/evaluation/plotting.jl b/src/evaluation/plotting.jl index 8113053..fa7e2cd 100644 --- a/src/evaluation/plotting.jl +++ b/src/evaluation/plotting.jl @@ -1,3 +1,11 @@ +using StatsPlots + +mill_names = [ + "BrownCreeper", "CorelAfrican", "CorelBeach", "Elephant", "Fox", "Musk1", "Musk2", + "Mut1", "Mut2", "News1", "News2", "News3", "Protein", + "Tiger", "UCSB-BC", "Web1", "Web2", "Web3", "Web4", "WinterWren" +] + """ groupedbar_matrix(df::DataFrame; group::Symbol, cols::Symbol, value::Symbol, groupnamefull=true) diff --git a/src/models/PoolAE.jl b/src/models/PoolAE.jl new file mode 100644 index 0000000..07fa316 --- /dev/null +++ b/src/models/PoolAE.jl @@ -0,0 +1,304 @@ +using Flux +using Flux3D: chamfer_distance +using ConditionalDists, Distributions, DistributionsAD +using MLDataPattern: RandomBatches +using StatsBase +using Random +using Mill + +""" +PoolAE is a generative model which reconstructs and generates +output from a single vector summary of the input set. + +PoolAE has 6 components: +- prepool_net +- poolf +- prior +- encoder +- generator +- decoder + +Pre-pool net is a neural network which transforms all vectors in given set. +A summary is created with a pooling function which has to be permutation invariant. +Possible functions include: mean, sum, maximum, etc. +""" +struct PoolAE{pre <: Chain, fun <: Function, e <: ConditionalMvNormal, p <: ContinuousMultivariateDistribution, g <: ConditionalMvNormal, d <: Chain} + prepool_net::pre + poolf::fun + encoder::e + prior::p + generator::g + decoder::d +end + +Flux.@functor PoolAE + +function Flux.trainable(m::PoolAE) + (prepool_net = m.prepool_net, encoder = m.encoder, generator = m.generator, decoder = m.decoder) +end + +function PoolAE(pre, fun, enc::ConditionalMvNormal, gen, dec, plength::Int) + W = first(Flux.params(enc)) + μ = fill!(similar(W, plength), 0) + σ = fill!(similar(W, plength), 1) + prior = DistributionsAD.TuringMvNormal(μ, σ) + PoolAE(pre, fun, enc, prior, gen, dec) +end + +function Base.show(io::IO, pm::PoolAE) + nm = "PoolAE($(pm.poolf))" + print(io, nm) +end + + +""" + pm_constructor(;idim, hdim, predim, postdim, edim, activation="swish", nlayers=3, var="scalar", fun=sum_stat) + +Constructs a PoolAE. Some input dimensions are automatically calculated based on the chosen +pooling function. + +Dimensions: +- idim: input dimension +- hdim: hidden dimension in all networks +- predim: the input dimension of pooling function +- postdim: the output dimension of post-pool network and input dimension of encoder and generator +- edim: output dimension of encoder and generator, input dimension to decoder +""" +function pm_constructor(;idim, hdim=32, predim=8, zdim=8, activation="swish", nlayers=3, var="scalar", + poolf=bag_mean, init_seed=nothing, kwargs...) + + fun = eval(:($(Symbol(poolf)))) + + # if seed is given, set it + (init_seed != nothing) ? Random.seed!(init_seed) : nothing + + # pre-pool network + pre = Chain( + build_mlp(idim,hdim,hdim,nlayers-1,activation=activation)..., + Dense(hdim,predim) + ) + # dimension after pooling + pooldim = length(fun(randn(predim))) + # post-pool network + + if var == "scalar" + # encoder + enc = Chain( + build_mlp(pooldim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,1]) + ) + enc_dist = ConditionalMvNormal(enc) + + gen = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,1]) + ) + gen_dist = ConditionalMvNormal(gen) + else + enc = Chain( + build_mlp(pooldim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,zdim]) + ) + enc_dist = ConditionalMvNormal(enc) + + gen = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,zdim]) + ) + gen_dist = ConditionalMvNormal(gen) + end + + dec = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + Dense(hdim,idim) + ) + + pm = PoolAE(pre, fun, enc_dist, gen_dist, dec, zdim) + return pm +end + +################################# +### Special pooling functions ### +################################# + +bag_mean(x) = mean(x, dims=2) +bag_maximum(x) = maximum(x, dims=2) + +""" + mean_max(x) + +Concatenates mean and maximum. +""" +function mean_max(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + return vcat(m1,m2) +end + +""" + mean_max_card(x) + +Concatenates mean, maximum and set cardinality. +""" +function mean_max_card(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + return vcat(m1,m2,size(x,2)) +end + +""" + sum_stat(x) + +Calculates a summary vector as a concatenation of mean, maximum, minimum, and var pooling. +""" +function sum_stat(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + m3 = minimum(x, dims=2) + m4 = var(x, dims=2) + if any(isnan.(m4)) + m4 = zeros(length(m1)) + end + return vcat(m1,m2,m3,m4) +end + +function sum_stat_card(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + m3 = minimum(x, dims=2) + m4 = var(x, dims=2) + if any(isnan.(m4)) + m4 = zeros(length(m1)) + end + return vcat(m1,m2,m3,m4,size(x,2)) +end + +""" + pm_variational_loss(m::PoolAE, x) + +Loss function for the PoolAE which mirrors ELBO for VAE and +should create a latent space mapped to standard Gaussian. Uses +Chamfer distance and KL divergence. +""" +function pm_variational_loss(m::PoolAE, x; β=1) + # pre-pool network transformation of X + v = m.prepool_net(x) + # pooling + p = m.poolf(v) + # pool encoder + z = rand(m.encoder, p) + kld = mean(kl_divergence(condition(m.encoder, p), m.prior)) + + Z = hcat([rand(m.generator, z) for i in 1:size(x, 2)]...) + dz = m.decoder(Z) + + return chamfer_distance(x, dz) + β*kld +end + +""" +StatsBase.fit!(model::MGMM, data::Tuple, loss::Function; max_train_time=82800, lr=0.001, + batchsize=64, patience=30, check_interval::Int=10, kwargs...) + +Function to fit MGMM model. +""" +function StatsBase.fit!(model::PoolAE, data::Tuple, loss::Function; + max_iters=10000, max_train_time=82800, lr=0.001, batchsize=64, patience=30, + check_interval::Int=10, kwargs...) + + history = MVHistory() + opt = ADAM(lr) + + tr_model = deepcopy(model) + ps = Flux.params(tr_model) + _patience = patience + + # prepare data for bag model + tr_x, tr_l = unpack_mill(data[1]) + vx, vl = unpack_mill(data[2]) + val_x = vx[vl .== 0] + + best_val_loss = Inf + i = 1 + start_time = time() + + lossf(x) = loss(tr_model, x) + + # infinite for loop via RandomBatches + for batch in RandomBatches(tr_x, 10) + # classic training + bag_batch = RandomBagBatches(tr_x,batchsize=batchsize,randomize=true) + Flux.train!(lossf, ps, bag_batch, opt) + # only batch training loss + batch_loss = mean(lossf.(bag_batch)) + + push!(history, :training_loss, i, batch_loss) + if mod(i, check_interval) == 0 + + # validation/early stopping + val_loss = mean(lossf.(val_x)) + + @info "$i - loss: $(batch_loss) (batch) | $(val_loss) (validation)" + + if isnan(val_loss) || isnan(batch_loss) + error("Encountered invalid values in loss function.") + end + + push!(history, :validation_likelihood, i, val_loss) + + if val_loss < best_val_loss + best_val_loss = val_loss + _patience = patience + + # this should save the model at least once + # when the validation loss is decreasing + model = deepcopy(tr_model) + else # else stop if the model has not improved for `patience` iterations + _patience -= 1 + # @info "Patience is: $_patience." + if _patience == 0 + @info "Stopped training after $(i) iterations." + break + end + end + end + if (time() - start_time > max_train_time) | (i > max_iters) # stop early if time is running out + model = deepcopy(tr_model) + @info "Stopped training after $(i) iterations, $((time() - start_time) / 3600) hours." + break + end + i += 1 + end + # again, this is not optimal, the model should be passed by reference and only the reference should be edited + (history = history, iterations = i, model = model, npars = sum(map(p -> length(p), Flux.params(model)))) +end + +###################################### +### Score functions and evaluation ### +###################################### + +""" + reconstruct(m::PoolAE, x) + +Reconstructs the input bag. +""" +function reconstruct(m::PoolAE, x) + v = m.prepool_net(x) + p = m.poolf(v) + z = mean(m.encoder, p) + Z = hcat([rand(m.generator, z) for i in 1:size(x, 2)]...) + m.decoder(Z) +end + +""" + pool_encoding(m::PoolAE, x; post=true) + +Returns the one-vector summary encoding for a bag. +If `post=true`, takes the bag through pre-pool network, +pooling function and post-pool network. If `post=false`, +skips the post-pool network transformation. +""" +function encoding(m::PoolAE, x) + v = m.prepool_net(x) + p = m.poolf(v) + z = mean(m.encoder, p) +end \ No newline at end of file diff --git a/src/models/PoolModel.jl b/src/models/PoolModel.jl index 16eda84..e5420f3 100644 --- a/src/models/PoolModel.jl +++ b/src/models/PoolModel.jl @@ -123,9 +123,9 @@ function pm_constructor(;idim, hdim, predim, postdim, edim, activation="swish", return pm end -################# -### Functions ### -################# +################################# +### Special pooling functions ### +################################# bag_mean(x) = mean(x, dims=2) bag_maximum(x) = maximum(x, dims=2)