From 185c01237c968076841b0ddfe9355ec1efb92241 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 9 Jan 2025 12:50:20 +0100 Subject: [PATCH] Fix ForwardDiff derivative of `NaNMath.pow` --- ext/SymbolicsForwardDiffExt.jl | 6 +++--- test/forwarddiff_symbolic_dual_ops.jl | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ext/SymbolicsForwardDiffExt.jl b/ext/SymbolicsForwardDiffExt.jl index 9882bec12..6bb7117e2 100644 --- a/ext/SymbolicsForwardDiffExt.jl +++ b/ext/SymbolicsForwardDiffExt.jl @@ -199,7 +199,7 @@ end # exponentiation # #----------------# -for f in (:(Base.:^), :(NaNMath.pow)) +for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log))) @eval begin @define_binary_dual_op( $f, @@ -212,7 +212,7 @@ for f in (:(Base.:^), :(NaNMath.pow)) elseif iszero(vx) && vy > 0 logval = zero(vx) else - logval = expv * log(vx) + logval = expv * ($log)(vx) end new_partials = _mul_partials(partials(x), partials(y), powval, logval) return Dual{Txy}(expv, new_partials) @@ -230,7 +230,7 @@ for f in (:(Base.:^), :(NaNMath.pow)) begin v = value(y) expv = ($f)(x, v) - deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x) return Dual{Ty}(expv, deriv * partials(y)) end, $AMBIGUOUS_TYPES diff --git a/test/forwarddiff_symbolic_dual_ops.jl b/test/forwarddiff_symbolic_dual_ops.jl index e132f9a50..eba74b7c1 100644 --- a/test/forwarddiff_symbolic_dual_ops.jl +++ b/test/forwarddiff_symbolic_dual_ops.jl @@ -114,3 +114,9 @@ end y(x) = isequal(z, x) ? 0 : x @test ForwardDiff.derivative(y, 0) == 1 # expect ∂(x)/∂x end + +@testset "NaNMath.pow (issue #1399)" begin + @variables x + @test_throws DomainError substitute(ForwardDiff.derivative(z -> x^z, 0.5), x => -1.0) + @test isnan(Symbolics.value(substitute(ForwardDiff.derivative(z -> NaNMath.pow(x, z), 0.5), x => -1.0))) +end