diff --git a/Project.toml b/Project.toml index dc489b5..a74e07c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.0.0-DEV" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" +ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" @@ -18,6 +19,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065" @@ -28,12 +30,17 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" [compat] -julia = "1.6" +CategoricalDistributions = "0.1.14" +Rasters = "0.10.1" StatsModels = "0.7.3" +julia = "1.6" [extras] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/src/SpeciesDistributionModels.jl b/src/SpeciesDistributionModels.jl index c04d28d..b9deae9 100644 --- a/src/SpeciesDistributionModels.jl +++ b/src/SpeciesDistributionModels.jl @@ -1,27 +1,35 @@ module SpeciesDistributionModels -import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra +import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra, Random, ThreadsX import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess using MLJBase: pdf -using Rasters: Raster, RasterStack +using Rasters: Raster, RasterStack, Band using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count -export SDMensemble, predict, sdm, select, machines, machine_keys, shap, +using ComputationalResources: CPU1, CPUThreads + +export SDMensemble, predict, sdm, select, machines, machine_keys, interactive_evaluation, interactive_response_curves, - remove_collinear + remove_collinear, + explain, variable_importance, ShapleyValues, + SDMmachineExplanation, SDMgroupExplanation, SDMensembleExplanation, + SDMmachineEvaluation, SDMgroupEvaluation, SDMensembleEvaluation +include("data_utils.jl") include("collinearity.jl") include("models.jl") include("ensemble.jl") include("predict.jl") -include("explain.jl") +include("explain/explain.jl") +include("explain/shapley.jl") include("evaluate.jl") +include("interface.jl") include("plots.jl") end diff --git a/src/collinearity.jl b/src/collinearity.jl index 594b42e..8d4d911 100644 --- a/src/collinearity.jl +++ b/src/collinearity.jl @@ -85,7 +85,7 @@ function _vifstep(data, datakeys, threshold, verbose, vifmethod, remove_perfectl if verbose @info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining" end - datakeys = datakeys[Base.setdiff(1:length(datakeys), maxvif[2])] # not very elegant! + datakeys = datakeys[Base.setdiff(1:length(datakeys), maxvif[2])] end end diff --git a/src/data_utils.jl b/src/data_utils.jl new file mode 100644 index 0000000..7308186 --- /dev/null +++ b/src/data_utils.jl @@ -0,0 +1,17 @@ +### Miscelanious utilities to deal with data issues such as names, missing values + +# Convert a BitArray to a CategoricalArray. Faster and type-stable version of `categorical` +function boolean_categorical(A::BitArray{N}) where N + CategoricalArrays.CategoricalArray{Bool, N, UInt8}(A, levels=[false, true], ordered=false) +end +boolean_categorical(A::AbstractVector{Bool}) = boolean_categorical(BitArray(A)) + +function _get_predictor_names(p, a) + predictors = Base.intersect(Tables.schema(a).names, Tables.schema(p).names) + predictors = filter!(!=(:geometry), predictors) # geometry is never a variable + length(predictors) > 0 || error("Presence and absence data have no common variable names - can't fit the ensemble.") + return predictors +end + +_map(::CPU1) = Base.map +_map(::CPUThreads) = ThreadsX.map diff --git a/src/ensemble.jl b/src/ensemble.jl index b179d77..cbcda0e 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -44,15 +44,16 @@ machines(group::SDMgroup) = map(m -> m.machine, group) sdm_machines(group::SDMgroup) = group.sdm_machines # machine_key generates a unique key for a machine -machine_keys(group::SDMgroup) = ["$(group.model_name)_$(group.resampler_name)_$(m.fold)" for m in group] +machine_keys(group::SDMgroup) = [Symbol("$(group.model_name)_$(group.resampler_name)_$(m.fold)") for m in group] # A bunch of functions are applied to an ensemble by applying to each group and reducing with vcat for f in (:machines, :machine_keys, :sdm_machines) @eval ($f)(ensemble::SDMensemble) = mapreduce(group -> ($f)(group), vcat, ensemble) end -## Select methods +model_names(ensemble) = getfield.(ensemble.groups, :model_name) +## Select methods # Function to convienently select some models from groups or ensembles function select(group::SDMgroup, machine_indices::Vector{<:Int}) if length(machine_indices) == 0 @@ -111,10 +112,10 @@ function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble) println(io, "Occurence data: Presence-Absence with $n_presence presences and $n_absence absences") println(io, "Predictors: $(join(["$key ($scitype)" for (key, scitype) in zip(nam, sci)], ", "))") - model_names = getfield.(ensemble.groups, :model_name) + m_names = model_names(ensemble) resampler_names = getfield.(ensemble.groups, :resampler_name) n_models = Base.length.(ensemble.groups) - table_cols = hcat(model_names, resampler_names, n_models) + table_cols = hcat(m_names, resampler_names, n_models) header = (["model", "resampler", "machines"]) PrettyTables.pretty_table(io, table_cols; header = header) @@ -135,7 +136,7 @@ Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.groups) # Turns models into a NamedTuple with unique keys function _givenames(models::Vector) names = map(models) do model - replace(MLJBase.name(model), r"Classifier$"=>"") + Base.replace(MLJBase.name(model), r"Classifier$"=>"", r"Binary$"=>"") end for (name, n) in StatsBase.countmap(names) if n > 1 @@ -159,10 +160,11 @@ function _fit_sdm_group( folds, model_name, resampler_name, - verbosity + verbosity, + cpu_backend ) - machines = map(enumerate(folds)) do (f, (train, test)) + machines = _map(cpu_backend)(enumerate(folds)) do (f, (train, test)) _fit_sdm_model(predictor_values, response_values, model, f, train, test, verbosity) end @@ -170,14 +172,14 @@ function _fit_sdm_group( end -function sdm( +function _fit_sdm_ensemble( presences, absence, models, - resamplers; - var_keys::Vector{Symbol} = intersect(Tables.schema(absence).names, Tables.schema(presences).names), - scitypes::Vector{DataType} = [MLJBase.scitype(Tables.schema(presences).types) for key in var_keys], - verbosity::Int = 0 + resamplers, + predictors::Vector{Symbol}, + verbosity::Int, + cpu_backend ) @assert Tables.istable(presences) && Tables.istable(absence) @@ -186,7 +188,7 @@ function sdm( n_total = n_presence + n_absence # merge presence and absence data into one namedtuple of vectors - predictor_values = NamedTuple{Tuple(var_keys)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in var_keys]) + predictor_values = NamedTuple{Tuple(predictors)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in predictors]) response_values = CategoricalArrays.categorical( [falses(n_absence); trues(n_presence)]; levels = [false, true], ordered = false @@ -198,7 +200,7 @@ function sdm( sdm_groups = mapreduce(vcat, collect(keys(resamplers_))) do resampler_key resampler = resamplers_[resampler_key] folds = MLJBase.train_test_pairs(resampler, 1:n_total, response_values) ## get indices - map(collect(keys(models_))) do model_key + _map(cpu_backend)(collect(keys(models_))) do model_key model = models_[model_key] _fit_sdm_group( predictor_values, @@ -208,7 +210,8 @@ function sdm( folds, model_key, resampler_key, - verbosity + verbosity, + cpu_backend ) end end diff --git a/src/evaluate.jl b/src/evaluate.jl index 067b079..cea8f5e 100644 --- a/src/evaluate.jl +++ b/src/evaluate.jl @@ -18,48 +18,55 @@ struct SDMensembleEvaluation <: AbstractVector{SDMgroupEvaluation} results end -SDMgroupOrEnsembleEvaluation = Union{SDMgroupEvaluation, SDMensembleEvaluation} - ScoreType = NamedTuple{(:score, :threshold), Tuple{Float64, Union{Missing, Float64}}} +SDMevaluation = Union{SDMmachineEvaluation, SDMgroupEvaluation, SDMensembleEvaluation} +SDMgroupOrEnsembleEvaluation = Union{SDMgroupEvaluation, SDMensembleEvaluation} + +# Basic operations on evaluate objects Base.getindex(ensemble::SDMensembleEvaluation, i) = ensemble.group_evaluations[i] Base.getindex(group::SDMgroupEvaluation, i) = group.machine_evaluations[i] Base.size(ensemble::SDMensembleEvaluation) = Base.size(ensemble.group_evaluations) Base.size(group::SDMgroupEvaluation) = Base.size(group.machine_evaluations) -function machine_evaluations(groupeval::SDMgroupEvaluation; mean = false) +""" + machine_evaluations(eval) + + Get the scores for each machine in an evaluation, which can be either an + `SDMgroupEvaluation` or an `SDMensembleEvaluation`. + + The return type is a nested structure of `NamedTuple`s. + The `NamedTuple` returned has two keys `train` and `test`, which each have keys + corresponding to the measures specified in [`evaluate`](@ref). + + ## Example + ```julia + evaluation = SDM.evaluate(ensemble; measures = (; accuracy, auc)) + machine_aucs = SDM.machine_evaluations(evaluation).train.auc + ``` + + +""" +machine_evaluations + +function machine_evaluations(groupeval::SDMgroupEvaluation) map((:train, :test)) do set map(keys(groupeval.measures)) do key - r = map(groupeval) do e + map(groupeval) do e e.results[set][key].score end - - if mean - Statistics.mean(r) - else - r - end - end |> NamedTuple{keys(groupeval.measures)} end |> NamedTuple{(:train, :test)} end - -function machine_evaluations(ensembleeval::SDMensembleEvaluation; mean = false) +function machine_evaluations(ensembleeval::SDMensembleEvaluation) map((:train, :test)) do set map(keys(ensembleeval.measures)) do key - r = mapreduce(vcat, ensembleeval) do groupeval + mapreduce(vcat, ensembleeval) do groupeval map(groupeval) do e e.results[set][key].score end end - - if mean - Statistics.mean(r) - else - r - end - end |> NamedTuple{keys(ensembleeval.measures)} end |> NamedTuple{(:train, :test)} end @@ -77,37 +84,69 @@ function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMmachineEvaluat PrettyTables.pretty_table(io, table_cols; header = header) end -function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupOrEnsembleEvaluation) +function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupEvaluation) measures = collect(keys(evaluation.measures)) - train_scores, test_scores = machine_evaluations(evaluation, mean = true) - - group_scores = map(measures) do key - evaluation.results[key].score - end + train_scores, test_scores = machine_evaluations(evaluation) + folds = getfield.(evaluation.group, :fold) println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures") - table_cols = hcat(measures, collect(group_scores), collect(train_scores), collect(test_scores)) - header = (["measure", "performance of avg", "avg. train performance", "avg. test performance"]) - PrettyTables.pretty_table(io, table_cols; header = header) + println(io, "Testing data") + PrettyTables.pretty_table(io, merge((; fold = folds), test_scores)) + println(io, "Training data") + PrettyTables.pretty_table(io, merge((; fold = folds), train_scores)) +end + +function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMensembleEvaluation) + measures = collect(keys(evaluation.measures)) + models = getfield.(evaluation.ensemble, :model_name) + + # get scores from each group + scores = machine_evaluations.(evaluation) + # get mean test and train from each group for each measure. + # then invert to a namedtuple where measures are keys + test_scores = map(scores) do score + map(Statistics.mean, score.test) + end |> Tables.columntable + train_scores = map(scores) do score + map(Statistics.mean, score.train) + end |> Tables.columntable + println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures") + + println(io, "Testing data") + PrettyTables.pretty_table(io, merge((; model = models), test_scores)) + println(io, "Training data") + PrettyTables.pretty_table(io, merge((; model = models), train_scores)) end -## Core evuator +## Core evaluator # internal method to get a vector of scores from y_hats, ys, and a namedtuple of measures -function _evaluate(y_hat, y, measures) - map(measures) do measure +function _evaluate(y_hat::MLJBase.UnivariateFiniteArray, y::CategoricalArrays.CategoricalArray, measures) + kinds_of_proxy = map(StatisticalMeasuresBase.kind_of_proxy, measures) + + # if any are literal targets (threshold-dependent), compute the confusion matrices outside the loop + if any(map(kind -> kind == StatisticalMeasures.LearnAPI.LiteralTarget(), kinds_of_proxy)) + scores = pdf.(y_hat, true) + thresholds = unique(scores) + levels = [false, true] + # use the internal method to avoid constructing indexer every time + indexer = StatisticalMeasures.LittleDict(levels[i] => i for i in eachindex(levels)) |> StatisticalMeasures.freeze + conf_mats = broadcast(thresholds) do t + y_ = boolean_categorical(scores .>= t) + StatisticalMeasures.ConfusionMatrices._confmat(y_, y, indexer, levels, false) + end + else + conf_mats = nothing + end + + map(measures, kinds_of_proxy) do measure, kind # If the measures is threshold independent - if StatisticalMeasuresBase.kind_of_proxy(measure) == StatisticalMeasures.LearnAPI.Distribution() + if kind == StatisticalMeasures.LearnAPI.Distribution() return ScoreType((score = measure(y_hat, y), threshold = missing)) - else # else the measure uses thresholds - # first get all possible thresholded values - scores = pdf.(y_hat, true) - thresholds = unique(scores) - thresholded_scores = map(t -> CategoricalArrays.categorical(scores .>= t, levels = [false, true]), thresholds) - + else # else the measure uses thresholds # find the max value and corresponding threshold for measure - all_scores = measure.(thresholded_scores, Ref(y)) + all_scores = measure.(conf_mats) max_score = findmax(all_scores) return ScoreType((score = max_score[1], threshold = thresholds[max_score[2]])) end @@ -115,14 +154,7 @@ function _evaluate(y_hat, y, measures) end # Evaluate a single SDMmachine -function evaluate( - sdm_machine::SDMmachine; - measures = (; - auc = StatisticalMeasures.auc, - log_loss = StatisticalMeasures.log_loss, - kappa = StatisticalMeasures.kappa - ) -) +function _evaluate(sdm_machine::SDMmachine, measures::NamedTuple) results = map((train = sdm_machine.train_rows, test = sdm_machine.test_rows)) do rows y_hat = MLJBase.predict(sdm_machine.machine, rows = rows) y = data(sdm_machine).response[rows] @@ -133,20 +165,12 @@ function evaluate( end # Evaluate a group -function evaluate( - group::SDMgroup; - measures = (; - auc = StatisticalMeasures.auc, - log_loss = StatisticalMeasures.log_loss, - kappa = StatisticalMeasures.kappa - ) -) +function _evaluate(group::SDMgroup, measures) machine_evaluations = map(m -> (evaluate(m; measures = measures)), group) # average group prediction - y_hat = mapreduce(+, machines(group)) do mach - MLJBase.predict(mach) # MLJBase.predict because StatisticalMeasures expect UniverateFiniteArrays. - end / length(group) + p = predict(group, data(group).predictor, reducer = Statistics.mean) + y_hat = MLJBase.UnivariateFinite(boolean_categorical([false, true]), p, augment = true) y = data(group).response group_evaluation = _evaluate(y_hat, y, measures) @@ -159,21 +183,12 @@ function evaluate( ) end -function evaluate( - ensemble::SDMensemble, - measures = (; - auc = StatisticalMeasures.auc, - log_loss = StatisticalMeasures.log_loss, - kappa = StatisticalMeasures.kappa) - ) - +function _evaluate(ensemble::SDMensemble, measures) group_evaluations = map(m -> (evaluate(m; measures = measures)), ensemble) # average ensemble prediction - y_hat = mapreduce(+, machines(ensemble)) do mach - MLJBase.predict(mach) - end / n_machines(ensemble) - + p = predict(ensemble, data(ensemble).predictor, reducer = Statistics.mean) + y_hat = MLJBase.UnivariateFinite(boolean_categorical([false, true]), p, augment = true) y = data(ensemble).response ensemble_evaluation = _evaluate(y_hat, y, measures) diff --git a/src/explain.jl b/src/explain.jl deleted file mode 100644 index 7c1aaea..0000000 --- a/src/explain.jl +++ /dev/null @@ -1,35 +0,0 @@ -struct SDMshapley - values::Vector{<:NamedTuple} # Contains all shap values for all models - importances::Vector{<:NamedTuple} # Contais mean absolute shap for each variable for each model - ensemble::SDMensemble - summary -end - -Base.size(shap::SDMshapley) = Base.size(shap.values) -Base.length(shap::SDMshapley) = Base.length(shap.values) - -function Base.show(io::IO, mime::MIME"text/plain", shap::SDMshapley) - println(io, "Shapley evaluation for SDM ensemble with $(Base.length(shap)) models") - - println(io, "Mean feature importance:") - Base.show(io, mime, shap.summary) -end - -function shap(ensemble; parallelism = Shapley.CPUThreads(), n_samples = 50) - shapvalues = map(ensemble.trained_models) do model - Shapley.shapley( - x -> Float64.(MLJBase.pdf.(MLJBase.predict(model.machine, x), true)), # some ml models return float32s - where to handle this? - Shapley.MonteCarlo(parallelism, n_samples), - ensemble.data.predictor - ) - end - - importances = map(vals -> map(val -> mapreduce(abs, +, val) / Base.length(val), vals), shapvalues) - - summary = NamedTuple(var => mapreduce(x -> getfield(x, var), +, importances) / Base.length(importances) for var in ensemble.predictors) - - return SDMshapley(shapvalues, importances, ensemble, summary) -end - - - diff --git a/src/explain/explain.jl b/src/explain/explain.jl new file mode 100644 index 0000000..dc04d7c --- /dev/null +++ b/src/explain/explain.jl @@ -0,0 +1,63 @@ +abstract type SDMexplainMethod end + +# Type definitions for explanation objects +struct SDMmachineExplanation + machine::SDMmachine + method::SDMexplainMethod + values::NamedTuple # Contains values + data::NamedTuple # Contains the data used to explain +end + +struct SDMgroupExplanation <: AbstractVector{SDMmachineExplanation} + group::SDMgroup + machine_explanations::Vector{SDMmachineExplanation} +end + +struct SDMensembleExplanation <: AbstractVector{SDMgroupExplanation} + ensemble::SDMensemble + group_explanations::Vector{SDMgroupExplanation} +end + +#### Basic operations on Explanation objects #### +Base.getindex(e::SDMgroupExplanation, i::Integer) = e.machine_explanations[i] +Base.getindex(e::SDMensembleExplanation, i::Integer) = e.group_explanations[i] +Base.size(e::SDMgroupExplanation) = Base.size(e.machine_explanations) +Base.size(e::SDMensembleExplanation) = Base.size(e.group_explanations) + +variables(e::SDMmachineExplanation) = keys(e.values) +data(e::SDMmachineExplanation) = e.data +method(e::SDMmachineExplanation) = e.method + +for f in [:data, :variables, :method] + @eval $f(e::Union{SDMgroupExplanation, SDMensembleExplanation}) = $f(first(e)) +end + +machine_explanations(e::SDMgroupExplanation) = e.machine_explanations +machine_explanations(e::SDMensembleExplanation) = reduce(vcat, machine_explanations.(e)) + + +#### Show methods #### +function Base.show(io::IO, mime::MIME"text/plain", expl::SDMmachineExplanation) + println(io, "$(typeof(expl)) using method $(typeof(method(expl)))") +end +function Base.show(io::IO, mime::MIME"text/plain", expl::SDMgroupExplanation) + println(io, "$(typeof(expl)) using method $(typeof(method(expl)))") +end +function Base.show(io::IO, mime::MIME"text/plain", expl::SDMensembleExplanation) + println(io, "$(typeof(expl)) using method $(typeof(method(expl)))") +end + + +## By default, variable importance is absolute mean for value for each variable in shapvalues +function variable_importance(expl::SDMmachineExplanation) + map(vals -> Statistics.mean(abs, vals), expl.values) +end + +function variable_importance(expl::Union{SDMgroupExplanation, SDMensembleExplanation}) + group_var_imp = map(variable_importance, machine_explanations(expl)) + map(Statistics.mean, Tables.columntable(group_var_imp)) +end + +#summary = NamedTuple(var => mapreduce(x -> getfield(x, var), +, importances) / Base.length(importances) for var in ensemble.predictors) + + diff --git a/src/explain/shapley.jl b/src/explain/shapley.jl new file mode 100644 index 0000000..0c2922d --- /dev/null +++ b/src/explain/shapley.jl @@ -0,0 +1,46 @@ +""" + ShapleyValues(algorithm::Shapley.Algorithm) + ShapleyValues(N::Integer; threaded = true, rng = Random.GLOBAL_RNG) + +Use to specify use Shapley values as method in [`explain`](@ref). +If an integer `N`, and optionally `threaded` and `rng` is supplied, `MonteCarlo` sampling is used, +where `N` is the number of iterations (samples). More samples will result in more accurate results, +but will take more time to compute. +""" +struct ShapleyValues <: SDMexplainMethod + algorithm::Shapley.Algorithm +end +# Default to MonteCarlo algorithm with 100 samples +function ShapleyValues(N::Integer; threaded = true, rng = Random.GLOBAL_RNG) + resource = threaded ? CPUThreads() : CPU1() + algorithm = Shapley.MonteCarlo(resource, N, rng) + ShapleyValues(algorithm) +end + +function _explain(mach::SDMmachine, method::ShapleyValues, d, predictors) + shapvalues = map(predictors) do predictor + Shapley.shapley( + x -> _reformat_and_predict(mach, x, false), # some ml models return float32s - where to handle this? + method.algorithm, + d, + predictor, + d + ) + end |> NamedTuple{predictors} + return SDMmachineExplanation(mach, method, shapvalues, d) +end + +function _explain(group::SDMgroup, method::ShapleyValues, d, predictors) + machine_explanations = map(group) do mach + _explain(mach, method, d, predictors) + end + return SDMgroupExplanation(group, machine_explanations) +end + +function _explain(ensemble::SDMensemble, method::ShapleyValues, d, predictors) + group_explanations = map(ensemble) do group + _explain(group, method, d, predictors) + end + + return SDMensembleExplanation(ensemble, group_explanations) +end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..d6536cf --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,111 @@ +""" + sdm(presences, absences; models, [resampler], [predictors], [verbosity]) + +Construct an ensemble with input data specified in `presences` and `absences`. + +The first input argument is species presences and the second (pseudo-)absences. Both presence and absence data must be Tables-compatible (e.g., a `DataFrame`, a `Vector` of `NamedTuple`, but not an `Array`) + +## Keywords +`models`: a `Vector` of the models to be used in the ensemble. All models must be MLJ-supported Classifiers. +For a full list of supported models, see https://alan-turing-institute.github.io/MLJ.jl/stable/model_browser/#Classification +`resampler`: The resampling strategy to be used of type `MLJBase.ResamplingStrategy`. Defaults to 5-fold cross validation. +`predictors`: a `Vector` of `Symbols` with the names of the predictor values to be used. By default, all pdf +`verbosity`: an `Int` value that regulates how much information is printed. + +## Example + +""" +function sdm( + presences, + absences; + models, + resampler = MLJBase.CV(; nfolds = 5, shuffle = true), + predictors = _get_predictor_names(presences, absences), + verbosity = 0, + threaded = false +) + + predictors = collect(predictors) + + # Check the predictor values are valid + :geometry in [predictors] && error("Predictors cannot be called :geometry") + Base.intersect(predictors, Tables.schema(presences).names) == predictors || + error("The presence data does not contain all predictors specified") + Base.intersect(predictors, Tables.schema(absences).names) == predictors || + error("The absence data does not contain all predictors specified") + + backend = threaded ? CPUThreads() : CPU1() + + _fit_sdm_ensemble(presences, absences, models, [resampler], predictors, verbosity, backend) +end + +""" + evaluate(x; measures) + +Evaluate `x`, which could be a SDMmachine, SDMgroup, or SDMensemble, +by applying the measures provided to the data used to built an ensemble, +and return an evaluation object. + +`measures` is a `NamedTuple` of measures. The keys are used to identify the measures. +This defaults to using auc, log_loss, and kappa. + +For threshold-dependent measures, the highest score as well as the threshold at which the highest scores is reached are reported. + +A list of measurse is available here: https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures/#aliases. However, note that not all measures are useful. +""" +function evaluate( # Define this as an extension of MLJBase.evaluate?? + x; + measures = (; + StatisticalMeasures.accuracy, + StatisticalMeasures.auc, + StatisticalMeasures.log_loss, + StatisticalMeasures.kappa + ) +) + _evaluate(x, measures) +end + + +""" + explain(ensemble::SDMensemble; method, [data], [predictors]) + +Generate response curves for `ensemble`. + +## Keywords +- `method` is the algorithm to use. See ShapleyValues +- `data` is the data to use to generate response curves, and defaults to the data used to train the ensemble +- `predictors`: which predictors to generate response curves for. Defaults to all variables in `data`. + +""" +function explain(e::SDMensemble; method, data = data(e).predictor, predictors = keys(data)) + _explain(e, method, data, predictors) +end + +""" + predict(SDMobject, newdata; clamp = false, [reducer], [by_group]) + +Use an `SDMmachine`, `SDMgroup`, or `SDMensemble` to predict habitat suitability for some data, optionally summarized for the entire ensemble, or for each `SDMgroup`. + +`newdata` can be either a `RasterStack`, or some other data which must be compatible with Tables.jl. It must have all predictor variables used to train the models in its columns (or layers in case of a RasterStack). + +If `clamp` is set to `true`, the predictions are clamped to the interval seen during of `SDMobject` + +Optionally provide a function to summarize the output as the `reducer` argument. This would typically be `Statistics.mean` or `Statistics.median`. +If `by_group` is set to `true`, the data is reduced for each `SDMgroup`, if it is set to `false` (the default), it reduced across the entire ensemble. + +If `newdata` is a `RasterStack`, the `predict` returns a `Raster`; otherwise, it returns a `NamedTuple` of `Vectors` +Habitat suitability is always reported as a floating-point number between 0 and 1. +""" +function predict(m::SDMmachine, d; clamp = false) + _reformat_and_predict(m, d, clamp) +end +function predict(g::SDMgroup, d; clamp = false, reducer = nothing) + _reformat_and_predict(g, d, clamp, reducer) +end +function predict(e::SDMensemble, d; clamp = false, reducer = nothing, by_group = false) + by_group && isnothing(reducer) && error("If by_group is true, reducer must be specified") + _reformat_and_predict(e, d, clamp, reducer, by_group) +end + + + diff --git a/src/predict.jl b/src/predict.jl index 5763ba9..2aad156 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -1,61 +1,103 @@ -function _unsafe_predict(mach::SDMmachine, data) - CategoricalDistributions.pdf.(MLJBase.predict(mach.machine, data), true) -end +#### Helper functions #### +# Reformat data so that it can be used in predict. Different models use different data types +function _reformat_data(m::SDMmachine, d, clamp::Bool) + traindata = data(m).predictor + newdata = Tables.columntable(d)[keys(traindata)] + if clamp + for k in keys(traindata) + if !(MLJBase.scitype(traindata[k]) <: AbstractVector{<:MLJBase.Finite}) # if data is categorical, don't clamp + newdata[k] .= Base.clamp.(newdata[k], Base.extrema(traindata[k])...) + end + end + end -function predict(mach::SDMmachine, d) - data_ = Tables.columntable(d)[keys(data(mach).predictor)] - _unsafe_predict(mach, data_) + return MLJBase.reformat(m.machine.old_model, newdata)[1] end +#### _predict methods #### +# _predict uses already-reformatted data. +# _reformat_and_predict methods first reformat and then call _predict -function predict(s::SDMgroupOrEnsemble, d) - data_ = Tables.columntable(d)[keys(data(s).predictor)] - - mapreduce(hcat, sdm_machines(s)) do mach - _unsafe_predict(mach, data_) - end +# Machine-level _predict method. All other _predict methods eventually call this +function _predict(m::SDMmachine, data) + # predict + prediction = MLJBase.predict(m.machine.old_model, m.machine.fitresult, data) + # convert to Floats + MLJBase.pdf.(prediction, true) end -#= on pause until RasterStacks are compatible with Tables.jl -function predict(ensemble::SDMensemble, data::Rasters.RasterStack) - preds = Tuple(ensemble.predictors) - - # Check dimensions match and variables exist - data1 = data[first(preds)] - dims1 = Rasters.dims(data1) - if ~all(p -> Rasters.dims(data[p]) == dims1, preds) error("Dimensions of data do not match") end +function _reformat_and_predict(m::SDMmachine, data, clamp) + _predict(m, _reformat_data(m, data, clamp)) +end - # Find missing values -- maybe add this as method to RasterStack? - missings = falses(dims1) - for l in data[preds] - missings .|= l .=== Rasters.missingval(l) - end +## Group level _predict methods +function _predict(g::SDMgroup, data, ::Nothing) + pr = map(m -> _predict(m, data), g) + return NamedTuple{Tuple(machine_keys(g))}(pr) +end - # Take non-missing data and convert to namedtuple of vectors - data_ = NamedTuple{preds}(map(p -> data[p][.~missings], preds)) +function _predict(g::SDMgroup, data, reducer::Function) + pr = map(m -> _predict(m, data), g) + return map((d...) -> reducer(d), pr...) +end - # Reformat data to named tuple of vectors - @time data_ = NamedTuple{Tuple(ensemble.predictors)}([vec(data[pre]) for pre in preds]) +function _reformat_and_predict(g::SDMgroup, data, clamp, reducer) + _predict(g, _reformat_data(first(g), data, clamp), reducer) +end - # Allocate Raster to save results - outraster = Raster(fill(NaN, (dims1..., Rasters.Band(machine_keys(ensemble)))); missingval = NaN, crs = Rasters.crs(data1)) +# ensemble-level methods +# For ensemble, there are no _predict methods. Datas has to be reformatted for each group +function _reformat_and_predict(e::SDMensemble, data, clamp::Bool, ::Nothing, ::Bool) + mapreduce( + g -> _reformat_and_predict(g, data, clamp, nothing), + merge, + e + ) +end - for (i, mach) in enumerate(machines(ensemble)) - # predict each machine and get the probability of true - @views outraster[Rasters.Band(i)][.~missings] .= MLJBase.predict(mach, data_).prob_given_ref[2] +function _reformat_and_predict(e::SDMensemble, data, clamp::Bool, reducer::Function, by_group::Bool) + if by_group + # pass the reducer to each group, then combine into a namedtuple + group_pr = (map(g -> _reformat_and_predict(g, data, clamp::Bool, reducer), e)) + NamedTuple{Tuple(model_names(e))}(group_pr) + else + # predict without reducing, then apply the reducer + pr = mapreduce(g -> _reformat_and_predict(g, data, clamp::Bool, nothing), merge, e) + map((d...) -> reducer(d), pr...) end +end - return outraster +# Dispatch on RasterStacks +_reformat_and_predict(e::SDMensemble, rs::Rasters.AbstractRasterStack, clamp::Bool, reducer::Function, by_group::Bool) = + _reformat_and_predict_raster(e, rs, clamp, reducer, by_group) +_reformat_and_predict(g::SDMgroup, rs::Rasters.AbstractRasterStack, clamp::Bool, reducer::Union{<:Function, <:Nothing}) = + _reformat_and_predict_raster(g, rs, clamp, reducer) +_reformat_and_predict(m::SDMmachine, rs::Rasters.AbstractRasterStack, clamp::Bool) = + _reformat_and_predict_raster(m, rs, clamp) + +function _reformat_and_predict_raster(s::Union{<:SDMensemble, SDMgroup, SDMmachine}, rs::Rasters.AbstractRasterStack, args...) + missing_mask = Rasters.boolmask(rs; alllayers = true) + d = rs[missing_mask] + pr = _reformat_and_predict(s, d, args...) + return _build_raster(missing_mask, pr) end -=# -# inernal convenience function to predict just train or test rows for each machine -function _predict(s::SDMgroupOrEnsemble, rows::Symbol) - y_hat_y = map(sdm_machines(s)) do sdm_mach - y_hat = MLJBase.predict(sdm_mach.machine, rows = sdm_mach[rows]) - y = data(s).response[sdm_mach[rows]] - return (;y_hat, y) +# Build Raster with the models/machines in Band dimension +function _build_raster(missing_mask::Rasters.AbstractRaster, pr::NamedTuple) + r_dims = (Rasters.dims(missing_mask)..., Band(collect(keys(pr)))) + T = eltype(first(pr)) # usually Float64, but could be something else depending on the reducer + output = Raster(Array{Union{Missing, T}}(missing, size(r_dims)); dims = r_dims) + for k in keys(pr) + @views output[Band = Rasters.At(k)][missing_mask] .= pr[k] end + return output +end - return (y_hat_y) -end \ No newline at end of file +# Build Raster with no additional layer +function _build_raster(missing_mask::Rasters.AbstractRaster, pr::Vector) + r_dims = Rasters.dims(missing_mask) + T = eltype(pr) # usually Float64, but could be something else depending on the reducer + output = Raster(Array{Union{Missing, T}}(missing, size(r_dims)), dims = r_dims) + output[missing_mask] .= pr + return output +end diff --git a/test/runtests.jl b/test/runtests.jl index 956a660..1f556e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,38 +1,62 @@ -using SpeciesDistributionModels, CategoricalArrays +using SpeciesDistributionModels, MLJBase import SpeciesDistributionModels as SDM +using StableRNGs, Distributions, Test -import GLM: Distributions - +rng = StableRNG(0) # some mock data n = 500 -backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n))) -presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n))) - -using Test +backgrounddata = (a = rand(rng, n), b = rand(rng, n), c = rand(rng, n)) +presencedata = (a = rand(rng, n), b = rand(rng, n).^2, c = sqrt.(rand(rng, n))) @testset "SpeciesDistributionModels.jl" begin - models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()] - resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)] + models = [SDM.random_forest(; rng), SDM.random_forest(; max_depth = 3, rng), SDM.linear_model(), SDM.boosted_regression_tree()] ensemble = sdm( - presencedata, backgrounddata, - models, - resamplers + presencedata, backgrounddata; + models = models, resampler = SDM.MLJBase.CV(; shuffle = true, nfolds = 5, rng), threaded = false ) evaluation = SDM.evaluate(ensemble) + @test evaluation isa SDM.SDMensembleEvaluation + @test evaluation[1] isa SDM.SDMgroupEvaluation + @test evaluation[1][1] isa SDM.SDMmachineEvaluation + @test evaluation.measures isa NamedTuple + mach_evals = SDM.machine_evaluations(evaluation) + @test mach_evals isa NamedTuple{(:train, :test)} + @test mach_evals.train isa NamedTuple{(keys(evaluation.measures))} + + machine_aucs = SDM.machine_evaluations(evaluation).test.auc + + pr1 = SDM.predict(ensemble, backgrounddata) + pr2 = SDM.predict(ensemble, backgrounddata; reducer = maximum) + pr3 = SDM.predict(ensemble, backgrounddata; reducer = x -> sum(x .> 0.5), by_group = true) + + @test pr2 isa Vector + @test collect(keys(pr1)) == SDM.machine_keys(ensemble) + @test collect(keys(pr3)) == SDM.model_names(ensemble) + eltype(pr3) == Vector{Int64} + + @test_throws ArgumentError SDM.predict(ensemble, backgrounddata.a) + @test_throws Exception SDM.predict(ensemble, backgrounddata[(:a,)]) + @test_throws Exception SDM.predict(ensemble, backgrounddata; by_group = true) + + # explain + expl = explain(ensemble; method = ShapleyValues(10; rng)) + varimp = variable_importance(expl) + @test varimp.b > varimp.a + @test varimp.c > varimp.a end @testset "collinearity" begin # mock data with a collinearity problem - data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) )) + data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(rng, n), e = backgrounddata.a .+ rand(rng, n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) )) - rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true) + rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = false) rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true) rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) @test rm_col_gvif == (:b, :c, :d, :e, :f) - @test rm_col_vif == (:b, :d, :e, :c, :f) - @test rm_col_pearson == (:b, :d, :e, :c, :f) + @test rm_col_vif == (:b, :c, :d, :e, :f) + @test rm_col_pearson == (:b, :c, :d, :e, :f) data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3]) Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2., remove_perfectly_collinear = false), silent = true)