From 23269d8faa156a3900b103a028cdec3eafb375f3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 14:27:10 +0530 Subject: [PATCH 1/4] feat: add `Code.create_array` method for `TrackedArray` in ReverseDiffExt --- Project.toml | 5 +++++ ext/SymbolicUtilsReverseDiffExt.jl | 10 ++++++++++ src/SymbolicUtils.jl | 2 ++ 3 files changed, 17 insertions(+) create mode 100644 ext/SymbolicUtilsReverseDiffExt.jl diff --git a/Project.toml b/Project.toml index bcf4758ef..9d2b7755f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "3.5.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" @@ -28,12 +29,15 @@ Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" [weakdeps] LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] SymbolicUtilsLabelledArraysExt = "LabelledArrays" +SymbolicUtilsReverseDiffExt = "ReverseDiff" [compat] AbstractTrees = "0.4" +ArrayInterface = "7.8" Bijections = "0.1.2" ChainRulesCore = "1" Combinatorics = "1.0" @@ -45,6 +49,7 @@ IfElse = "0.1" LabelledArrays = "1.5" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1" +ReverseDiff = "1" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" diff --git a/ext/SymbolicUtilsReverseDiffExt.jl b/ext/SymbolicUtilsReverseDiffExt.jl new file mode 100644 index 000000000..f1c3b3670 --- /dev/null +++ b/ext/SymbolicUtilsReverseDiffExt.jl @@ -0,0 +1,10 @@ +module SymbolicUtilsReverseDiffExt + +using ReverseDiff +using SymbolicUtils + +@inline function SymbolicUtils.Code.create_array(::Type{<:ReverseDiff.TrackedArray}, T, v1::Val, v2::Val{dims}, elems...) where dims + SymbolicUtils.ArrayInterface.aos_to_soa(SymbolicUtils.Code.create_array(Array, T, v1, v2, elems...)) +end + +end diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 32e94ac18..2bf52507e 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -18,6 +18,8 @@ using ConstructionBase using TermInterface import TermInterface: iscall, isexpr, head, children, operation, arguments, metadata, maketerm, sorted_arguments +# For ReverseDiffExt +import ArrayInterface Base.@deprecate istree iscall export istree, operation, arguments, sorted_arguments, similarterm, iscall From 5f5533b8b29cb04d4a81975e3a37d7313162f26e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 14:34:04 +0530 Subject: [PATCH 2/4] test: test ReverseDiffExt --- test/code.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/code.jl b/test/code.jl index 7956aa59b..c05200167 100644 --- a/test/code.jl +++ b/test/code.jl @@ -5,6 +5,7 @@ using SymbolicUtils.Code: LazyState using StaticArrays using LabelledArrays using SparseArrays +using ReverseDiff using LinearAlgebra test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b)) @@ -158,6 +159,17 @@ nanmath_st.rewrites[:nanmath] = true @test eval(toexpr(Let([a ← 1, b ← 2, arr ← @SLVector((:a, :b))(@SVector[1,2])], MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2]) + trackedarr = eval(toexpr(Let([a ← ReverseDiff.track(1.0), b ← 2, arr ← ReverseDiff.track(ones(2))], + MakeArray([a+b,a/b], arr)))) + @test trackedarr isa ReverseDiff.TrackedArray + @test trackedarr == [3, 1/2] + + trackedarr = eval(toexpr(Let([a ← ReverseDiff.track(1.0), b ← 2, arr ← ReverseDiff.track(ones(2))], + MakeArray([a b; a+b a/b], arr)))) + @test trackedarr isa ReverseDiff.TrackedArray + @test trackedarr == [1 2; 3 1/2] + + R1 = eval(toexpr(Let([a ← 1, b ← 2, arr ← @MVector([1,2])],MakeArray([a,b,a+b,a/b], arr)))) @test R1 == (@MVector [1, 2, 3, 1/2]) && R1 isa MVector From 58f3e46203882d611da667f14047bb10c8133be6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 15:04:25 +0530 Subject: [PATCH 3/4] test: include ReverseDiff as test dependency --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9d2b7755f..6c02fbf2e 100644 --- a/Project.toml +++ b/Project.toml @@ -67,8 +67,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"] +test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "Test", "Zygote"] From b409ba659eefea00c6ce3f2f71976d2597a06c87 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 3 Sep 2024 08:22:27 -0400 Subject: [PATCH 4/4] Update benchmark_pr.yml --- .github/workflows/benchmark_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 20a5467ba..ba2e5f7e9 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: "1.8" + version: "1" - uses: julia-actions/cache@v1 - name: Extract Package Name from Project.toml id: extract-package-name