Skip to content

Commit

Permalink
Update to ChainRules 1 (#198)
Browse files Browse the repository at this point in the history
* Update to ChainRules 1

* Support ChainRules 0.10 (for Julia 1.3 compatibility)

* Project to primal's subspace

* Use LogExpFunctions and IrrationalConstants instead of NNlib and StatsFuns

* Fix deprecations

* Update Project.toml
  • Loading branch information
devmotion authored Sep 13, 2021
1 parent 0d9b8b4 commit e9d289a
Show file tree
Hide file tree
Showing 20 changed files with 100 additions and 93 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.7"
version = "0.9.8"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.9, 0.10"
ChainRulesCore = "0.10.11, 1"
Compat = "3"
Distributions = "0.23.3, 0.24, 0.25"
Functors = "0.1, 0.2"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3.3"
MappedArrays = "0.2.2, 0.3, 0.4"
NNlib = "0.6, 0.7"
NonlinearSolve = "0.3"
Reexport = "0.2, 1"
Requires = "0.5, 1"
StatsFuns = "0.8, 0.9.3"
julia = "1.3"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ logabsdetjac(::Identity{0}, y::Real) = zero(eltype(y)) # ∂ₓid(x) = ∂ₓ x
A slightly more complex example is `Logit`:

```julia
using StatsFuns: logit, logistic
using LogExpFunctions: logit, logistic

struct Logit{T<:Real} <: Bijector{0}
a::T
Expand Down Expand Up @@ -586,7 +586,7 @@ As you can see it's a very contrived example, but you get the idea.
We could also have implemented `Logit` as an `ADBijector`:
```julia
using StatsFuns: logit, logistic
using LogExpFunctions: logit, logistic
using Bijectors: ADBackend

struct ADLogit{T, AD} <: ADBijector{AD, 0}
Expand Down
4 changes: 3 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ module Bijectors

using Reexport, Requires
@reexport using Distributions
using StatsFuns
using LinearAlgebra
using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

import Functors
import IrrationalConstants
import LogExpFunctions
import NonlinearSolve
import ChainRulesCore

Expand Down
4 changes: 3 additions & 1 deletion src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
result += (K - i + 1) * (
IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j))
)
end

return result
Expand Down
6 changes: 2 additions & 4 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
######################
# Logit and Logistic #
######################
using StatsFuns: logit, logistic

struct Logit{N, T<:Real} <: Bijector{N}
a::T
b::T
Expand All @@ -29,11 +27,11 @@ Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b

(b::Logit)(x) = _logit.(x, b.a, b.b)
(b::Logit)(x::AbstractArray{<:AbstractArray}) = map(b, x)
_logit(x, a, b) = logit((x - a) / (b - a))
_logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a))

(ib::Inverse{<:Logit})(y) = _ilogit.(y, ib.orig.a, ib.orig.b)
(ib::Inverse{<:Logit})(x::AbstractArray{<:AbstractArray}) = map(ib, x)
_ilogit(y, a, b) = (b - a) * logistic(y) + a
_ilogit(y, a, b) = (b - a) * LogExpFunctions.logistic(y) + a

logabsdetjac(b::Logit{0}, x) = logit_logabsdetjac.(x, b.a, b.b)
logabsdetjac(b::Logit{1}, x::AbstractVector) = sum(logit_logabsdetjac.(x, b.a, b.b))
Expand Down
8 changes: 2 additions & 6 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using LinearAlgebra
using Random
using NNlib: softplus

################################################################################
# Planar and Radial Flows #
# Ref: Variational Inference with Normalizing Flows, #
Expand Down Expand Up @@ -63,8 +59,8 @@ arXiv:1505.05770
"""
function get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real})
wT_u = dot(w, u)
= u .+ ((softplus(-wT_u) - 1) / sum(abs2, w)) .* w
wT_û = softplus(wT_u) - 1
= u .+ ((LogExpFunctions.log1pexp(-wT_u) - 1) / sum(abs2, w)) .* w
wT_û = LogExpFunctions.log1pexp(wT_u) - 1
return û, wT_û
end

Expand Down
16 changes: 6 additions & 10 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using LinearAlgebra
using Random
using NNlib: softplus

################################################################################
# Planar and Radial Flows #
# Ref: Variational Inference with Normalizing Flows, #
Expand Down Expand Up @@ -39,8 +35,8 @@ function _transform(flow::RadialLayer, z::AbstractVecOrMat)
return _radial_transform(first(flow.α_), first(flow.β), flow.z_0, z)
end
function _radial_transform(α_, β, z_0, z)
α = softplus(α_) # from A.2
β_hat = -α + softplus(β) # from A.2
α = LogExpFunctions.log1pexp(α_) # from A.2
β_hat = -α + LogExpFunctions.log1pexp(β) # from A.2
if z isa AbstractVector
r = norm(z .- z_0)
else
Expand Down Expand Up @@ -73,8 +69,8 @@ end
function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
α_plus_β_hat = softplus(first(flow.β)) # from A.2
α = LogExpFunctions.log1pexp(first(flow.α_)) # from A.2
α_plus_β_hat = LogExpFunctions.log1pexp(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2.
y_minus_z0 = y .- z0
Expand All @@ -87,8 +83,8 @@ end
function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
α_plus_β_hat = softplus(first(flow.β)) # from A.2
α = LogExpFunctions.log1pexp(first(flow.α_)) # from A.2
α_plus_β_hat = LogExpFunctions.log1pexp(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2 for each column.
y_minus_z0 = y .- z0
Expand Down
15 changes: 6 additions & 9 deletions src/bijectors/rational_quadratic_spline.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using NNlib

"""
RationalQuadraticSpline{T, 0} <: Bijector{0}
RationalQuadraticSpline{T, 1} <: Bijector{1}
Expand Down Expand Up @@ -111,11 +109,10 @@ function RationalQuadraticSpline(
derivatives::A,
B::T2
) where {T1, T2, A <: AbstractVector{T1}}
# Using `NNLlinb.softax` instead of `StatsFuns.softmax` (which does inplace operations)
return RationalQuadraticSpline(
(cumsum(vcat([zero(T1)], NNlib.softmax(widths))) .- 0.5) * 2 * B,
(cumsum(vcat([zero(T1)], NNlib.softmax(heights))) .- 0.5) * 2 * B,
vcat([one(T1)], softplus.(derivatives), [one(T1)])
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .- 0.5) * 2 * B,
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .- 0.5) * 2 * B,
vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)])
)
end

Expand All @@ -125,9 +122,9 @@ function RationalQuadraticSpline(
derivatives::A,
B::T2
) where {T1, T2, A <: AbstractMatrix{T1}}
ws = hcat(zeros(T1, size(widths, 1)), NNlib.softmax(widths; dims = 2))
hs = hcat(zeros(T1, size(widths, 1)), NNlib.softmax(heights; dims = 2))
ds = hcat(ones(T1, size(widths, 1)), softplus.(derivatives), ones(T1, size(widths, 1)))
ws = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(widths; dims = 2))
hs = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(heights; dims = 2))
ds = hcat(ones(T1, size(widths, 1)), LogExpFunctions.log1pexp.(derivatives), ones(T1, size(widths, 1)))

return RationalQuadraticSpline(
(2 * B) .* (cumsum(ws; dims = 2) .- 0.5),
Expand Down
28 changes: 14 additions & 14 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) wh
ϵ = _eps(T)
sum_tmp = zero(T)
@inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ]
@inbounds y[1] = StatsFuns.logit(z) + log(T(K - 1))
@inbounds y[1] = LogExpFunctions.logit(z) + log(T(K - 1))
@inbounds @simd for k in 2:(K - 1)
sum_tmp += x[k - 1]
# z ∈ [ϵ, 1-ϵ]
# x[k] = 0 && sum_tmp = 1 -> z ≈ 1
z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)
y[k] = StatsFuns.logit(z) + log(T(K - k))
y[k] = LogExpFunctions.logit(z) + log(T(K - k))
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
Expand Down Expand Up @@ -64,11 +64,11 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) wh
@inbounds @simd for n in 1:size(X, 2)
sum_tmp = zero(T)
z = X[1, n] * (one(T) - 2ϵ) + ϵ
Y[1, n] = StatsFuns.logit(z) + log(T(K - 1))
Y[1, n] = LogExpFunctions.logit(z) + log(T(K - 1))
for k in 2:(K - 1)
sum_tmp += X[k - 1, n]
z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)
Y[k, n] = StatsFuns.logit(z) + log(T(K - k))
Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k))
end
sum_tmp += X[K-1, n]
if proj
Expand Down Expand Up @@ -98,11 +98,11 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(y)
ϵ = _eps(T)
@inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1)))
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
@inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), 0, 1)
sum_tmp = zero(T)
@inbounds @simd for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] - log(T(K - k)))
z = LogExpFunctions.logistic(y[k] - log(T(K - k)))
sum_tmp += x[k-1]
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
Expand Down Expand Up @@ -142,10 +142,10 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj
T = eltype(Y)
ϵ = _eps(T)
@inbounds @simd for n in 1:size(X, 2)
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] - log(T(K - 1)))
sum_tmp, z = zero(T), LogExpFunctions.logistic(Y[1, n] - log(T(K - 1)))
X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), 0, 1)
for k in 2:(K - 1)
z = StatsFuns.logistic(Y[k, n] - log(T(K - k)))
z = LogExpFunctions.logistic(Y[k, n] - log(T(K - k)))
sum_tmp += X[k - 1, n]
X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
Expand Down Expand Up @@ -383,15 +383,15 @@ function simplex_invlink_jacobian(
@inbounds dxdy .= 0

ϵ = _eps(T)
@inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1)))
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
unclamped_x = (z - ϵ) / (one(T) - 2ϵ)
clamped_x = _clamp(unclamped_x, 0, 1)
@inbounds if unclamped_x == clamped_x
dxdy[1,1] = z * (1 - z) / (one(T) - 2ϵ)
end
sum_tmp = zero(T)
@inbounds for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] - log(T(K - k)))
z = LogExpFunctions.logistic(y[k] - log(T(K - k)))
sum_tmp += clamped_x
unclamped_x = ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ
clamped_x = _clamp(unclamped_x, 0, 1)
Expand Down Expand Up @@ -445,7 +445,7 @@ function add_simplex_invlink_adjoint!(
@inbounds dxdy .= 0
ϵ = _eps(T)
@inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1)))
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
unclamped_x = (z - ϵ) / (one(T) - 2ϵ)
clamped_x = _clamp(unclamped_x, 0, 1)
@inbounds if unclamped_x == clamped_x
Expand All @@ -454,7 +454,7 @@ function add_simplex_invlink_adjoint!(
end
sum_tmp = zero(T)
@inbounds for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] - log(T(K - k)))
z = LogExpFunctions.logistic(y[k] - log(T(K - k)))
sum_tmp += clamped_x
unclamped_x = ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ
clamped_x = _clamp(unclamped_x, 0, 1)
Expand Down Expand Up @@ -503,7 +503,7 @@ function add_simplex_invlink_adjoint!(
@inbounds for col in 1:size(y,2)
dxdy .= 0
ϵ = _eps(T)
z = StatsFuns.logistic(y[1,col] - log(T(K - 1)))
z = LogExpFunction.logistic(y[1,col] - log(T(K - 1)))
unclamped_x = (z - ϵ) / (one(T) - 2ϵ)
clamped_x = _clamp(unclamped_x, 0, 1)
if unclamped_x == clamped_x
Expand All @@ -512,7 +512,7 @@ function add_simplex_invlink_adjoint!(
end
sum_tmp = zero(T)
for k = 2:(K - 1)
z = StatsFuns.logistic(y[k,col] - log(T(K - k)))
z = LogExpFunctions.logistic(y[k,col] - log(T(K - k)))
sum_tmp += clamped_x
unclamped_x = ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ
clamped_x = _clamp(unclamped_x, 0, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
function truncated_link(x::Real, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return StatsFuns.logit((x - a) / (b - a))
return LogExpFunctions.logit((x - a) / (b - a))
elseif lowerbounded
return log(x - a)
elseif upperbounded
Expand Down Expand Up @@ -102,7 +102,7 @@ end
function truncated_invlink(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
return (b - a) * StatsFuns.logistic(y) + a
return (b - a) * LogExpFunctions.logistic(y) + a
elseif lowerbounded
return exp(y) + a
elseif upperbounded
Expand Down
Loading

2 comments on commit e9d289a

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/44811

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.8 -m "<description of version>" e9d289a08e1bd06e9610b3cfdcb2eaa5586f3614
git push origin v0.9.8

Please sign in to comment.