From 24fa396f44084dae0f9751d740a8b01d4862a893 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Sat, 10 Jun 2023 20:22:06 +0800 Subject: [PATCH] making rqs compatible with float32 input (#267) * making rqs compatible with float32 input * minor format edit * Update Format.yml * fix format * Update Format.yml * Update Format.yml * save additional allocations in rqs layer Co-authored-by: David Widmann * rm allocations in rqs layer Co-authored-by: David Widmann * bump version to 0.12.6 * add tests for rqs --------- Co-authored-by: Tor Erlend Fjelde Co-authored-by: David Widmann --- .github/workflows/Format.yml | 2 +- Project.toml | 2 +- src/bijectors/rational_quadratic_spline.jl | 6 +-- test/bijectors/rational_quadratic_spline.jl | 44 +++++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml index 6a6df765..5847fee5 100644 --- a/.github/workflows/Format.yml +++ b/.github/workflows/Format.yml @@ -35,4 +35,4 @@ jobs: if: github.event_name == 'pull_request' with: tool_name: JuliaFormatter - fail_on_error: true \ No newline at end of file + fail_on_error: true diff --git a/Project.toml b/Project.toml index 98f1908f..b4757af6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.5" +version = "0.12.6" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 03e10a8a..1ad8acef 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -100,8 +100,8 @@ function RationalQuadraticSpline( widths::A, heights::A, derivatives::A, B::T2 ) where {T1,T2,A<:AbstractVector{T1}} return RationalQuadraticSpline( - (cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .- 0.5) * 2 * B, - (cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .- 0.5) * 2 * B, + cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .* (2 * B) .- B, + cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .* (2 * B) .- B, vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)]), ) end @@ -118,7 +118,7 @@ function RationalQuadraticSpline( ) return RationalQuadraticSpline( - (2 * B) .* (cumsum(ws; dims=2) .- 0.5), (2 * B) .* (cumsum(hs; dims=2) .- 0.5), ds + (2 * B) .* cumsum(ws; dims=2) .- B, (2 * B) .* cumsum(hs; dims=2) .- B, ds ) end diff --git a/test/bijectors/rational_quadratic_spline.jl b/test/bijectors/rational_quadratic_spline.jl index 4b8be14b..a936c01c 100644 --- a/test/bijectors/rational_quadratic_spline.jl +++ b/test/bijectors/rational_quadratic_spline.jl @@ -1,6 +1,7 @@ using Test using Bijectors using Bijectors: RationalQuadraticSpline +using LogExpFunctions @testset "RationalQuadraticSpline" begin # Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points". @@ -59,4 +60,47 @@ using Bijectors: RationalQuadraticSpline x = [-5.0, 5.0] test_bijector(b, x; y=x, logjac=zero(eltype(x))) end + + @testset "Float32 support" begin + ws = randn(Float32, K) + hs = randn(Float32, K) + ds = randn(Float32, K - 1) + + Ws = randn(Float32, d, K) + Hs = randn(Float32, d, K) + Ds = randn(Float32, d, K - 1) + + # success of construction + b = RationalQuadraticSpline(ws, hs, ds, B) + bb = RationalQuadraticSpline(Ws, Hs, Ds, B) + end + + @testset "consistency after commit" begin + ws = randn(K) + hs = randn(K) + ds = randn(K - 1) + + Ws = randn(d, K) + Hs = randn(d, K) + Ds = randn(d, K - 1) + + Ws_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Ws; dims=2)) + Hs_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Hs; dims=2)) + + # success of construction + b = RationalQuadraticSpline(ws, hs, ds, B) + b_mv = RationalQuadraticSpline(Ws, Hs, Ds, B) + + # consistency of evaluation + @test all( + (cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(ws))) .- 0.5) * 2 * B .≈ + b.widths, + ) + @test all( + (cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(hs))) .- 0.5) * 2 * B .≈ + b.heights, + ) + @test all((2 * B) .* (cumsum(Ws_t; dims=2) .- 0.5) .≈ b_mv.widths) + @test all((2 * B) .* (cumsum(Hs_t; dims=2) .- 0.5) .≈ b_mv.heights) + end end