Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collinearity methods #6

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Lasso = "b4fcebef-c861-5a0f-a7e2-ba9dc32b180a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
Expand All @@ -22,11 +24,14 @@ Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
julia = "1.6"
StatsModels = "0.7.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
7 changes: 4 additions & 3 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SpeciesDistributionModels

import Tables, StatsBase, Statistics
import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra
import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays
import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

Expand All @@ -13,14 +13,15 @@ using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count

export SDMensemble, predict, sdm, select, machines, machine_keys, shap,
interactive_evaluation, interactive_response_curves
interactive_evaluation, interactive_response_curves,
remove_collinear

include("collinearity.jl")
include("models.jl")
include("ensemble.jl")
include("predict.jl")
include("explain.jl")
include("evaluate.jl")
include("plots.jl")


end
107 changes: 107 additions & 0 deletions src/collinearity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
abstract type AbstractCollinearityMethod end

struct Gvif <: AbstractCollinearityMethod
threshold
end
Gvif(; threshold) = Gvif(threshold)

struct Vif <: AbstractCollinearityMethod
threshold
end
Vif(; threshold) = Vif(threshold)

struct Pearson <: AbstractCollinearityMethod
threshold
end
Pearson(; threshold) = Pearson(threshold)

# Need to add a method for RasterStack, unless it will be Tables.jl compatible
"""
remove_collinear(data; method, silent = false)

Removes strongly correlated variables in `data`, until correlation is below a threshold specified in `method`.

`method` can currently be either `Gvif`, `Vif` or `Pearson`, which use GVIF, VIF, or Pearson's r, respectively.
GVIF and VIF are similar method, but GVIF includes categorical variables whereas VIF ignores them.

To run without showing information about collinearity scores, set `silent = true`.

## Example
```julia
julia> import SpeciesDistributionModels as SDM
julia> mydata = (a = 1:100, b = sqrt.(1:100), c = rand(100))
julia> SDM.remove_collinear(mydata; method = SDM.Vif(10))
[ Info: a has highest GVIF of 28.367942095054225
[ Info: Removing a, 2 variables remaining
[ Info: b has highest GVIF of 1.0077618445543057
[ Info: All variables are below threshold, returning remaining variables
(:b, :c)
```

"""
function remove_collinear(data; method, silent::Bool = false)
schema = Tables.schema(data)
datakeys = schema.names
iscategorical = collect(schema.types .<: CategoricalArrays.CategoricalValue)
_remove_collinear(data, datakeys, method, ~silent, iscategorical)
end

_remove_collinear(data, datakeys, v::Vif, verbose, iscategorical) = (_vifstep(data, datakeys[.~iscategorical], v.threshold, verbose, StatsAPI.vif)..., datakeys[iscategorical]...)
_remove_collinear(data, datakeys, v::Gvif, verbose, iscategorical) = _vifstep(data, datakeys, v.threshold, verbose, StatsAPI.gvif)
_remove_collinear(data, datakeys, p::Pearson, verbose, iscategorical) = (_pearsonstep(data, datakeys[.~iscategorical], p.threshold, verbose)..., datakeys[iscategorical]...)

function _vifstep(data, datakeys, threshold, verbose, vifmethod)
highest_vif = threshold + 1.
while highest_vif > threshold
# make a custom implementation of gvif that works without the useless model
m = GLM.lm(StatsModels.FormulaTerm(StatsModels.term(1), StatsModels.term.(datakeys)), data)
vifresult = vifmethod(m)
maxvif = Base.findmax(vifresult)
highest_vif = maxvif[1]
if verbose
@info "$(datakeys[maxvif[2]]) has highest VIF score: $(maxvif[1])"
end
if isnan(highest_vif)
error("Cannot compute VIF. Possible some variables have perfect collinearity")
end

if highest_vif > threshold
if verbose
@info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining"
end
datakeys = datakeys[filter(x -> x != maxvif[2], 1:length(datakeys))] # not very elegant!
end
end

if verbose
@info "All variables are below threshold, returning remaining variables"
end

return datakeys
end

# to break ties it j
function _pearsonstep(data, datakeys, threshold, verbose)
data = Tables.columntable(data)[datakeys]
datamatrix = reduce(hcat, data)
c = abs.(Statistics.cor(datamatrix) - LinearAlgebra.I)
correlated_vars_idx = findall(LinearAlgebra.LowerTriangular(c) .> threshold)
if verbose
@info "Found $(length(correlated_vars_idx)) correlated variable pairs"
for idx in correlated_vars_idx
println("$(keys(data)[idx.I[1]]) ~ $(keys(data)[idx.I[2]]): $(c[idx])")
end
end

correlated_vars = Tuple.(correlated_vars_idx)
vars_to_remove = Int[]
while correlated_vars != []
cm = mapreduce(x -> collect(x), vcat, correlated_vars) |> StatsBase.countmap # count how often each occurs
to_remove = findmax(cm)[2]
correlated_vars = [c for c in correlated_vars if ~in(to_remove, c)]
append!(vars_to_remove, to_remove)
end
vars_to_remove
vars_remaining = keys(data)[setdiff(1:length(keys(data)), vars_to_remove)]
return vars_remaining
end
38 changes: 36 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
using SpeciesDistributionModels
using SpeciesDistributionModels, CategoricalArrays
import SpeciesDistributionModels as SDM

import GLM: Distributions

# some mock data
n = 500
backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n)))
presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n)))

using Test

@testset "SpeciesDistributionModels.jl" begin
# Write your tests here.
models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()]
resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)]

ensemble = sdm(
presencedata, backgrounddata,
models,
resamplers
)

evaluation = SDM.evaluate(ensemble)
end

@testset "collinearity" begin
# mock data with a collinearity problem
data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) ))

rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true)
rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true)
rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true)
@test rm_col_gvif == (:b, :c, :d, :e, :f)
@test rm_col_vif == (:b, :d, :e, :c, :f)
@test rm_col_pearson == (:b, :d, :e, :c, :f)

data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3])
Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2.), verbose = true)
@test remove_collinear(data_with_perfect_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) == (:a, )
end
Loading