diff --git a/src/diff.jl b/src/diff.jl index 3845ae8b1..df8f500de 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -150,6 +150,127 @@ function recursive_hasoperator(op, O) end end +""" + executediff(D, arg, simplify=false; occurrences=nothing) + +Apply the passed Differential D on the passed argument. + +This function differs to `expand_derivatives` in that in only expands the +passed differential and not any other Differentials it encounters. + +# Arguments +- `D::Differential`: The differential to apply +- `arg::Symbolic`: The symbolic expression to apply the differential on. +- `simplify::Bool=false`: Whether to simplify the resulting expression using + [`SymbolicUtils.simplify`](@ref). +- `occurrences=nothing`: Information about the occurrences of the independent + variable in the argument of the derivative. This is used internally for + optimization purposes. +""" +function executediff(D, arg, simplify=false; occurrences=nothing) + if occurrences == nothing + occurrences = occursin_info(D.x, arg) + end + + _isfalse(occurrences) && return 0 + occurrences isa Bool && return 1 # means it's a `true` + + if !iscall(arg) + return D(arg) # Cannot expand + elseif (op = operation(arg); issym(op)) + inner_args = arguments(arg) + if any(isequal(D.x), inner_args) + return D(arg) # base case if any argument is directly equal to the i.v. + else + return sum(inner_args, init=0) do a + return executediff(Differential(a), arg) * + executediff(D, a) + end + end + elseif op === (IfElse.ifelse) + args = arguments(arg) + O = op(args[1], + executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]), + executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3])) + return O + elseif isa(op, Differential) + # The recursive expand_derivatives was not able to remove + # a nested Differential. We can attempt to differentiate the + # inner expression wrt to the outer iv. And leave the + # unexpandable Differential outside. + if isequal(op.x, D.x) + return D(arg) + else + inner = executediff(D, arguments(arg)[1], false) + # if the inner expression is not expandable either, return + if iscall(inner) && operation(inner) isa Differential + return D(arg) + else + # otherwise give the nested Differential another try + return executediff(op, inner, simplify) + end + end + elseif isa(op, Integral) + if isa(op.domain.domain, AbstractInterval) + domain = op.domain.domain + a, b = DomainSets.endpoints(domain) + c = 0 + inner_function = arguments(arg)[1] + if iscall(value(a)) + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) + t2 = D(a) + c -= t1*t2 + end + if iscall(value(b)) + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) + t2 = D(b) + c += t1*t2 + end + inner = executediff(D, arguments(arg)[1]) + c += op(inner) + return value(c) + end + end + + inner_args = arguments(arg) + l = length(inner_args) + exprs = [] + c = 0 + + for i in 1:l + t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i]) + + x = if _iszero(t2) + t2 + elseif _isone(t2) + d = derivative_idx(arg, i) + d isa NoDeriv ? D(arg) : d + else + t1 = derivative_idx(arg, i) + t1 = t1 isa NoDeriv ? D(arg) : t1 + t1 * t2 + end + + if _iszero(x) + continue + elseif x isa Symbolic + push!(exprs, x) + else + c += x + end + end + + if isempty(exprs) + return c + elseif length(exprs) == 1 + term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1]) + return _iszero(c) ? term : c + term + else + x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) + return simplify ? SymbolicUtils.simplify(x) : x + end +end + """ $(SIGNATURES) @@ -162,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters. - `O::Symbolic`: The symbolic expression to expand. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). -- `occurrences=nothing`: Information about the occurrences of the independent - variable in the argument of the derivative. This is used internally for - optimization purposes. # Examples ```jldoctest @@ -180,111 +298,11 @@ julia> dfx = expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) +function expand_derivatives(O::Symbolic, simplify=false) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) arg = expand_derivatives(arg, false) - - if occurrences == nothing - occurrences = occursin_info(operation(O).x, arg) - end - - _isfalse(occurrences) && return 0 - occurrences isa Bool && return 1 # means it's a `true` - - D = operation(O) - - if !iscall(arg) - return D(arg) # Cannot expand - elseif (op = operation(arg); issym(op)) - inner_args = arguments(arg) - if any(isequal(D.x), inner_args) - return D(arg) # base case if any argument is directly equal to the i.v. - else - return sum(inner_args, init=0) do a - return expand_derivatives(Differential(a)(arg)) * - expand_derivatives(D(a)) - end - end - elseif op === (IfElse.ifelse) - args = arguments(arg) - O = op(args[1], D(args[2]), D(args[3])) - return expand_derivatives(O, simplify; occurrences) - elseif isa(op, Differential) - # The recursive expand_derivatives was not able to remove - # a nested Differential. We can attempt to differentiate the - # inner expression wrt to the outer iv. And leave the - # unexpandable Differential outside. - if isequal(op.x, D.x) - return D(arg) - else - inner = expand_derivatives(D(arguments(arg)[1]), false) - # if the inner expression is not expandable either, return - if iscall(inner) && operation(inner) isa Differential - return D(arg) - else - return expand_derivatives(op(inner), simplify) - end - end - elseif isa(op, Integral) - if isa(op.domain.domain, AbstractInterval) - domain = op.domain.domain - a, b = DomainSets.endpoints(domain) - c = 0 - inner_function = expand_derivatives(arguments(arg)[1]) - if iscall(value(a)) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) - t2 = D(a) - c -= t1*t2 - end - if iscall(value(b)) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) - t2 = D(b) - c += t1*t2 - end - inner = expand_derivatives(D(arguments(arg)[1])) - c += op(inner) - return value(c) - end - end - - inner_args = arguments(arg) - l = length(inner_args) - exprs = [] - c = 0 - - for i in 1:l - t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i]) - - x = if _iszero(t2) - t2 - elseif _isone(t2) - d = derivative_idx(arg, i) - d isa NoDeriv ? D(arg) : d - else - t1 = derivative_idx(arg, i) - t1 = t1 isa NoDeriv ? D(arg) : t1 - t1 * t2 - end - - if _iszero(x) - continue - elseif x isa Symbolic - push!(exprs, x) - else - c += x - end - end - - if isempty(exprs) - return c - elseif length(exprs) == 1 - term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1]) - return _iszero(c) ? term : c + term - else - x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) - return simplify ? SymbolicUtils.simplify(x) : x - end + return executediff(operation(O), arg, simplify) elseif iscall(O) && isa(operation(O), Integral) return operation(O)(expand_derivatives(arguments(O)[1])) elseif !hasderiv(O) @@ -295,14 +313,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) return simplify ? SymbolicUtils.simplify(O1) : O1 end end -function expand_derivatives(n::Num, simplify=false; occurrences=nothing) - wrap(expand_derivatives(value(n), simplify; occurrences=occurrences)) +function expand_derivatives(n::Num, simplify=false) + wrap(expand_derivatives(value(n), simplify)) end -function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences), - expand_derivatives(imag(n), simplify; occurrences=occurrences))) +function expand_derivatives(n::Complex{Num}, simplify=false) + wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify), + expand_derivatives(imag(n), simplify))) end -expand_derivatives(x, simplify=false; occurrences=nothing) = x +expand_derivatives(x, simplify=false) = x _iszero(x) = false _isone(x) = false diff --git a/test/diff.jl b/test/diff.jl index fc29f3da5..e977fdbee 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -349,6 +349,34 @@ let @test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im) end +# 1262 +# +let + @variables t b(t) + D = Differential(t) + expr = b - ((D(b))^2) * D(D(b)) + expr2 = D(expr) + @test isequal(expand_derivatives(expr), expr) + @test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) +end + +# 1126 +# +let + @syms y f(y) g(y) h(y) + D = Differential(y) + + expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y)))) + + expr = expr_gen(g(y)) + # just make sure that no errors are thrown in the following, the results are to complicated to compare + expand_derivatives(expr) + expr = expr_gen(h(y)) + expand_derivatives(expr) + + expr = expr_gen(f(y)) + expand_derivatives(expr) +end # Check `is_derivative` function let