Skip to content

Commit

Permalink
Add rule for with_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Nov 3, 2023
1 parent b8adca6 commit 6d8616d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl")
include("rulesets/Base/sort.jl")
include("rulesets/Base/mapreduce.jl")
include("rulesets/Base/broadcast.jl")
include("rulesets/Base/CoreLogging.jl")

include("rulesets/Distributed/nondiff.jl")

Expand Down
20 changes: 20 additions & 0 deletions src/rulesets/Base/CoreLogging.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)

function rrule(
rc::RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.CoreLogging.with_logger),
f::Function,
logger::Base.CoreLogging.AbstractLogger
)
y, f_pb = Base.CoreLogging.with_logger(logger) do
rrule_via_ad(rc, f)
end
with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent())
return y, with_logger_pullback
end

@non_differentiable Base.CoreLogging.current_logger(args...)
@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
@non_differentiable Base.CoreLogging.handle_message(::Any...)
4 changes: 0 additions & 4 deletions src/rulesets/Base/nondiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,6 @@ end
@non_differentiable Broadcast.result_style(::Any)
@non_differentiable Broadcast.result_style(::Any, ::Any)

@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
@non_differentiable Base.CoreLogging.handle_message(::Any...)

@non_differentiable Libc.free(::Any)
@non_differentiable Libc.getpid()
Expand Down
11 changes: 11 additions & 0 deletions test/rulesets/Base/CoreLogging.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)
@testset "CoreLogging.jl" begin
@testset "with_logger" begin
test_rrule(
Base.CoreLogging.with_logger,
()->2.0 * 3.0,
Base.CoreLogging.NullLogger();
check_inferred=false
)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ end
test_method_tables() # Check the global method tables are consistent

# Each file puts all tests inside one or more @testset blocks
include_test("rulesets/CoreLogging.jl")
include_test("rulesets/Base/base.jl")
include_test("rulesets/Base/fastmath_able.jl")
include_test("rulesets/Base/evalpoly.jl")
Expand Down

0 comments on commit 6d8616d

Please sign in to comment.