Skip to content

Commit

Permalink
add boostrtap plot to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Feb 11, 2025
1 parent 2d89e19 commit 1595ed7
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,102 @@ CurrentModule = TMLE

TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in leveraging the power of modern machine-learning methods while preserving interpretability and statistical inference guarantees, you are in the right place. TMLE.jl is compatible with any [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) compliant algorithm and any dataset respecting the [Tables](https://tables.juliadata.org/stable/) interface.

The following plot illustrates the bias reduction achieved by TMLE over a mis-specified linear model in the presence of confounding. Note that in this case, TMLE also uses mis-specified models but still achieves a lower bias due to the atrgeting step.

```@eval
using Distributions
using Random
using DataFrames
using CairoMakie
using TMLE
using CategoricalArrays
using MLJLinearModels
function generate_confounded(;α = 0, β = 1, γ = -1, n = 100)
W = rand(Normal(1, 1), n)
Uₜ = rand(Normal(0, 1), n)
T = Uₜ + W .< 0.7
ϵ = rand(Normal(0, 0.1), n)
Y = @. β * T + α * W + γ * T * W + ϵ
return DataFrame(W=W, T=T, Y=Y)
end
function generate_unconfounded(;α = 0, β = 1, γ = -1, n = 100)
W = rand(Normal(1, 1), n)
Uₜ = rand(Normal(0, 1), n)
T = rand(Bernoulli(0.2), n)
ϵ = rand(Normal(0, 1), n)
Y = @. β * T + α * W + γ * T * W + ϵ
return DataFrame(W=W, T=T, Y=Y)
end
function linear_model_coef(data)
fitted_lm = lm(@formula(Y ~ T + W), data)
return coef(fitted_lm)[2]
end
function tmle_estimates(data)
models = default_models(;
Q_continuous=MLJLinearModels.LinearRegressor(),
Q_binary=MLJLinearModels.LogisticClassifier(),
G=MLJLinearModels.LogisticClassifier()
)
Ψ̂ = TMLEE(models=models, weighted=true)
Ψ = TMLE.ATE(;
outcome=:Y,
treatment_values=(T=(case=true, control = false),),
treatment_confounders=(:W,)
)
data.T = categorical(data.T)
Ψ̂ₙ, cache = Ψ̂(Ψ, data;verbosity=0);
return TMLE.estimate(Ψ̂ₙ)
end
function bootstrap_analysis(;B=100, α=0, β=1, γ=-1, n=100, ATE=α+γ)
Random.seed!(123)
β̂s_confounded = Vector{Float64}(undef, B)
tmles_confounded = Vector{Float64}(undef, B)
β̂s_unconfounded = Vector{Float64}(undef, B)
tmles_unconfounded = Vector{Float64}(undef, B)
for b in 1:B
@info(string("Bootstrap iteration ", b))
data_confounded = generate_confounded(;α=α, β=β, γ=γ, n=n)
β̂s_confounded[b] = linear_model_coef(data_confounded)
tmles_confounded[b] = tmle_estimates(data_confounded)
data_unconfounded = generate_unconfounded(;α=α, β=β, γ=γ, n=n)
β̂s_unconfounded[b] = linear_model_coef(data_unconfounded)
tmles_unconfounded[b] = tmle_estimates(data_unconfounded)
end
return β̂s_confounded, β̂s_unconfounded, tmles_confounded, tmles_unconfounded
end
function plot(β̂s_confounded, β̂s_unconfounded, tmles_confounded, tmles_unconfounded, β, ATE)
fig = Figure(size=(1000, 800))
ax = Axis(fig[1, 1], title="Distribution of Linear Model's and TMLE's Estimates", yticks=(1:2, ["Confounded", "Unconfounded"]))
labels = vcat(repeat(["Confounded"], length(β̂s_confounded)), repeat(["Unconfounded"], length(β̂s_unconfounded)))
rainclouds!(ax, labels, vcat(β̂s_confounded, β̂s_unconfounded), orientation = :horizontal, color=(:blue, 0.5))
rainclouds!(ax, labels, vcat(tmles_confounded, tmles_unconfounded), orientation = :horizontal, color=(:orange, 0.5))
vlines!(ax, ATE, label="ATE", color=:green)
vlines!(ax, β, label="β", color=:red)
Legend(fig[1, 2],
[PolyElement(color = :blue), PolyElement(color = :orange), LineElement(color = :green), LineElement(color = :red)],
["Linear", "TMLE", "ATE", "β"],
framevisible = false,
)
return fig
end
Random.seed!(123)
B = 1000
n = 1000
α = 0
β = 1
γ = -1
ATE = β + γ
β̂s_confounded, β̂s_unconfounded, tmles_confounded, tmles_unconfounded = bootstrap_analysis(;B=B, α=α, β=β, γ=γ, n=n, ATE=ATE)
plot(β̂s_confounded, β̂s_unconfounded, tmles_confounded, tmles_unconfounded, β, ATE)
```
## Installation

TMLE.jl can be installed via the Package Manager and supports Julia `v1.6` and greater.
Expand Down

0 comments on commit 1595ed7

Please sign in to comment.