Skip to content

Commit

Permalink
switch to IntervalSets and StatsDiscretizations
Browse files Browse the repository at this point in the history
  • Loading branch information
nignatiadis committed May 21, 2024
1 parent 4ff2401 commit 7f1f33a
Show file tree
Hide file tree
Showing 22 changed files with 146 additions and 550 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name = "Empirikos"
uuid = "cab608d6-c565-4ea1-96d6-ce5441ba21b0"
authors = ["Nikos Ignatiadis <[email protected]> and contributors"]
version = "0.5.3"
version = "0.6"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,18 +19,17 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RangeHelpers = "3a07dd3d-1c52-4395-8858-40c6328157db"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsDiscretizations = "1d0cfea5-fabc-4e25-85a8-945fa8abc3c9"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
CSV = "0.8, 0.9, 0.10"
DataStructures = "0.17,0.18"
Distributions = "0.24.7, 0.25"
Intervals = "1.4, 1.5, 1.6"
JuMP = "^1"
KernelDensity = "0.6"
LinearAlgebra = "1.9"
Expand All @@ -44,11 +42,11 @@ QuadGK = "2.0"
Random = "1.9"
RangeHelpers = "0.1.9"
RecipesBase = "1.2, 1.3"
Reexport = "1"
Setfield = "1"
SpecialFunctions = "2"
Statistics = "1.9"
StatsBase = "0.33, 0.34"
StatsDiscretizations = "0.2"
UnPack = "1"
julia = "1.9"

Expand Down
15 changes: 3 additions & 12 deletions src/Empirikos.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
module Empirikos

using Reexport

import Base: broadcast, broadcast!, broadcasted, eltype, zero, <=
using DataStructures
@reexport using Distributions
using Distributions
import Distributions:
ntrials, pdf, support, location, cf, cdf, ccdf, logpdf, logdiffcdf, logccdf, components

import Intervals: Interval, Closed, Open, Unbounded, Bounded, AbstractInterval, isbounded,
RightEndpoint
export Interval, Closed, Open, Unbounded # instead of @reexport

import JuMP
import JuMP: @constraint, @variable, set_lower_bound, @expression,
Expand All @@ -36,13 +32,12 @@ import Statistics: std, var
using StatsBase
import StatsBase: loglikelihood, response, fit, nobs, weights, confint

using StatsDiscretizations
using UnPack

include("utils.jl")
include("ebayes_samples.jl")
include("compound.jl")
include("interval_discretizer.jl")
include("dict_function.jl")
include("ebayes_methods.jl")
include("ebayes_targets.jl")
include("mixtures.jl")
Expand Down Expand Up @@ -120,9 +115,7 @@ export EBayesSample,
MixturePriorClass,
GaussianScaleMixtureClass,
NPMLE,
nominal_alpha,
integer_discretizer,
interval_discretizer
nominal_alpha

export loglikelihood,
response,
Expand All @@ -136,8 +129,6 @@ export DvoretzkyKieferWolfowitz,
ChiSquaredFLocalization,
InfinityNormDensityBand

# utilities
export DictFunction


export FLocalizationInterval,
Expand Down
47 changes: 24 additions & 23 deletions src/amari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ function initialize_modulus_model(method::AMARI, ::Type{ModulusModelWithF}, targ
throw(ArgumentError("ModulusModelWithF only works for Homoskedastic samples."))
Z = first(representative_eb_samples.vec)
if isa(method.plugin_G, Distribution) #TODO SPECIAL CASE this elsewhere?
estimated_marginal_density = Empirikos.dictfun(discretizer, Z, z-> pdf(method.plugin_G, z))
estimated_marginal_density = StatsDiscretizations.dictfun(discretizer, z-> pdf(method.plugin_G, z), x->set_response(Z,x))
else
estimated_marginal_density = Empirikos.dictfun(discretizer, Z, z-> pdf(method.plugin_G.prior, z))
estimated_marginal_density = StatsDiscretizations.dictfun(discretizer, z-> pdf(method.plugin_G.prior, z), x->set_response(Z,x))
end

model = Model(solver)
Expand Down Expand Up @@ -223,7 +223,7 @@ function modulus_cholesky_factor(convexclass::AbstractMixturePriorClass, plugin_
cache_vec = zeros(K)
fill!(chr.factors, 0)
cache_vec = zeros(K)
for _interval in discr.sorted_intervals
for _interval in discr
for (z, pr) in zip(eb_samples.vec, eb_samples.probs)
z = set_response(z, _interval)
cache_vec .= sqrt(pr) .* exp.(logpdf.(components(convexclass), z) .- logpdf(plugin_G, z)/2)
Expand Down Expand Up @@ -281,19 +281,19 @@ function set_target!(modulus_model::AbstractModulusModel, target::Empirikos.Line
modulus_model
end

function default_support_discretizer(Zs::AbstractVector{<:AbstractNormalSample})
_low,_up = quantile(response.(Zs), (0.005, 0.995))
_step = mean( std.(Zs))/100
interval_discretizer(RangeHelpers.range(_low; stop=above(_up), step=_step))
end
#function default_support_discretizer(Zs::AbstractVector{<:AbstractNormalSample})
# _low,_up = quantile(response.(Zs), (0.005, 0.995))
# _step = mean( std.(Zs))/100
# interval_discretizer(RangeHelpers.range(_low; stop=above(_up), step=_step))
#end

function default_support_discretizer(Zs::AbstractVector{<:FoldedNormalSample})
_up = quantile(response.(Zs), 0.995)
_low = zero(_up)
_step = mean( std.(Zs) )/100
interval_discretizer(RangeHelpers.range(start=_low, stop=above(_up), step=_step);
closed=:left, unbounded=:right)
end
#function default_support_discretizer(Zs::AbstractVector{<:FoldedNormalSample})
# _up = quantile(response.(Zs), 0.995)
# _low = zero(_up)
# _step = mean( std.(Zs) )/100
# interval_discretizer(RangeHelpers.range(start=_low, stop=above(_up), step=_step);
# closed=:left, unbounded=:right)
#end


function initialize_method(method::AMARI, target::Empirikos.LinearEBayesTarget, Zs; kwargs...)
Expand Down Expand Up @@ -385,7 +385,7 @@ function SteinMinimaxEstimator(modulus_model::ModulusModelWithoutF)
L1 = target(g1)
L2 = target(g2)

offset_sum = sum(discretizer.sorted_intervals) do _int
offset_sum = sum(discretizer) do _int
sum(zip(representative_eb_samples.vec, representative_eb_samples.probs)) do (z,pr)
z = set_response(z, _int)
f2_z = pdf(g2, z)
Expand Down Expand Up @@ -422,8 +422,9 @@ end

function SteinMinimaxEstimator(modulus_model::ModulusModelWithF)
@unpack model, method, target, estimated_marginal_density = modulus_model
@unpack convexclass = method

@unpack convexclass, representative_eb_samples = method
Z = first(representative_eb_samples.vec)

δ = get_δ(modulus_model)
ω_δ = objective_value(model)
ω_δ_prime = -JuMP.dual(modulus_model.bound_delta)
Expand All @@ -445,7 +446,7 @@ function SteinMinimaxEstimator(modulus_model::ModulusModelWithF)
Q_0 = (L1+L2)/2 -
ω_δ_prime/(2*δ)*sum( (f2 .- f1).* (f2 .+ f1) ./ f̄s)

Q = Empirikos.dictfun(method.discretizer, Zs, Q .+ Q_0)
Q = StatsDiscretizations.dictfun(method.discretizer, Q .+ Q_0, x->set_response(Z,x))

max_bias = (ω_δ - δ*ω_δ_prime)/2
unit_var_proxy = ω_δ_prime^2
Expand Down Expand Up @@ -497,7 +498,7 @@ function confint(Q::SteinMinimaxEstimator, target, Zs; level=0.95, tail=:both)
error("Target has changed")
α = 1- level
_bias = Q.max_bias
_Qs = Q.Q.(Zs)
_Qs = collect(Q.Q.(Zs))
_wts = StatsBase.weights(Zs)
_se = std(_Qs, _wts; corrected=true)/sqrt(nobs(Zs))
point_estimate = mean(_Qs, _wts)
Expand Down Expand Up @@ -596,14 +597,14 @@ function confint(method::AMARI, target::Empirikos.AbstractPosteriorTarget, Zs;


fit_lower = fit_initialized!(method, target_lower, Zs) #SteinMinimax
Q_lower = fit_lower.Q.(Zs)
Q_lower = collect(fit_lower.Q.(Zs))
confint_lower = confint(fit_lower, target_lower, Zs; level=level)
max_bias_lower = confint_lower.maxbias
var_Q_lower = abs2(confint_lower.se)
estimate_lower = confint_lower.estimate

fit_upper = fit_initialized!(method, target_upper, Zs)
Q_upper = fit_upper.Q.(Zs)
Q_upper = collect(fit_upper.Q.(Zs))
confint_upper = confint(fit_upper, target_upper, Zs; level=level)
max_bias_upper = confint_upper.maxbias
var_Q_upper = abs2(confint_upper.se)
Expand Down Expand Up @@ -650,7 +651,7 @@ end
function worst_case_bias_lp(fitted_amari::AMARI, Q::QDonoho, target; max=true)
@unpack convexclass, solver, discretizer, flocalization, representative_eb_samples = fitted_amari

transposed_intervals = reshape(discretizer.sorted_intervals, 1, length(discretizer.sorted_intervals))
transposed_intervals = reshape(discretizer, 1, length(discretizer))
Zs = set_response.(representative_eb_samples.vec, transposed_intervals)
model = Model(solver)

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/LordCressie/LordCressie.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end

function ebayes_samples()
tbl = load_table()
Zs = summarize(BinomialSample.(tbl.x, 20), tbl.N1)
summarize(BinomialSample.(tbl.x, 20), tbl.N1)
end

end
56 changes: 0 additions & 56 deletions src/dict_function.jl

This file was deleted.

Loading

0 comments on commit 7f1f33a

Please sign in to comment.