Skip to content

Commit

Permalink
Merge pull request #429 from JuliaSymbolics/s/literal
Browse files Browse the repository at this point in the history
LiteralReal
  • Loading branch information
shashi authored Jan 24, 2022
2 parents e2a6f8a + 07877fb commit 01946d5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
53 changes: 31 additions & 22 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ const diadic = [max, min, hypot, atan, mod, rem, copysign,
polygamma, beta, logbeta]
const previously_declared_for = Set([])

const basic_monadic = [-, +]
const basic_diadic = [+, -, *, /, //, \, ^]
#################### SafeReal #########################
export SafeReal
export SafeReal, LiteralReal

# ideally the relationship should be the other way around
abstract type SafeReal <: Real end

################### LiteralReal #######################

abstract type LiteralReal <: Real end

#######################################################

assert_like(f, T) = nothing
Expand All @@ -41,14 +47,13 @@ function number_methods(T, rhs1, rhs2, options=nothing)
exprs = []

skip_basics = options !== nothing ? options == :skipbasics : false
only_basics = options !== nothing ? options == :onlybasics : false
skips = Meta.isexpr(options, [:vcat, :hcat, :vect]) ? Set(options.args) : []
basic_monadic = [-, +]
basic_diadic = [+, -, *, /, //, \, ^]

rhs2 = :($assert_like(f, Number, a, b); $rhs2)
rhs1 = :($assert_like(f, Number, a); $rhs1)

for f in (skip_basics ? diadic : vcat(basic_diadic, diadic))
for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic))
nameof(f) in skips && continue
for S in previously_declared_for
push!(exprs, quote
Expand All @@ -69,7 +74,7 @@ function number_methods(T, rhs1, rhs2, options=nothing)
push!(exprs, expr)
end

for f in (skip_basics ? monadic : vcat(basic_monadic, monadic))
for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic))
nameof(f) in skips && continue
push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1))
end
Expand All @@ -87,28 +92,31 @@ end
@number_methods(Mul, term(f, a), term(f, a, b), skipbasics)
@number_methods(Pow, term(f, a), term(f, a, b), skipbasics)
@number_methods(Div, term(f, a), term(f, a, b), skipbasics)
@number_methods(Sym{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
@number_methods(Term{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)

for f in vcat(diadic, [+, -, *, \, /, ^])
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:SafeReal},
S::Type{<:Real})
X = promote_type(T, Real)
X == Real ? SafeReal : X
end
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:Real},
S::Type{<:SafeReal})
X = promote_type(Real, S)
X == Real ? SafeReal : X
end
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:SafeReal},
S::Type{<:SafeReal})
X = promote_type(Real, Real)
X == Real ? SafeReal : X
for R in [SafeReal, LiteralReal]
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:$R},
S::Type{<:Real})
X = promote_type(T, Real)
X == Real ? $R : X
end
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:Real},
S::Type{<:$R})
X = promote_type(Real, S)
X == Real ? $R : X
end
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:$R},
S::Type{<:$R})
$R
end
end
end

Expand All @@ -118,6 +126,7 @@ Base.rem2pi(x::Symbolic{<:Number}, mode::Base.RoundingMode) = term(rem2pi, x, mo
for f in monadic
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real)
@eval promote_symtype(::$(typeof(f)), T::Type{<:SafeReal}) = SafeReal
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralReal}) = LiteralReal
@eval (::$(typeof(f)))(a::Symbolic{<:Number}) = term($f, a)
end

Expand Down
11 changes: 11 additions & 0 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,14 @@ end
@test (2.5x/3x).den == 3
@test (x/3x) == 1//3
end

@testset "LiteralReal" begin
@syms x::LiteralReal y::LiteralReal z::LiteralReal
@test repr(x+x) == "x + x"
@test repr(x*x) == "x*x"
@test repr(x*x + x*x) == "x*x + x*x"
for ex in [sin(x), x+x, x*x, x\x, x/x]
@test typeof(sin(x)) <: Term{LiteralReal}
end
@test repr(sin(x) + sin(x)) == "sin(x) + sin(x)"
end

0 comments on commit 01946d5

Please sign in to comment.