Skip to content

Commit

Permalink
Add literal_pow transformation (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
BioTurboNick authored May 29, 2024
1 parent 117a93d commit 7aa1d39
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 20 deletions.
14 changes: 14 additions & 0 deletions src/OverflowContexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ module OverflowContexts
const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128}
const UnsignedBitInteger = Union{UInt8, UInt16, UInt32, UInt64, UInt128}

using Base: BitInteger, promote, afoldl, @_inline_meta
import Base: literal_pow
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs,
checked_div, checked_fld, checked_cld, checked_mod, checked_rem
using Base.Checked: mul_with_overflow

if VERSION v"1.11-alpha"
import Base: power_by_squaring
import Base.Checked: checked_pow
else
using Base: throw_domerr_powbysq, to_power_type
using Base.Checked: throw_overflowerr_binaryop
end

include("macros.jl")
include("checked.jl")
include("unchecked.jl")
Expand Down
29 changes: 16 additions & 13 deletions src/checked.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
using Base: BitInteger, promote, afoldl, @_inline_meta
import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs,
checked_div, checked_fld, checked_cld, checked_mod, checked_rem
using Base.Checked: mul_with_overflow

if VERSION v"1.11-alpha"
import Base: power_by_squaring
import Base.Checked: checked_pow
else
using Base: throw_domerr_powbysq, to_power_type
using Base.Checked: throw_overflowerr_binaryop
end

# resolve ambiguity when `-` used as symbol
checked_negsub(x) = checked_neg(x)
checked_negsub(x, y) = checked_sub(x, y)
Expand Down Expand Up @@ -87,3 +74,19 @@ if VERSION < v"1.11"
return y
end
end

# adapted from Base intfuncs.jl; negative literal powers promote to floating point
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{0}) = one(x)
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{1}) = x
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{2}) = @checked x * x
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{3}) = @checked x * x * x
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1))
@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2))

@inline function literal_pow(f::typeof(checked_pow), x, ::Val{p}) where {p}
if p < 0
literal_pow(^, x, Val(p))
else
f(x, p)
end
end
6 changes: 6 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,18 @@ function replace_op!(expr::Expr, op_map::Dict)
Expr(:tuple, expr.args[2:end]...)]
end
else # arbitrary call
op_orig = op
op = get(op_map, op, op)
if isexpr(f, :.)
f.args[2] = QuoteNode(op)
expr.args[1] = f
else
expr.args[1] = op
if op_orig == :^ && expr.args[3] isa Integer
# literal_pow transformation
pushfirst!(expr.args, :(Base.literal_pow))
expr.args[4] = :(Val($(expr.args[4])))
end
end
end
for i in 2:length(expr.args)
Expand Down
23 changes: 16 additions & 7 deletions src/saturating.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import Base: BitInteger
import Base.Checked: mul_with_overflow

if VERSION v"1.11-alpha"
using Base: power_by_squaring
end

# resolve ambiguity when `-` used as symbol
saturating_negsub(x) = saturating_neg(x)
saturating_negsub(x, y) = saturating_sub(x, y)
Expand Down Expand Up @@ -156,3 +149,19 @@ function saturating_mod(x::T, y::T) where T <: SignedBitInteger
end

saturating_mod(x::T, y::T) where T <: UnsignedBitInteger = @saturating rem(x, y)

# adapted from Base intfuncs.jl; negative literal powers promote to floating point
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{0}) = one(x)
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{1}) = x
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{2}) = @saturating x * x
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{3}) = @saturating x * x * x
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1))
@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2))

@inline function literal_pow(f::typeof(saturating_pow), x, ::Val{p}) where {p}
if p < 0
literal_pow(^, x, Val(p))
else
f(x, p)
end
end
3 changes: 3 additions & 0 deletions src/unchecked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ unchecked_rem(x::T, y::T) where T <: UnsignedBitInteger =

unchecked_mod(x::T, y::T) where T <: SignedBitInteger = x - unchecked_fld(x, y) * y
unchecked_mod(x::T, y::T) where T <: UnsignedBitInteger = unchecked_rem(x, y)

# adapted from Base intfuncs.jl; negative literal powers promote to floating point
@inline literal_pow(::typeof(unchecked_pow), x, ::Val{p}) where {p} = literal_pow(^, x, Val(p))
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,9 @@ end
@test_throws ErrorException @saturating aa * bb'
@test_throws ErrorException @saturating dd ^ 2
end

@testset "literal_pow transformation" begin
expr = @macroexpand @checked 5 ^ 2
@test expr.args[1] == :(Base.literal_pow)
@test expr.args[2] == :(OverflowContexts.checked_pow)
end

0 comments on commit 7aa1d39

Please sign in to comment.