From b638c4887ab77ad0949aa6381db40e6d438611ff Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 26 Nov 2019 13:24:03 -0600 Subject: [PATCH] Add support for CheckedArithmetic --- Project.toml | 4 ++++ src/FixedPointNumbers.jl | 4 ++++ src/fixed.jl | 7 +++++++ src/normed.jl | 7 +++++++ test/normed.jl | 22 +++++++++++++++++++++- 5 files changed, 43 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 60fd45cd..0482f844 100644 --- a/Project.toml +++ b/Project.toml @@ -2,8 +2,12 @@ name = "FixedPointNumbers" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.6.1" +[deps] +CheckedArithmetic = "2c4a1fb8-30c1-4c71-8b84-dff8d59868ee" + [compat] julia = "1" +CheckedArithmetic = "0.1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/FixedPointNumbers.jl b/src/FixedPointNumbers.jl index ccd86e50..9ad8b07e 100644 --- a/src/FixedPointNumbers.jl +++ b/src/FixedPointNumbers.jl @@ -12,6 +12,8 @@ import Base: ==, <, <=, -, +, *, /, ~, isapprox, using Base: @pure +using CheckedArithmetic + # T => BaseType # f => Number of bits reserved for fractional part abstract type FixedPoint{T <: Integer, f} <: Real end @@ -199,4 +201,6 @@ end rand(::Type{T}) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T))) rand(::Type{T}, sz::Dims) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T), sz)) +CheckedArithmetic.safearg_type(::Type{T}) where T<:FixedPoint = Float64 + end # module diff --git a/src/fixed.jl b/src/fixed.jl index 3576a610..f237a640 100644 --- a/src/fixed.jl +++ b/src/fixed.jl @@ -97,3 +97,10 @@ end # TODO: Document and check that it still does the right thing. decompose(x::Fixed{T,f}) where {T,f} = x.i, -f, 1 + +CheckedArithmetic.accumulatortype(::typeof(+), ::Type{Fixed{T,f}}) where {T,f} = + Fixed{accumulatortype(+, T), f} +CheckedArithmetic.accumulatortype(::typeof(-), ::Type{Fixed{T,f}}) where {T,f} = + Fixed{accumulatortype(-, T), f} +CheckedArithmetic.accumulatortype(::typeof(*), ::Type{Fixed{T,f}}) where {T,f} = + floattype(Fixed{T,f}) diff --git a/src/normed.jl b/src/normed.jl index 737657b4..5d4be612 100644 --- a/src/normed.jl +++ b/src/normed.jl @@ -333,3 +333,10 @@ if !signbit(signed(unsafe_trunc(UInt, -12.345))) unsafe_trunc(T, unsafe_trunc(typeof(signed(zero(T))), x)) end end + +CheckedArithmetic.accumulatortype(::typeof(+), ::Type{Normed{T,f}}) where {T,f} = + Normed{accumulatortype(+, T), f} +CheckedArithmetic.accumulatortype(::typeof(-), ::Type{Normed{T,f}}) where {T,f} = + floattype(Normed{T,f}) +CheckedArithmetic.accumulatortype(::typeof(*), ::Type{Normed{T,f}}) where {T,f} = + floattype(Normed{T,f}) diff --git a/test/normed.jl b/test/normed.jl index 6fb4d125..c3454b50 100644 --- a/test/normed.jl +++ b/test/normed.jl @@ -1,4 +1,4 @@ -using FixedPointNumbers, Test +using FixedPointNumbers, CheckedArithmetic, Test @testset "reinterpret" begin @test reinterpret(N0f8, 0xa2).i === 0xa2 @@ -367,6 +367,22 @@ end @test 1.0*a == bd*ad end +function sum_naive(A::AbstractArray) + s = zero(eltype(A)) + for a in A + s += a + end + return s +end + +function sumsquares(A::AbstractArray) + s = zero(accumulatortype(eltype(A))) + for a in A + s += acc(a)^2 + end + return s +end + @testset "reductions" begin a = N0f8[reinterpret(N0f8, 0xff), reinterpret(N0f8, 0xff)] @test sum(a) == 2.0 @@ -376,6 +392,10 @@ end acmp = Float64(a[1])*Float64(a[2]) @test prod(a) == acmp @test prod(a, dims=1) == [acmp] + + a = reinterpret(N0f8, [0x01:0xff;]) + @test_throws ArgumentError @check sum_naive(a) atol=1e-4 + @check sumsquares(a) atol=1e-4 end @testset "rand" begin