Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For a 0.1.2 release #14

Merged
merged 20 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.3'
- '1' # automatically expands to the latest stable 1.x release of Julia.
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CategoricalDistributions"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -19,7 +19,7 @@ Missings = "0.4, 1"
OrderedCollections = "1.1"
ScientificTypesBase = "2"
UnicodePlots = "2"
julia = "1.0"
julia = "1.3"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ levels(d2)
julia> pdf(d2, "maybe")
0.0

julia> pdf(d2, "okay")https://github.com/JuliaAI/CategoricalDistributions.jl#measures-over-finite-labeled-sets
julia> pdf(d2, "okay")
ERROR: DomainError with Value okay not in pool. :
```

Expand Down Expand Up @@ -122,10 +122,10 @@ julia> pdf(v, L)
## Measures over finite labeled sets

There is, in fact, no enforcement that probabilities in a
`UnivariateFinite` distribution sum to one, only that they be
non-negative. Thus `UnivariateFinite` objects can be more properly
understood as an implementation of arbitrary non-negative measures
over finite labeled sets.
`UnivariateFinite` distribution sum to one, only that they be belong
to a type `T` for which `zero(T)` is defined. In particular
`UnivariateFinite` objects implement arbitrary non-negative, signed,
or complex measures over a finite labeled set.


## What does this package provide?
Expand Down
2 changes: 2 additions & 0 deletions src/CategoricalDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ using Random
using UnicodePlots

const Dist = Distributions
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12

import Distributions: pdf, logpdf, support, mode

include("utilities.jl")
include("types.jl")
include("methods.jl")
include("arrays.jl")
include("arithmetic.jl")

export UnivariateFinite, UnivariateFiniteArray

Expand Down
55 changes: 55 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ## ARITHMETIC

const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
"Adding two `UnivariateFinite` objects whose "*
"sample spaces have different labellings is not allowed. ")

import Base: +, *, /, -

pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)

function +(d1::U, d2::U) where U <: SingletonOrArray
L = classes(d1)
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
end

function _minus(d, T)
S = d.scitype
decoder = d.decoder
prob_given_ref = copy(d.prob_given_ref)
for ref in keys(prob_given_ref)
prob_given_ref[ref] = -prob_given_ref[ref]
end
return T(S, decoder, prob_given_ref)
end
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)

function -(d1::U, d2::U) where U <: SingletonOrArray
L = classes(d1)
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
end

# It seems that the restriction `x::Number` below (applying only to the
# array case) is unavoidable because of a method ambiguity with
# `Base.*(::AbstractArray, ::Number)`.

function _times(d, x, T)
S = d.scitype
decoder = d.decoder
prob_given_ref = copy(d.prob_given_ref)
for ref in keys(prob_given_ref)
prob_given_ref[ref] *= x
end
return T(d.scitype, decoder, prob_given_ref)
end
*(d::UnivariateFinite, x) = _times(d, x, UnivariateFinite)
*(d::UnivariateFiniteArray, x::Number) = _times(d, x, UnivariateFiniteArray)

*(x, d::UnivariateFinite) = d*x
*(x::Number, d::UnivariateFiniteArray) = d*x
/(d::UnivariateFinite, x) = d*inv(x)
/(d::UnivariateFiniteArray, x::Number) = d*inv(x)
3 changes: 3 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ function Base.Broadcast.broadcasted(::typeof(mode),
return reshape(mode_flat, size(u))
end


## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE

# We already have `classes(::UnivariateFininiteArray)
Expand All @@ -266,3 +267,5 @@ function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
return classes(yhat[i])
end


71 changes: 18 additions & 53 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
print(stream, "UnivariateFinite{$(d.scitype)}($arg_str)")
end

Base.show(io::IO, mime::MIME"text/plain",
d::UnivariateFinite) = show(io, d)

# in common case of `Real` probabilities we can do a pretty bar plot:
function Base.show(io::IO, mime::MIME"text/plain",
d::UnivariateFinite{S}) where S
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
show_bars = false
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
all(>=(0), values(d.prob_given_ref))
show_bars = true
end
show_bars || return show(io, d)
s = support(d)
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
y = pdf.(d, s)
S = d.scitype
plt = barplot(x, y, title="UnivariateFinite{$S}")
show(io, mime, plt)
end
Expand Down Expand Up @@ -125,58 +136,6 @@ end

# TODO: It would be useful to define == as well.

# TODO: Now that `UnivariateFinite` is any finite measure, we can
# replace the following nonsense with an overloading of `+`. I think
# it is only used in MLJEnsembles.jl - but need to check. I believe
# this is a private method we can easily remove

function average(dvec::AbstractVector{UnivariateFinite{S,V,R,P}};
weights=nothing) where {S,V,R,P}

n = length(dvec)

Dist.@check_args(UnivariateFinite, weights == nothing || n==length(weights))

# check all distributions have consistent pool:
first_index = first(dvec).decoder.classes
for d in dvec
d.decoder.classes == first_index ||
error("Averaging UnivariateFinite distributions with incompatible"*
" pools. ")
end

# get all refs:
refs = reduce(union, [keys(d.prob_given_ref) for d in dvec]) |> collect

# initialize the prob dictionary for the distribution sum:
prob_given_ref = LittleDict{R,P}([refs...], zeros(P, length(refs)))

# make vector of all the distributions dicts padded to have same common keys:
prob_given_ref_vec = map(dvec) do d
merge(prob_given_ref, d.prob_given_ref)
end

# sum up:
if weights == nothing
scale = 1/n
for x in refs
for k in 1:n
prob_given_ref[x] += scale*prob_given_ref_vec[k][x]
end
end
else
scale = 1/sum(weights)
for x in refs
for k in 1:n
prob_given_ref[x] +=
weights[k]*prob_given_ref_vec[k][x]*scale
end
end
end
d1 = first(dvec)
return UnivariateFinite(sample_scitype(d1), d1.decoder, prob_given_ref)
end

"""
Dist.pdf(d::UnivariateFinite, x)

Expand Down Expand Up @@ -371,3 +330,9 @@ function Dist.fit(d::Type{<:UnivariateFinite},
end


# # BROADCASTING OVER SINGLE UNIVARIATE FINITE

# This mirrors behaviour assigned Distributions.Distribution objects,
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.

Broadcast.broadcastable(d::UnivariateFinite) = Ref(d)
Loading