Skip to content

Commit

Permalink
fix muladd for various number types (baggepinnen#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin authored Jul 2, 2024
1 parent fc9c448 commit 5edba75
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
25 changes: 6 additions & 19 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,25 +395,12 @@ for PT in ParticleSymbols

end

# for XT in (:T, :($PT{T,N})), YT in (:T, :($PT{T,N})), ZT in (:T, :($PT{T,N}))
# XT == YT == ZT == :T && continue
# @eval function Base.muladd(x::$XT,y::$YT,z::$ZT) where {T<:Number,N}
# res = muladd.(maybe_particles(x),maybe_particles(y),maybe_particles(z))
# $PT{eltype(res),N}(res)
# end
# end

@eval function Base.muladd(x::$PT{T,N},y::$PT{T,N},z::$PT{T,N}) where {T<:Number,N}
res = muladd.(x.particles,y.particles,z.particles)
$PT{T,N}(res)
end
@eval function Base.muladd(x::T,y::$PT{T,N},z::$PT{T,N}) where {T<:Number,N}
res = muladd.(x,y.particles,z.particles)
$PT{T,N}(res)
end
@eval function Base.muladd(x::T,y::T,z::$PT{T,N}) where {T<:Number,N}
res = muladd.(x,y,z.particles)
$PT{T,N}(res)
for XT in (:Number, :($PT{<:Number,N})), YT in (:Number, :($PT{<:Number,N})), ZT in (:Number, :($PT{<:Number,N}))
XT == YT == ZT == :Number && continue
@eval function Base.muladd(x::$XT,y::$YT,z::$ZT) where {N}
res = muladd.(maybe_particles(x),maybe_particles(y),maybe_particles(z))
$PT{eltype(res),N}(res)
end
end

@eval Base.promote_rule(::Type{S}, ::Type{$PT{T,N}}) where {S<:Number,T,N} = $PT{promote_type(S,T),N} # This is hard to hit due to method for real 3 lines down
Expand Down
5 changes: 5 additions & 0 deletions test/test_unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ register_primitive(unitful_testfunction) # must be outside testset

typeof(promote(1u"V", (1.0 ± 0.1)u"V")) <: Tuple{Particles{<:Quantity}, Particles{<:Quantity}}

@test muladd(p1, 1, p1) == p1 + p1
@test muladd(p1, 1, p2) == p1 + p2
@test muladd(1, p1, p2) == p1 + p2
@test muladd(p1, 1/p1, 1) == 2

ρ = (2.7 ± 0.2)u"g/cm^3"
mass = (250 ± 10)u"g"
width = (30.5 ± 0.2)u"cm"
Expand Down

0 comments on commit 5edba75

Please sign in to comment.