Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle broadcasted operators #12

Merged
merged 10 commits into from
May 17, 2024
68 changes: 51 additions & 17 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ const op_checked = Dict(
:- => :(checked_sub),
:* => :(checked_mul),
:^ => :(checked_pow),
:+= => :(checked_add),
:-= => :(checked_sub),
:*= => :(checked_mul),
:^= => :(checked_pow),
:abs => :(checked_abs),
)

Expand All @@ -93,13 +89,27 @@ const op_unchecked = Dict(
:- => :(unchecked_sub),
:* => :(unchecked_mul),
:^ => :(unchecked_pow),
:+= => :(unchecked_add),
:-= => :(unchecked_sub),
:*= => :(unchecked_mul),
:^= => :(unchecked_pow),
:abs => :(unchecked_abs)
)

const broadcast_op_map = Dict(
:.+ => :+,
:.- => :-,
:.* => :*,
:.^ => :^
)

const assignment_op_map = Dict(
:+= => :+,
:-= => :-,
:*= => :*,
:^= => :^,
:.+= => :.+,
:.-= => :.-,
:.*= => :.*,
:.^= => :.^,
)

# resolve ambiguity when `-` used as symbol
unchecked_negsub(x) = unchecked_neg(x)
unchecked_negsub(x, y) = unchecked_sub(x, y)
Expand All @@ -110,18 +120,34 @@ checked_negsub(x, y) = checked_sub(x, y)
function replace_op!(expr::Expr, op_map::Dict)
if isexpr(expr, :call)
f, len = expr.args[1], length(expr.args)
op = isexpr(f, :.) ? f.args[2].value : f # handle module-scoped functions
if op === :+ && len == 2 # unary +
op = isexpr(f, :.) ? f.args[2].value : f # handle module-scoped functions
if op === :+ && len == 2 # unary +
# no action required
elseif op === :- && len == 2 # unary -
elseif op === :- && len == 2 # unary -
op = get(op_map, Symbol("unary-"), op)
if isexpr(f, :.)
f.args[2] = QuoteNode(op)
expr.args[1] = f
else
expr.args[1] = op
end
else # arbitrary call
elseif op ∈ keys(broadcast_op_map) # broadcast operators
op = get(broadcast_op_map, op, op)
if length(expr.args) == 2 # unary operator
if op == :-
expr.head = :.
expr.args = [
get(op_map, Symbol("unary-"), op),
Expr(:tuple, expr.args[2])]
end
# no action required for .+
else
expr.head = :.
expr.args = [
get(op_map, op, op),
Expr(:tuple, expr.args[2:end]...)]
end
else # arbitrary call
op = get(op_map, op, op)
if isexpr(f, :.)
f.args[2] = QuoteNode(op)
Expand All @@ -134,7 +160,7 @@ function replace_op!(expr::Expr, op_map::Dict)
a = expr.args[i]
if isa(a, Expr)
replace_op!(a, op_map)
elseif isa(a, Symbol) # operator as symbol function argument, e.g. `fold(+, ...)`
elseif isa(a, Symbol) # operator as symbol function argument, e.g. `fold(+, ...)`
op = if a == :-
get(op_map, Symbol("ambig-"), a)
else
Expand All @@ -146,13 +172,16 @@ function replace_op!(expr::Expr, op_map::Dict)
expr.args[i] = op
end
end
elseif isexpr(expr, (:+=, :-=, :*=, :^=)) # in-place operator
elseif isexpr(expr, keys(assignment_op_map)) # assignment operators
target = expr.args[1]
arg = expr.args[2]
op = expr.head
op = get(op_map, op, op)
expr.head = :(=)
expr.args[2] = Expr(:call, op, target, arg)
op = get(assignment_op_map, op, op)
expr.head = startswith(string(op), ".") ? :.= : :(=) # is there a better test?
expr.args[2] = replace_op!(Expr(:call, op, target, arg), op_map)
elseif isexpr(expr, :.) # broadcast function
op = expr.args[1]
expr.args[1] = get(op_map, op, op)
elseif !isexpr(expr, :macrocall) || expr.args[1] ∉ (Symbol("@checked"), Symbol("@unchecked"))
for a in expr.args
if isa(a, Expr)
Expand All @@ -162,3 +191,8 @@ function replace_op!(expr::Expr, op_map::Dict)
end
return expr
end

if VERSION < v"1.6"
import Base.Meta: isexpr
isexpr(expr, heads) = isexpr(expr, collect(heads))
end
34 changes: 34 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,40 @@ using SaferIntegers
end))
end

@testset "Broadcasted operators replaced" begin
aa = fill(typemax(Int), 2)
bb = fill(2, 2)
cc = fill(typemin(Int), 2)
@unchecked(.+cc) == cc
@unchecked(.-cc) == cc
@checked(.+cc) == cc
@test_throws OverflowError @checked(.-cc)
@unchecked(aa .+ bb) == fill(typemin(Int) + 1, 2)
@test_throws OverflowError @checked aa .+ bb
@unchecked(cc .- bb) == fill(typemax(Int) - 1, 2)
@test_throws OverflowError @checked cc .- bb
@unchecked(aa .* bb) == fill(-2, 2)
@test_throws OverflowError @checked aa .* bb
@unchecked(aa .^ bb) == fill(1, 2)
@test_throws OverflowError @checked aa .^ bb
@unchecked(abs.(cc)) == cc
@test_throws OverflowError @checked abs.(cc)
end

@testset "Broadcasted assignment operators replaced" begin
aa = fill(typemax(Int), 2)
bb = fill(2, 2)
cc = fill(typemin(Int), 2)
@unchecked(copy(aa) .+= bb) == fill(typemin(Int) + 1, 2)
@test_throws OverflowError @checked aa .+ bb
@unchecked(copy(cc) .-= bb) == fill(typemax(Int) - 1, 2)
@test_throws OverflowError @checked cc .- bb
@unchecked(copy(aa) .* bb) == fill(-2, 2)
@test_throws OverflowError @checked aa .* bb
@unchecked(copy(aa) .^ bb) == fill(1, 2)
@test_throws OverflowError @checked aa .^ bb
end

@testset "Elementwise array methods are replaced, and others throw" begin
aa = fill(typemax(Int), 2)
bb = fill(2, 2)
Expand Down
Loading