Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more consistent interfaces #7

Merged
merged 26 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0049300
draft interface
tiemvanderdeure Jan 3, 2024
1459868
draft interface
tiemvanderdeure Jan 3, 2024
b19f009
Merge branch 'interface' of https://github.com/tiemvanderdeure/Specie…
tiemvanderdeure Feb 3, 2024
9a568d4
interace for ensemble, evaluate, and predict
tiemvanderdeure Feb 4, 2024
6cafca0
much improved predict functions
tiemvanderdeure Feb 4, 2024
5fff143
add data utils
tiemvanderdeure Feb 4, 2024
784c15c
tweaks to helper function and names
tiemvanderdeure Feb 4, 2024
4cde268
more consistent evaluate interface
tiemvanderdeure Feb 5, 2024
8a9ec05
test for the new interfaces
tiemvanderdeure Feb 5, 2024
9ddde20
remove a bunch of old code
tiemvanderdeure Feb 5, 2024
3a40ec2
fix a typo
tiemvanderdeure Feb 5, 2024
e358fa2
use Rasters.missingmask
tiemvanderdeure Feb 15, 2024
f300cb4
cleaner show method for evaluations
tiemvanderdeure Feb 17, 2024
eaec1ce
add accuracy as a default measure for evaluate
tiemvanderdeure Feb 17, 2024
59ae0d3
use Rasters.boolmask in predict
tiemvanderdeure Feb 17, 2024
5ae6c64
implement clamp
tiemvanderdeure Feb 17, 2024
a4ed643
add a fast method for boolean categorical values
tiemvanderdeure Feb 17, 2024
db0ecc9
interface for explain
tiemvanderdeure Feb 19, 2024
afdf8a1
rename sdm() and add threaded keyword
tiemvanderdeure Feb 20, 2024
436acfe
import threadsx
tiemvanderdeure Feb 20, 2024
1179124
enable threaded sdm fitting
tiemvanderdeure Feb 20, 2024
ad21265
change some comments
tiemvanderdeure Feb 20, 2024
8de28eb
add threadsx to toml
tiemvanderdeure Feb 20, 2024
56e418b
pass cpu backend
tiemvanderdeure Feb 20, 2024
e94c859
dispatch categorical_boolean on vectors of bool
tiemvanderdeure Feb 20, 2024
a104ce3
update tests
tiemvanderdeure Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
julia = "1.6"
Rasters = "0.10.2"
CategoricalDistributions = "0.1.14"
StatsModels = "0.7.3"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
6 changes: 4 additions & 2 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@ import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

using MLJBase: pdf

using Rasters: Raster, RasterStack
using Rasters: Raster, RasterStack, Band

using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift

using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count

export SDMensemble, predict, sdm, select, machines, machine_keys, shap,
export SDMensemble, predict, sdm_ensemble, select, machines, machine_keys, shap,
interactive_evaluation, interactive_response_curves,
remove_collinear

include("data_utils.jl")
include("collinearity.jl")
include("models.jl")
include("ensemble.jl")
include("predict.jl")
include("explain.jl")
include("evaluate.jl")
include("interface.jl")
include("plots.jl")

end
9 changes: 9 additions & 0 deletions src/data_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
### Miscelanious utilities to deal with data issues such as names, missing values

function _get_predictor_names(p, a)
predictors = Base.intersect(Tables.schema(a).names, Tables.schema(p).names)
predictors = filter!(!=(:geometry), predictors) # geometry is never a variable
length(predictors) > 0 || error("Presence and absence data have no common variable names - can't fit the ensemble.")
return predictors
end

20 changes: 10 additions & 10 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ machines(group::SDMgroup) = map(m -> m.machine, group)
sdm_machines(group::SDMgroup) = group.sdm_machines

# machine_key generates a unique key for a machine
machine_keys(group::SDMgroup) = ["$(group.model_name)_$(group.resampler_name)_$(m.fold)" for m in group]
machine_keys(group::SDMgroup) = [Symbol("$(group.model_name)_$(group.resampler_name)_$(m.fold)") for m in group]

# A bunch of functions are applied to an ensemble by applying to each group and reducing with vcat
for f in (:machines, :machine_keys, :sdm_machines)
@eval ($f)(ensemble::SDMensemble) = mapreduce(group -> ($f)(group), vcat, ensemble)
end

## Select methods
model_names(ensemble) = getfield.(ensemble.groups, :model_name)

## Select methods
# Function to convienently select some models from groups or ensembles
function select(group::SDMgroup, machine_indices::Vector{<:Int})
if length(machine_indices) == 0
Expand Down Expand Up @@ -111,10 +112,10 @@ function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble)
println(io, "Occurence data: Presence-Absence with $n_presence presences and $n_absence absences")
println(io, "Predictors: $(join(["$key ($scitype)" for (key, scitype) in zip(nam, sci)], ", "))")

model_names = getfield.(ensemble.groups, :model_name)
m_names = model_names(ensemble)
resampler_names = getfield.(ensemble.groups, :resampler_name)
n_models = Base.length.(ensemble.groups)
table_cols = hcat(model_names, resampler_names, n_models)
table_cols = hcat(m_names, resampler_names, n_models)
header = (["model", "resampler", "machines"])
PrettyTables.pretty_table(io, table_cols; header = header)

Expand All @@ -135,7 +136,7 @@ Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.groups)
# Turns models into a NamedTuple with unique keys
function _givenames(models::Vector)
names = map(models) do model
replace(MLJBase.name(model), r"Classifier$"=>"")
Base.replace(MLJBase.name(model), r"Classifier$"=>"", r"Binary$"=>"")
end
for (name, n) in StatsBase.countmap(names)
if n > 1
Expand Down Expand Up @@ -170,13 +171,12 @@ function _fit_sdm_group(

end

function sdm(
function _fit_sdm_ensemble(
presences,
absence,
models,
resamplers;
var_keys::Vector{Symbol} = intersect(Tables.schema(absence).names, Tables.schema(presences).names),
scitypes::Vector{DataType} = [MLJBase.scitype(Tables.schema(presences).types) for key in var_keys],
resamplers,
predictors::Vector{Symbol},
verbosity::Int = 0
)
@assert Tables.istable(presences) && Tables.istable(absence)
Expand All @@ -186,7 +186,7 @@ function sdm(
n_total = n_presence + n_absence

# merge presence and absence data into one namedtuple of vectors
predictor_values = NamedTuple{Tuple(var_keys)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in var_keys])
predictor_values = NamedTuple{Tuple(predictors)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in predictors])
response_values = CategoricalArrays.categorical(
[falses(n_absence); trues(n_presence)];
levels = [false, true], ordered = false
Expand Down
140 changes: 76 additions & 64 deletions src/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct SDMensembleEvaluation <: AbstractVector{SDMgroupEvaluation}
results
end

SDMevaluation = Union{SDMmachineEvaluation, SDMgroupEvaluation, SDMensembleEvaluation}
SDMgroupOrEnsembleEvaluation = Union{SDMgroupEvaluation, SDMensembleEvaluation}

ScoreType = NamedTuple{(:score, :threshold), Tuple{Float64, Union{Missing, Float64}}}
Expand All @@ -28,38 +29,43 @@ Base.getindex(group::SDMgroupEvaluation, i) = group.machine_evaluations[i]
Base.size(ensemble::SDMensembleEvaluation) = Base.size(ensemble.group_evaluations)
Base.size(group::SDMgroupEvaluation) = Base.size(group.machine_evaluations)

function machine_evaluations(groupeval::SDMgroupEvaluation; mean = false)
"""
machine_evaluations(eval)

Get the scores for each machine in an evaluation, which can be either an
`SDMgroupEvaluation` or an `SDMensembleEvaluation`.

The return type is a nested structure of `NamedTuple`s.
The `NamedTuple` returned has two keys `train` and `test`, which each have keys
corresponding to the measures specified in [`evaluate`](@ref).

## Example
```julia
evaluation = SDM.evaluate(ensemble; measures = (; accuracy, auc))
machine_aucs = SDM.machine_evaluations(evaluation).train.auc
```


"""
machine_evaluations

function machine_evaluations(groupeval::SDMgroupEvaluation)
map((:train, :test)) do set
map(keys(groupeval.measures)) do key
r = map(groupeval) do e
map(groupeval) do e
e.results[set][key].score
end

if mean
Statistics.mean(r)
else
r
end

end |> NamedTuple{keys(groupeval.measures)}
end |> NamedTuple{(:train, :test)}
end

function machine_evaluations(ensembleeval::SDMensembleEvaluation; mean = false)
function machine_evaluations(ensembleeval::SDMensembleEvaluation)
map((:train, :test)) do set
map(keys(ensembleeval.measures)) do key
r = mapreduce(vcat, ensembleeval) do groupeval
mapreduce(vcat, ensembleeval) do groupeval
map(groupeval) do e
e.results[set][key].score
end
end

if mean
Statistics.mean(r)
else
r
end

end |> NamedTuple{keys(ensembleeval.measures)}
end |> NamedTuple{(:train, :test)}
end
Expand All @@ -77,52 +83,72 @@ function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMmachineEvaluat
PrettyTables.pretty_table(io, table_cols; header = header)
end

function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupOrEnsembleEvaluation)
function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupEvaluation)
measures = collect(keys(evaluation.measures))
train_scores, test_scores = machine_evaluations(evaluation, mean = true)

group_scores = map(measures) do key
evaluation.results[key].score
end
train_scores, test_scores = machine_evaluations(evaluation)
folds = getfield.(evaluation.group, :fold)

println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures")

table_cols = hcat(measures, collect(group_scores), collect(train_scores), collect(test_scores))
header = (["measure", "performance of avg", "avg. train performance", "avg. test performance"])
PrettyTables.pretty_table(io, table_cols; header = header)
println(io, "Testing data")
PrettyTables.pretty_table(io, merge((; fold = folds), test_scores))
println(io, "Training data")
PrettyTables.pretty_table(io, merge((; fold = folds), train_scores))
end

function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMensembleEvaluation)
measures = collect(keys(evaluation.measures))
models = getfield.(evaluation.ensemble, :model_name)

# get scores from each group
scores = machine_evaluations.(evaluation)
# get mean test and train from each group for each measure.
# then invert to a namedtuple where measures are keys
test_scores = map(scores) do score
map(Statistics.mean, score.test)
end |> Tables.columntable
train_scores = map(scores) do score
map(Statistics.mean, score.train)
end |> Tables.columntable

println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures")

println(io, "Testing data")
PrettyTables.pretty_table(io, merge((; model = models), test_scores))
println(io, "Training data")
PrettyTables.pretty_table(io, merge((; model = models), train_scores))
end

## Core evuator
## Core evaluator
# internal method to get a vector of scores from y_hats, ys, and a namedtuple of measures
function _evaluate(y_hat, y, measures)
map(measures) do measure
function _evaluate(y_hat::MLJBase.UnivariateFiniteArray, y::CategoricalArrays.CategoricalArray, measures)
kinds_of_proxy = map(StatisticalMeasuresBase.kind_of_proxy, measures)

# if any are literal targets (threshold-dependent), compute the confusion matrices outside the loop
if any(map(kind -> kind == StatisticalMeasures.LearnAPI.LiteralTarget(), kinds_of_proxy))
scores = pdf.(y_hat, true)
thresholds = unique(scores)
thresholded_scores = map(t -> CategoricalArrays.categorical(scores .>= t, levels = [false, true]), thresholds)
conf_mats = StatisticalMeasures.ConfusionMatrix(; levels = [false, true], checks = false).(thresholded_scores, Ref(y))
else
conf_mats = nothing
end

map(measures, kinds_of_proxy) do measure, kind
# If the measures is threshold independent
if StatisticalMeasuresBase.kind_of_proxy(measure) == StatisticalMeasures.LearnAPI.Distribution()
if kind == StatisticalMeasures.LearnAPI.Distribution()
return ScoreType((score = measure(y_hat, y), threshold = missing))
else # else the measure uses thresholds
# first get all possible thresholded values
scores = pdf.(y_hat, true)
thresholds = unique(scores)
thresholded_scores = map(t -> CategoricalArrays.categorical(scores .>= t, levels = [false, true]), thresholds)

else # else the measure uses thresholds
# find the max value and corresponding threshold for measure
all_scores = measure.(thresholded_scores, Ref(y))
all_scores = measure.(conf_mats)
max_score = findmax(all_scores)
return ScoreType((score = max_score[1], threshold = thresholds[max_score[2]]))
end
end
end

# Evaluate a single SDMmachine
function evaluate(
sdm_machine::SDMmachine;
measures = (;
auc = StatisticalMeasures.auc,
log_loss = StatisticalMeasures.log_loss,
kappa = StatisticalMeasures.kappa
)
)
function _evaluate(sdm_machine::SDMmachine, measures::NamedTuple)
results = map((train = sdm_machine.train_rows, test = sdm_machine.test_rows)) do rows
y_hat = MLJBase.predict(sdm_machine.machine, rows = rows)
y = data(sdm_machine).response[rows]
Expand All @@ -133,19 +159,12 @@ function evaluate(
end

# Evaluate a group
function evaluate(
group::SDMgroup;
measures = (;
auc = StatisticalMeasures.auc,
log_loss = StatisticalMeasures.log_loss,
kappa = StatisticalMeasures.kappa
)
)
function _evaluate(group::SDMgroup, measures)
machine_evaluations = map(m -> (evaluate(m; measures = measures)), group)

# average group prediction
y_hat = mapreduce(+, machines(group)) do mach
MLJBase.predict(mach) # MLJBase.predict because StatisticalMeasures expect UniverateFiniteArrays.
MLJBase.predict(mach)
end / length(group)

y = data(group).response
Expand All @@ -159,14 +178,7 @@ function evaluate(
)
end

function evaluate(
ensemble::SDMensemble,
measures = (;
auc = StatisticalMeasures.auc,
log_loss = StatisticalMeasures.log_loss,
kappa = StatisticalMeasures.kappa)
)

function _evaluate(ensemble::SDMensemble, measures)
group_evaluations = map(m -> (evaluate(m; measures = measures)), ensemble)

# average ensemble prediction
Expand Down
Loading
Loading