Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arithmetic for UnivariateFinite objects. #12

Merged
merged 12 commits into from
Dec 13, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CategoricalDistributions"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
2 changes: 2 additions & 0 deletions src/CategoricalDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ using Random
using UnicodePlots

const Dist = Distributions
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12

import Distributions: pdf, logpdf, support, mode

include("utilities.jl")
include("types.jl")
include("methods.jl")
include("arrays.jl")
include("arithmetic.jl")

export UnivariateFinite, UnivariateFiniteArray

Expand Down
55 changes: 55 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ## ARITHMETIC

const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
"Adding two `UnivariateFinite` objects whose "*
"sample spaces have different labellings is not allowed. ")

import Base: +, *, /, -

pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)

function +(d1::U, d2::U) where U <: SingletonOrArray
L = classes(d1)
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
end

function _minus(d, T)
S = d.scitype
decoder = d.decoder
prob_given_ref = copy(d.prob_given_ref)
for ref in keys(prob_given_ref)
prob_given_ref[ref] = -prob_given_ref[ref]
end
return T(S, decoder, prob_given_ref)
end
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)

function -(d1::U, d2::U) where U <: SingletonOrArray
L = classes(d1)
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
end

# It seems that the restriction `x::Number` below (applying only to the
# array case) is unavoidable because of a method ambiguity with
# `Base.*(::AbstractArray, ::Number)`.

function _times(d, x, T)
S = d.scitype
decoder = d.decoder
prob_given_ref = copy(d.prob_given_ref)
for ref in keys(prob_given_ref)
prob_given_ref[ref] *= x
end
return T(d.scitype, decoder, prob_given_ref)
end
*(d::UnivariateFinite, x) = _times(d, x, UnivariateFinite)
*(d::UnivariateFiniteArray, x::Number) = _times(d, x, UnivariateFiniteArray)

*(x, d::UnivariateFinite) = d*x
*(x::Number, d::UnivariateFiniteArray) = d*x
/(d::UnivariateFinite, x) = d*inv(x)
/(d::UnivariateFiniteArray, x::Number) = d*inv(x)
3 changes: 3 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ function Base.Broadcast.broadcasted(::typeof(mode),
return reshape(mode_flat, size(u))
end


## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE

# We already have `classes(::UnivariateFininiteArray)
Expand All @@ -266,3 +267,5 @@ function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
return classes(yhat[i])
end


67 changes: 13 additions & 54 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
print(stream, "UnivariateFinite{$(d.scitype)}($arg_str)")
end

Base.show(io::IO, mime::MIME"text/plain",
d::UnivariateFinite) = show(io, d)

# in common case of `Real` probabilities we can do a pretty bar plot:
function Base.show(io::IO, mime::MIME"text/plain",
d::UnivariateFinite{S}) where S
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
show_bars = false
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
all(>=(0), values(d.prob_given_ref))
show_bars = true
end
show_bars || return show(io, d)
s = support(d)
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
y = pdf.(d, s)
S = d.scitype
plt = barplot(x, y, title="UnivariateFinite{$S}")
show(io, mime, plt)
end
Expand Down Expand Up @@ -125,58 +136,6 @@ end

# TODO: It would be useful to define == as well.

# TODO: Now that `UnivariateFinite` is any finite measure, we can
# replace the following nonsense with an overloading of `+`. I think
# it is only used in MLJEnsembles.jl - but need to check. I believe
# this is a private method we can easily remove

function average(dvec::AbstractVector{UnivariateFinite{S,V,R,P}};
weights=nothing) where {S,V,R,P}

n = length(dvec)

Dist.@check_args(UnivariateFinite, weights == nothing || n==length(weights))

# check all distributions have consistent pool:
first_index = first(dvec).decoder.classes
for d in dvec
d.decoder.classes == first_index ||
error("Averaging UnivariateFinite distributions with incompatible"*
" pools. ")
end

# get all refs:
refs = reduce(union, [keys(d.prob_given_ref) for d in dvec]) |> collect

# initialize the prob dictionary for the distribution sum:
prob_given_ref = LittleDict{R,P}([refs...], zeros(P, length(refs)))

# make vector of all the distributions dicts padded to have same common keys:
prob_given_ref_vec = map(dvec) do d
merge(prob_given_ref, d.prob_given_ref)
end

# sum up:
if weights == nothing
scale = 1/n
for x in refs
for k in 1:n
prob_given_ref[x] += scale*prob_given_ref_vec[k][x]
end
end
else
scale = 1/sum(weights)
for x in refs
for k in 1:n
prob_given_ref[x] +=
weights[k]*prob_given_ref_vec[k][x]*scale
end
end
end
d1 = first(dvec)
return UnivariateFinite(sample_scitype(d1), d1.decoder, prob_given_ref)
end

"""
Dist.pdf(d::UnivariateFinite, x)

Expand Down Expand Up @@ -374,6 +333,6 @@ end
# # BROADCASTING OVER SINGLE UNIVARIATE FINITE

# This mirrors behaviour assigned Distributions.Distribution objects,
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.

Broadcast.broadcastable(d::UnivariateFinite) = Ref(d)
98 changes: 68 additions & 30 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ choosing `probs` to be an array of one higher dimension than the array
generated.

Here the word "probabilities" is an abuse of terminology as there is
no requirement that probabilities actually sum to one, only that they
be non-negative. So `UnivariateFinite` objects actually implement
arbitrary non-negative measures over finite sets of labelled points. A
`UnivariateDistribution` will be a bona fide probability measure when
constructed using the `augment=true` option (see below) or when
`fit` to data.
no requirement that the that probabilities actually sum to one. The
only requirement is that the probabilities have a common type `T` for
which `zero(T)` is defined. In particular, `UnivariateFinite` objects
implement arbitrary non-negative, signed, or complex measures over
finite sets of labelled points. A `UnivariateDistribution` will be a
bona fide probability measure when constructed using the
`augment=true` option (see below) or when `fit` to data. And the
probabilities of a `UnivariateFinite` object `d` must be non-negative,
with a non-zero sum, for `rand(d)` to be defined and interpretable.

Unless `pool` is specified, `support` should have type
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
Expand All @@ -37,28 +40,37 @@ constructor then returns an array of `UnivariateFinite` distributions
of size `(n1, n2, ..., nk)`.

```
using CategoricalArrays
v = categorical([:x, :x, :y, :x, :z])

julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)

julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
using CategoricalDistributions, CategoricalArrays, Distributions
samples = categorical(['x', 'x', 'y', 'x', 'z'])
julia> Distributions.fit(UnivariateFinite, samples)
UnivariateFinite{Multiclass{3}}
┌ ┐
x ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.6
y ┤■■■■■■■■■■■■ 0.2
z ┤■■■■■■■■■■■■ 0.2
└ ┘

julia> d = UnivariateFinite([samples[1], samples[end]], [0.1, 0.9])
UnivariateFinite{Multiclass{3}(x=>0.1, z=>0.9)
UnivariateFinite{Multiclass{3}}
┌ ┐
x ┤■■■■ 0.1
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
└ ┘

julia> rand(d, 3)
3-element Array{Any,1}:
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
CategoricalValue{Symbol,UInt32} 'z'
CategoricalValue{Symbol,UInt32} 'z'
CategoricalValue{Symbol,UInt32} 'z'

julia> levels(d)
julia> levels(samples)
3-element Array{Symbol,1}:
:x
:y
:z
'x'
'y'
'z'

julia> pdf(d, :y)
julia> pdf(d, 'y')
0.0
```

Expand All @@ -77,19 +89,27 @@ In the last case, specify `ordered=true` if the pool is to be
considered ordered.

```
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)

julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)

julia> pdf(d, :y) # allowed as `:y in levels(v)`
julia> UnivariateFinite(['x', 'z'], [0.1, 0.9], pool=missing, ordered=true)
UnivariateFinite{OrderedFactor{2}}
┌ ┐
x ┤■■■■ 0.1
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
└ ┘

samples = categorical(['x', 'x', 'y', 'x', 'z'])
julia> d = UnivariateFinite(['x', 'z'], [0.1, 0.9], pool=samples)
┌ ┐
x ┤■■■■ 0.1
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
└ ┘

julia> pdf(d, 'y') # allowed as `'y' in levels(samples)`
0.0

v = categorical([:x, :x, :y, :x, :z, :w])
v = categorical(['x', 'x', 'y', 'x', 'z', 'w'])
probs = rand(100, 3)
probs = probs ./ sum(probs, dims=2)
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
julia> d1 = UnivariateFinite(['x', 'y', 'z'], probs, pool=v)
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
Expand All @@ -107,6 +127,18 @@ for the classes `c2, c3, ..., cn`. The class `c1` probabilities are
chosen so that each `UnivariateFinite` distribution in the returned
array is a bona fide probability distribution.

```julia
julia> UnivariateFinite([0.1, 0.2, 0.3], augment=true, pool=missing)
3-element UnivariateFiniteArray{Multiclass{2}, String, UInt8, Float64, 1}:
UnivariateFinite{Multiclass{2}}(class_1=>0.9, class_2=>0.1)
UnivariateFinite{Multiclass{2}}(class_1=>0.8, class_2=>0.2)
UnivariateFinite{Multiclass{2}}(class_1=>0.7, class_2=>0.3)

d2 = UnivariateFinite(['x', 'y', 'z'], probs[:, 2:end], augment=true, pool=v)
julia> pdf(d1, levels(v)) ≈ pdf(d2, levels(v))
true
```

---

UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
Expand Down Expand Up @@ -142,6 +174,8 @@ struct UnivariateFinite{S,V,R,P}
prob_given_ref::LittleDict{R,P,Vector{R}, Vector{P}}
end

@doc DOC_CONSTRUCTOR UnivariateFinite

"""
UnivariateFiniteArray

Expand All @@ -160,6 +194,10 @@ end

const UnivariateFiniteVector{S,V,R,P} = UnivariateFiniteArray{S,V,R,P,1}

# private:
const SingletonOrArray{S,V,R,P} = Union{UnivariateFinite{S,V,R,P},
UnivariateFiniteArray{S,V,R,P}}


# # CHECKS AND ERROR MESSAGES

Expand Down
Loading