Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require that function arguments are ::Function. #228

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ function FiniteDifferenceMethod(grid::AbstractVector{Int}, q::Int; kw_args...)
end

"""
(m::FiniteDifferenceMethod)(f::Function, x::T) where T<:AbstractFloat
(m::FiniteDifferenceMethod)(f, x::T) where T<:AbstractFloat

Estimate the derivative of `f` at `x` using the finite differencing method `m` and an
automatically determined step size.

# Arguments
- `f::Function`: Function to estimate derivative of.
- `f`: Function to estimate derivative of.
- `x::T`: Input to estimate derivative at.

# Returns
Expand Down Expand Up @@ -188,12 +188,12 @@ julia> FiniteDifferences.estimate_step(fdm, sin, 1.0) # Computes step size and
# We loop over all concrete subtypes of `FiniteDifferenceMethod` for Julia v1.0 compatibility.
for T in (UnadaptedFiniteDifferenceMethod, AdaptedFiniteDifferenceMethod)
@eval begin
function (m::$T)(f::TF, x::Real) where TF<:Function
function (m::$T)(f::TF, x::Real) where TF
x = float(x) # Assume that converting to float is desired, if it isn't already.
step = first(estimate_step(m, f, x))
return m(f, x, step)
end
function (m::$T{P,0})(f::TF, x::Real) where {P,TF<:Function}
function (m::$T{P,0})(f::TF, x::Real) where {P,TF}
# The automatic step size calculation fails if `Q == 0`, so handle that edge
# case.
return f(x)
Expand All @@ -202,13 +202,13 @@ for T in (UnadaptedFiniteDifferenceMethod, AdaptedFiniteDifferenceMethod)
end

"""
(m::FiniteDifferenceMethod)(f::Function, x::T, step::Real) where T<:AbstractFloat
(m::FiniteDifferenceMethod)(f, x::T, step::Real) where T<:AbstractFloat

Estimate the derivative of `f` at `x` using the finite differencing method `m` and a given
step size.

# Arguments
- `f::Function`: Function to estimate derivative of.
- `f`: Function to estimate derivative of.
- `x::T`: Input to estimate derivative at.
- `step::Real`: Step size.

Expand All @@ -235,7 +235,7 @@ julia> fdm(sin, 1, 1e-3) - cos(1) # Check the error.
# We loop over all concrete subtypes of `FiniteDifferenceMethod` for 1.0 compatibility.
for T in (UnadaptedFiniteDifferenceMethod, AdaptedFiniteDifferenceMethod)
@eval begin
function (m::$T{P,Q})(f::TF, x::Real, step::Real) where {P,Q,TF<:Function}
function (m::$T{P,Q})(f::TF, x::Real, step::Real) where {P,Q,TF}
x = float(x) # Assume that converting to float is desired, if it isn't already.
fs = _eval_function(m, f, x, step)
return _compute_estimate(m, fs, x, step, m.coefs)
Expand All @@ -245,7 +245,7 @@ end

function _eval_function(
m::FiniteDifferenceMethod, f::TF, x::T, step::Real,
) where {TF<:Function,T<:AbstractFloat}
) where {TF,T<:AbstractFloat}
return f.(x .+ T(step) .* m.grid)
end

Expand Down Expand Up @@ -336,7 +336,7 @@ end
"""
function estimate_step(
m::FiniteDifferenceMethod,
f::Function,
f,
x::T
) where T<:AbstractFloat

Expand All @@ -345,7 +345,7 @@ estimate of the derivative.

# Arguments
- `m::FiniteDifferenceMethod`: Finite difference method to estimate the step size for.
- `f::Function`: Function to evaluate the derivative of.
- `f`: Function to evaluate the derivative of.
- `x::T`: Point to estimate the derivative at.

# Returns
Expand All @@ -355,13 +355,13 @@ estimate of the derivative.
"""
function estimate_step(
m::UnadaptedFiniteDifferenceMethod, f::TF, x::T,
) where {TF<:Function,T<:AbstractFloat}
) where {TF,T<:AbstractFloat}
step, acc = _compute_step_acc_default(m, x)
return _limit_step(m, x, step, acc)
end
function estimate_step(
m::AdaptedFiniteDifferenceMethod{P,Q}, f::TF, x::T,
) where {P,Q,TF<:Function,T<:AbstractFloat}
) where {P,Q,TF,T<:AbstractFloat}
∇f_magnitude, f_magnitude = _estimate_magnitudes(m.bound_estimator, f, x)
if ∇f_magnitude == 0.0 || f_magnitude == 0.0
step, acc = _compute_step_acc_default(m, x)
Expand All @@ -373,7 +373,7 @@ end

function _estimate_magnitudes(
m::FiniteDifferenceMethod{P,Q}, f::TF, x::T,
) where {P,Q,TF<:Function,T<:AbstractFloat}
) where {P,Q,TF,T<:AbstractFloat}
step = first(estimate_step(m, f, x))
fs = _eval_function(m, f, x, step)
# Estimate magnitude of `∇f` in a neighbourhood of `x`.
Expand Down Expand Up @@ -551,7 +551,7 @@ end
"""
extrapolate_fdm(
m::FiniteDifferenceMethod,
f::Function,
f,
x::Real,
initial_step::Real=10,
power::Int=1,
Expand All @@ -567,7 +567,7 @@ automatically sets `power = 2` if `m` is symmetric and `power = 1`. Moreover, it

# Arguments
- `m::FiniteDifferenceMethod`: Finite difference method to estimate the step size for.
- `f::Function`: Function to evaluate the derivative of.
- `f`: Function to evaluate the derivative of.
- `x::Real`: Point to estimate the derivative at.
- `initial_step::Real=10`: Initial step size.

Expand All @@ -576,7 +576,7 @@ automatically sets `power = 2` if `m` is symmetric and `power = 1`. Moreover, it
"""
function extrapolate_fdm(
m::FiniteDifferenceMethod,
f::Function,
f,
x::Real,
initial_step::Real=10;
power::Int=1,
Expand Down
14 changes: 14 additions & 0 deletions test/methods.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Wrapper used in test set “do not require f::Function” below, moved outside so that it
works on Julia 1.0.
"""
struct NotAFunction end # not <: Function on purpose, cf #224
(::NotAFunction)(x) = abs2(x)

@testset "Methods" begin
@testset "Correctness" begin
# Finite difference methods to test.
Expand Down Expand Up @@ -162,4 +169,11 @@
end
end
end

@testset "do not require f::Function" begin
x = 0.7
for f in [forward_fdm, central_fdm, backward_fdm]
f(5, 1)(NotAFunction(), x) ≈ 2 * x
end
end
end