Skip to content

Commit

Permalink
replace MLJ dependency with MLJBase and StatisticalMeasures (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure authored Dec 4, 2023
1 parent c0b6c52 commit 366f53a
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 25 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689"
Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down
21 changes: 21 additions & 0 deletions playground.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Revise

using SpeciesDistributionModels, GLMakie
import SpeciesDistributionModels as SDM

presences = (a = rand(200), b = rand(200), c = rand(200))
background = (a = rand(200), b = sqrt.(rand(200)), c = rand(200).^2)

models = [
SDM.linear_model(),
SDM.boosted_regression_tree(),
SDM.random_forest(),
SDM.random_forest(; n_trees = 10, max_depth = 3)]

ensemble = sdm(presences, background, models, [SDM.MLJBase.CV(; shuffle = true)])

interactive_evaluation(ensemble)

shapley = shap(ensemble; n_samples = 5)

interactive_response_curves(shapley)
4 changes: 1 addition & 3 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
module SpeciesDistributionModels

using MLJ

import Tables, StatsBase, Statistics
import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess
import MLJBase, StatisticalMeasures, GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

using Rasters: Raster, RasterStack

Expand Down
18 changes: 9 additions & 9 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.trained_models)
# Turns models into a NamedTuple with unique keys
function givenames(models::Vector)
names = map(models) do model
replace(MLJ.name(model), r"Classifier$"=>"")
replace(MLJBase.name(model), r"Classifier$"=>"")
end
for (name, n) in StatsBase.countmap(names)
if n > 1
Expand All @@ -86,7 +86,7 @@ end

function auc_by_model(ensemble)
mapreduce(vcat, keys(ensemble.models)) do key
mean([model.auc for model in ensemble.trained_models if model.model_key == key])
Statistics.mean([model.auc for model in ensemble.trained_models if model.model_key == key])
end
end

Expand All @@ -96,7 +96,7 @@ function sdm(
models,
resamplers;
var_keys::Vector{Symbol} = [key for key in Tables.schema(absence).names if in(key, Tables.schema(presences).names)],
scitypes::Vector{DataType} = [MLJ.scitype(Tables.schema(presences).types) for key in var_keys],
scitypes::Vector{DataType} = [MLJBase.scitype(Tables.schema(presences).types) for key in var_keys],
verbosity::Int = 0
)

Expand All @@ -117,16 +117,16 @@ function sdm(

trained_models = mapreduce(vcat, keys(resamplers_)) do resampler_key
resampler = resamplers_[resampler_key]
folds = MLJ.MLJBase.train_test_pairs(resampler, 1:n_total, response_values) ## get indices
folds = MLJBase.train_test_pairs(resampler, 1:n_total, response_values) ## get indices
mapreduce(vcat, keys(models_)) do model_key
model = models_[model_key]
map(enumerate(folds)) do (f, (train, test))
mach = machine(model, predictor_values, response_values)
fit!(mach; rows = train, verbosity = verbosity)
y_hat = MLJ.predict(mach, rows = test)
AUC = auc(y_hat, response_values[test])
mach = MLJBase.machine(model, predictor_values, response_values)
MLJBase.fit!(mach; rows = train, verbosity = verbosity)
y_hat = MLJBase.predict(mach, rows = test)
auc = StatisticalMeasures.auc(y_hat, response_values[test])
machine_key = Symbol(String(model_key) * "_" * String(resampler_key) * "_" * string(f))
return (; machine = mach, auc = AUC, model_key, resampler_key, fold = f, machine_key, train, test)
return (; machine = mach, auc = auc, model_key, resampler_key, fold = f, machine_key, train, test)
# Probably make a Type for this
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/explain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
function shap(ensemble; parallelism = Shapley.CPUThreads(), n_samples = 50)
shapvalues = map(ensemble.trained_models) do model
Shapley.shapley(
x -> Float64.(MLJ.pdf.(MLJ.predict(model.machine, x), true)), # some ml models return float32s - where to handle this?
x -> Float64.(MLJBase.pdf.(MLJBase.predict(model.machine, x), true)), # some ml models return float32s - where to handle this?
Shapley.MonteCarlo(parallelism, n_samples),
ensemble.data.predictor
)
Expand Down
15 changes: 8 additions & 7 deletions src/models.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Load in models and wrap them to hide @load
# Load in models and wrap them to hide the interfaces
# Do we want to get ride of these as dependencies?

const lbc = MLJ.@load LinearBinaryClassifier pkg=GLM verbosity = 0
const etc = MLJ.@load EvoTreeClassifier pkg=EvoTrees verbosity = 0
const rf = MLJ.@load RandomForestClassifier pkg=DecisionTree verbosity = 0
using MLJGLMInterface: LinearBinaryClassifier
using EvoTrees: EvoTreeClassifier
using MLJDecisionTreeInterface: RandomForestClassifier

linear_model(; kw...) = lbc(; kw...)
boosted_regression_tree(; kw...) = etc(; kw...)
random_forest(; kw...) = rf(; kw...)
linear_model = LinearBinaryClassifier
boosted_regression_tree = EvoTreeClassifier
random_forest = RandomForestClassifier
2 changes: 1 addition & 1 deletion src/plots.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function classification_rates(scores, y)
fpr, tpr, thresholds = roc_curve(scores, y)
fpr, tpr, thresholds = StatisticalMeasures.roc_curve(scores, y)
tnr = 1. .- fpr
fnr = 1. .- tpr

Expand Down
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function predict(ensemble::SDMensemble, data::NamedTuple)

map(machines(ensemble)) do mach
# predict each machine and get the probability of true
Float64.(MLJ.predict(mach, data_).prob_given_ref[2])
Float64.(MLJBase.predict(mach, data_).prob_given_ref[2])
end
end

Expand Down Expand Up @@ -33,15 +33,15 @@ function predict(ensemble::SDMensemble, data::Rasters.RasterStack)

for (i, mach) in enumerate(machines(ensemble))
# predict each machine and get the probability of true
@views outraster[Rasters.Band(i)][.~missings] .= MLJ.predict(mach, data_).prob_given_ref[2]
@views outraster[Rasters.Band(i)][.~missings] .= MLJBase.predict(mach, data_).prob_given_ref[2]
end

return outraster
end

function predict(ensemble::SDMensemble, rows::Symbol)
y_hat_y = map(ensemble.trained_models) do model
y_hat = MLJ.predict(model.machine, rows = model[rows])
y_hat = MLJBase.predict(model.machine, rows = model[rows])
y = ensemble.data.response[model[rows]]
return (;y_hat, y)
end
Expand Down

0 comments on commit 366f53a

Please sign in to comment.