From 39d79865c00f2c39053049785fe955a3db86564b Mon Sep 17 00:00:00 2001 From: houpc Date: Wed, 25 Oct 2023 22:39:09 +0800 Subject: [PATCH 1/3] add flatten_chains and the relevant tests --- src/computational_graph/ComputationalGraph.jl | 4 +- src/computational_graph/optimize.jl | 49 +++++++++++- src/computational_graph/transform.jl | 42 ++++++++++ test/computational_graph.jl | 80 +++++++++++++++++++ 4 files changed, 171 insertions(+), 4 deletions(-) diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 2ea2379c..fcfa7c1e 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -47,8 +47,8 @@ include("eval.jl") export eval! include("transform.jl") -export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! -export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains +export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! +export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index b289d58f..eb856e78 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -17,7 +17,8 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose else graphs = collect(graphs) leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize) - merge_all_chains!(graphs, verbose=verbose) + # merge_all_chains!(graphs, verbose=verbose) + flatten_all_chains!(graphs, verbose=verbose) merge_all_linear_combinations!(graphs, verbose=verbose) return leaf_mapping end @@ -179,6 +180,50 @@ function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) return graphs end +""" + function flatten_all_chains!(g::AbstractGraph; verbose=0) + + In-place flattens all nodes representing trivial unary chains in the given graph `g`. + +# Arguments: +- `graphs`: The graph to be processed. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- The mutated graph `g` with all chains flattened. +""" +function flatten_all_chains!(g::AbstractGraph; verbose=0) + verbose > 0 && println("flatten all nodes representing trivial unary chains.") + for sub_g in g.subgraphs + flatten_all_chains!(sub_g) + flatten_chains!(sub_g) + end + flatten_chains!(g) + return g +end + +""" + function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) + + In-place flattens all nodes representing trivial unary chains in given graphs. + +# Arguments: +- `graphs`: A collection of graphs to be processed. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- The mutated collection `graphs` with all chains in each graph flattened. +""" +function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) + verbose > 0 && println("flatten all nodes representing trivial unary chains.") + # Post-order DFS + for g in graphs + flatten_all_chains!(g.subgraphs) + flatten_chains!(g) + end + return graphs +end + """ function merge_all_linear_combinations!(g::AbstractGraph; verbose=0) @@ -263,7 +308,7 @@ end - Optimized graphs. # """ -function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0) +function merge_all_multi_products!(graphs::AbstractVector{<:Graph}; verbose=0) verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.") # Post-order DFS for g in graphs diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index c5040715..7c291ca4 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -170,6 +170,13 @@ end - `g::AbstractGraph`: graph to be modified """ function merge_factorless_chain!(g::AbstractGraph) + if unary_istrivial(g.operator) && onechild(g) && isfactorless(g) + child = eldest(g) + for field in fieldnames(typeof(g)) + value = getproperty(child, field) + setproperty!(g, field, value) + end + end while unary_istrivial(g.operator) && onechild(g) && isfactorless(g) child = eldest(g) for field in fieldnames(typeof(g)) @@ -278,6 +285,41 @@ end """ merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g)) +""" + function flatten_chains!(g::AbstractGraph) + + Recursively flattens chains of subgraphs within the given graph `g` by merging certain trivial unary subgraphs + into their parent graphs in the in-place form. + + Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), + where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" +function flatten_chains!(g::AbstractGraph) + for (i, sub_g) in enumerate(g.subgraphs) + if onechild(sub_g) && unary_istrivial(sub_g.operator) + flatten_chains!(sub_g) + + g.subgraph_factors[i] *= sub_g.subgraph_factors[1] + g.subgraphs[i] = eldest(sub_g) + end + end + return g +end + +""" + function flatten_chains(g::AbstractGraph) + + Recursively flattens chains of subgraphs within a given graph `g` by merging certain trivial unary subgraphs into their parent graphs, + This function returns a new graph with flatten chains, dervied from the input graph `g` remaining unchanged. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" +flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g)) + """ function merge_linear_combination(g::Graph) diff --git a/test/computational_graph.jl b/test/computational_graph.jl index a1b63e61..e56643a3 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -211,6 +211,32 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true merge_multi_product!(h1) @test isequiv(h1, h1_mp, :id) end + @testset "Flatten chains" begin + l0 = Graph([]) + l1 = Graph([l0]; subgraph_factors=[2]) + g1 = Graph([l1]; subgraph_factors=[-1], operator=O()) + g1c = deepcopy(g1) + g2 = 2 * g1 + g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) + g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) + r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod()) + r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod()) + r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O()) + rvec = deepcopy([r1, r2, r3]) + Graphs.flatten_chains!(r1) + @test isequiv(g1, g1c, :id) + @test isequiv(r1, 210g1, :id) + @test isequiv(g2, 2g1, :id) + @test isequiv(g3, 6g1, :id) + @test isequiv(g4, 30g1, :id) + Graphs.flatten_chains!(r2) + @test isequiv(r2, -30g1, :id) + Graphs.flatten_chains!(r3) + @test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id) + @test r1 == Graphs.flatten_chains(rvec[1]) + @test r2 == Graphs.flatten_chains(rvec[2]) + @test r3 == Graphs.flatten_chains(rvec[3]) + end end @testset verbose = true "Optimizations" begin @testset "Remove one-child parents" begin @@ -247,6 +273,40 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true @test h.subgraph_factors == [105] @test eldest(h) == h0 end + @testset "Flatten all chains" begin + l0 = Graph([]) + l1 = Graph([l0]; subgraph_factors=[2]) + l2 = Graph([]; factor=3) + g1 = Graph([l1, l2]; subgraph_factors=[-1, 1]) + g2 = 2 * g1 + g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) + g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) + r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod()) + r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod()) + r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O()) + rvec = deepcopy([r1, r2, r3]) + rvec1 = deepcopy([r1, r2, r3]) + Graphs.flatten_all_chains!(r1) + @test isequiv(g1, Graph([l0, l2]; subgraph_factors=[-2, 1]), :id) + @test isequiv(r1, 210g1, :id) + @test isequiv(g2, 2g1, :id) + @test isequiv(g3, 6g1, :id) + @test isequiv(g4, 30g1, :id) + Graphs.flatten_all_chains!(r2) + @test isequiv(r2, -30g1, :id) + Graphs.flatten_all_chains!(r3) + @test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id) + Graphs.flatten_all_chains!(rvec) + @test rvec == [r1, r2, r3] + + Graphs.merge_all_chains!(rvec1) + @test rvec1[1].subgraph_factors == [210] + @test eldest(rvec1[1]) == g1 + @test rvec1[2].subgraph_factors == [-1] + @test eldest(rvec1[2]) == g1 # BUG! + @test rvec1[3].subgraph_factors == [2, 7] + @test rvec1[3].subgraphs == [g1, g1] # BUG! + end @testset "merge all linear combinations" begin g1 = Graph([]) g2 = 2 * g1 @@ -265,6 +325,26 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true Graphs.merge_all_linear_combinations!(h0) @test isequiv(h0.subgraphs[1], _h, :id) end + @testset "Merge all multi prodicts" begin + g1 = Graph([]) + g2 = Graph([], factor=2) + g3 = Graph([], factor=3) + h = Graph([g1, g2, g1, g1, g3, g2]; subgraph_factors=[3, 2, 5, 1, 1, 3], operator=Graphs.Prod()) + hvec = repeat([deepcopy(h)], 3) + h0 = Graph([deepcopy(h), g2]) + h_s1 = Graph([g1], operator=Graphs.Power(3)) + h_s2 = Graph([g2], operator=Graphs.Power(2)) + _h = Graph([h_s1, h_s2, g3], subgraph_factors=[15, 6, 1], operator=Graphs.Prod()) + # Test on a single graph + Graphs.merge_all_multi_products!(h) + @test isequiv(h, _h, :id) + # Test on a vector of graphs + Graphs.merge_all_multi_products!(hvec) + @test all(isequiv(h, _h, :id) for h in hvec) + + Graphs.merge_all_multi_products!(h0) + @test isequiv(h0.subgraphs[1], _h, :id) + end @testset "optimize" begin g1 = Graph([]) g2 = 2 * g1 From dafa0bd54c84f5d7133eef21d20a82a5316696f3 Mon Sep 17 00:00:00 2001 From: houpc Date: Thu, 26 Oct 2023 22:59:47 +0800 Subject: [PATCH 2/3] delete all the functions relevant to merge_all_chains. --- src/computational_graph/optimize.jl | 141 +----------------------- src/computational_graph/transform.jl | 128 ---------------------- test/computational_graph.jl | 154 --------------------------- 3 files changed, 2 insertions(+), 421 deletions(-) diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index eb856e78..206fe576 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -17,7 +17,6 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose else graphs = collect(graphs) leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize) - # merge_all_chains!(graphs, verbose=verbose) flatten_all_chains!(graphs, verbose=verbose) merge_all_linear_combinations!(graphs, verbose=verbose) return leaf_mapping @@ -44,146 +43,10 @@ function optimize(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose= return graphs_new, leaf_mapping end -""" - function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0) - - In-place merge prefactors of all nodes representing trivial unary chains towards the root level for a single graph. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") - # Post-order DFS - for sub_g in g.subgraphs - merge_all_chain_prefactors!(sub_g) - merge_chain_prefactors!(sub_g) - end - merge_chain_prefactors!(g) - return g -end - -""" - function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - - In-place merge prefactors of all nodes representing trivial unary chains towards the root level for given graphs. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") - # Post-order DFS - for g in graphs - merge_all_chain_prefactors!(g.subgraphs) - merge_chain_prefactors!(g) - end - return graphs -end - -""" - function merge_all_factorless_chains!(g::AbstractGraph; verbose=0) - - In-place merge all nodes representing factorless trivial unary chains within a single graph. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_factorless_chains!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") - # Post-order DFS - for sub_g in g.subgraphs - merge_all_factorless_chains!(sub_g) - merge_factorless_chain!(sub_g) - end - merge_factorless_chain!(g) - return g -end - -""" - function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - - In-place merge all nodes representing factorless trivial unary chains within given graphs. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") - # Post-order DFS - for g in graphs - merge_all_factorless_chains!(g.subgraphs) - merge_factorless_chain!(g) - end - return graphs -end - -""" - function merge_all_chains!(g::AbstractGraph; verbose=0) - - In-place merge all nodes representing trivial unary chains within a single graph. - This function consolidates both chain prefactors and factorless chains. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_chains!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge all nodes representing trivial unary chains.") - merge_all_chain_prefactors!(g, verbose=verbose) - merge_all_factorless_chains!(g, verbose=verbose) - return g -end - -""" - function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) where {G<:AbstractGraph} - - In-place merge all nodes representing trivial unary chains in given graphs. - This function consolidates both chain prefactors and factorless chains. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge all nodes representing trivial unary chains.") - merge_all_chain_prefactors!(graphs, verbose=verbose) - merge_all_factorless_chains!(graphs, verbose=verbose) - return graphs -end - """ function flatten_all_chains!(g::AbstractGraph; verbose=0) - - In-place flattens all nodes representing trivial unary chains in the given graph `g`. +F + Flattens all nodes representing trivial unary chains in-place in the given graph `g`. # Arguments: - `graphs`: The graph to be processed. diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 7c291ca4..a0683380 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -157,134 +157,6 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end -""" - function merge_factorless_chain!(g::AbstractGraph) - - Simplifies `g` in-place if it represents a factorless trivial unary chain. For example, +(+(+g)) ↦ g. - - Does nothing unless g has the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, - a node with non-unity multiplicative prefactor, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_factorless_chain!(g::AbstractGraph) - if unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - child = eldest(g) - for field in fieldnames(typeof(g)) - value = getproperty(child, field) - setproperty!(g, field, value) - end - end - while unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - child = eldest(g) - for field in fieldnames(typeof(g)) - value = getproperty(child, field) - setproperty!(g, field, value) - end - end - return g -end - -""" - function merge_factorless_chain(g::AbstractGraph) - - Returns a simplified copy of `g` if it represents a factorless trivial unary chain. - Otherwise, returns the original graph. For example, +(+(+g)) ↦ g. - - Does nothing unless g has the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, - a node with non-unity multiplicative prefactor, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_factorless_chain(g::AbstractGraph) - while unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - g = eldest(g) - end - return g -end - -""" - function merge_chain_prefactors!(g::AbstractGraph) - - Simplifies subgraphs of g representing trivial unary chains by merging their - subgraph factors toward root level, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*(*(*g)) + 63*(*h). - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_chain_prefactors!(g::AbstractGraph) - for (i, child) in enumerate(g.subgraphs) - total_chain_factor = 1 - while onechild(child) - # Break case: end of trivial unary chain - unary_istrivial(child.operator) == false && break - # Move this subfactor to running total - total_chain_factor *= child.subgraph_factors[1] - child.subgraph_factors[1] = 1 - # Descend one level - child = eldest(child) - end - # Update g subfactor with total factors from children - g.subgraph_factors[i] *= total_chain_factor - end - return g -end - -""" - function merge_chain_prefactors(g::AbstractGraph) - - Returns a copy of g with subgraphs representing trivial unary chains simplified by merging - their subgraph factors toward root level, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*(*(*g)) + 63*(*h). - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -merge_chain_prefactors(g::AbstractGraph) = merge_chain_prefactors!(deepcopy(g)) - -""" - function merge_chains!(g::AbstractGraph) - - Converts subgraphs of g representing trivial unary chains - to in-place form, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*g + 63*h. - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_chains!(g::AbstractGraph) - merge_chain_prefactors!(g) # shift chain subgraph factors towards root level - for sub_g in g.subgraphs # prune factorless chain subgraphs - merge_factorless_chain!(sub_g) - end - return g -end - -""" - function merge_chains(g::AbstractGraph) - - Returns a copy of a graph g with subgraphs representing trivial unary chain - simplified to in-place form, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*g + 63*h. - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g)) - """ function flatten_chains!(g::AbstractGraph) diff --git a/test/computational_graph.jl b/test/computational_graph.jl index e56643a3..37ab49bf 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -111,14 +111,8 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true g4p = Graph([g3p,]; operator=Graphs.Sum()) @test Graphs.unary_istrivial(Graphs.Prod) @test Graphs.unary_istrivial(Graphs.Sum) - @test Graphs.merge_factorless_chain(g2) == g1 - @test Graphs.merge_factorless_chain(g3) == g1 - @test Graphs.merge_factorless_chain(g4) == g1 - @test Graphs.merge_factorless_chain(g3p) == g3p - @test Graphs.merge_factorless_chain(g4p) == g3p g5 = Graph([g1,]; operator=O()) @test Graphs.unary_istrivial(O) == false - @test Graphs.merge_factorless_chain(g5) == g5 end g1 = Graph([]) g2 = Graph([g1,]; subgraph_factors=[5,], operator=Graphs.Prod()) @@ -130,43 +124,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true g2p = Graph([g1, g2]; operator=Graphs.Sum()) g3p = Graph([g2p,]; subgraph_factors=[3,], operator=Graphs.Prod()) gp = Graph([g3p,]; subgraph_factors=[2,], operator=Graphs.Prod()) - @testset "Merge chains" begin - # g ↦ 30*(*(*g1)) - g_merged = Graphs.merge_chain_prefactors(g) - @test g_merged.subgraph_factors == [30,] - @test all(isfactorless(node) for node in PreOrderDFS(eldest(g_merged))) - # in-place form - gc = deepcopy(g) - Graphs.merge_chain_prefactors!(gc) - @test isequiv(gc, g_merged, :id) - # gp ↦ 6*(*(g1 + 5*g1)) - gp_merged = Graphs.merge_chain_prefactors(gp) - @test gp_merged.subgraph_factors == [6,] - @test isfactorless(eldest(gp)) == false - @test isfactorless(eldest(gp_merged)) - @test eldest(eldest(gp_merged)) == g2p - # g ↦ 30*g1 - g_merged = merge_chains(g) - @test isequiv(g_merged, 30 * g1, :id) - # in-place form - merge_chains!(g) - @test isequiv(g, 30 * g1, :id) - # gp ↦ 6*(g1 + 5*g1) - gp_merged = merge_chains(gp) - @test isequiv(gp_merged, 6 * g2p, :id) - # Test a generic trivial unary chain - # *(O3(5 * O2(3 * O1(2 * h)))) ↦ 30 * h - h = Graph([]) - h1 = Graph([h,]; subgraph_factors=[2,], operator=O1()) - h2 = Graph([h1,]; subgraph_factors=[3,], operator=O2()) - h3 = Graph([h2,]; subgraph_factors=[5,], operator=O3()) - h4 = Graph([h3,]; operator=Graphs.Prod()) - h4_merged = merge_chains(h4) - @test isequiv(h4_merged, 30 * h, :id) - # in-place form - merge_chains!(h4) - @test isequiv(h4, 30 * h, :id) - end @testset "Merge prefactors" begin g1 = propagator(𝑓⁺(1)𝑓⁻(2)) h1 = FeynmanGraph([g1, g1], drop_topology(g1.properties); subgraph_factors=[1, 2], operator=Graphs.Sum()) @@ -239,40 +196,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true end end @testset verbose = true "Optimizations" begin - @testset "Remove one-child parents" begin - # h = O(7 * (5 * (3 * (2 * g)))) ↦ O(210 * g) - g1 = Graph([]) - g2 = 2 * g1 - g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) - h = Graph([g4,]; subgraph_factors=[7,], operator=O()) - hvec = repeat([deepcopy(h)], 3) - # Test on a single graph - Graphs.merge_all_chains!(h) - @test h.operator == O - @test h.subgraph_factors == [210,] - @test eldest(h) == g1 - # Test on a vector of graphs - Graphs.merge_all_chains!(hvec) - @test all(h.operator == O for h in hvec) - @test all(h.subgraph_factors == [210,] for h in hvec) - @test all(eldest(h) == g1 for h in hvec) - - g2 = 2 * g1 - g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) - h0 = Graph([g1, g4]; subgraph_factors=[2, 7], operator=O()) - Graphs.merge_all_chains!(h0) - @test h0.subgraph_factors == [2, 210] - @test h0.subgraphs[2] == g1 - - h1 = Graph([h0]; subgraph_factors=[3,], operator=Graphs.Prod()) - h2 = Graph([h1]; subgraph_factors=[5,], operator=Graphs.Prod()) - h = Graph([h2]; subgraph_factors=[7,], operator=O()) - Graphs.merge_all_chains!(h) - @test h.subgraph_factors == [105] - @test eldest(h) == h0 - end @testset "Flatten all chains" begin l0 = Graph([]) l1 = Graph([l0]; subgraph_factors=[2]) @@ -298,14 +221,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true @test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id) Graphs.flatten_all_chains!(rvec) @test rvec == [r1, r2, r3] - - Graphs.merge_all_chains!(rvec1) - @test rvec1[1].subgraph_factors == [210] - @test eldest(rvec1[1]) == g1 - @test rvec1[2].subgraph_factors == [-1] - @test eldest(rvec1[2]) == g1 # BUG! - @test rvec1[3].subgraph_factors == [2, 7] - @test rvec1[3].subgraphs == [g1, g1] # BUG! end @testset "merge all linear combinations" begin g1 = Graph([]) @@ -540,14 +455,8 @@ end g4p = FeynmanGraph([g3p,], drop_topology(g3p.properties); operator=Graphs.Sum()) @test Graphs.unary_istrivial(Graphs.Prod) @test Graphs.unary_istrivial(Graphs.Sum) - @test Graphs.merge_factorless_chain(g2) == g1 - @test Graphs.merge_factorless_chain(g3) == g1 - @test Graphs.merge_factorless_chain(g4) == g1 - @test Graphs.merge_factorless_chain(g3p) == g3p - @test Graphs.merge_factorless_chain(g4p) == g3p g5 = FeynmanGraph([g1,], drop_topology(g1.properties); operator=O()) @test Graphs.unary_istrivial(O) == false - @test Graphs.merge_factorless_chain(g5) == g5 end g1 = propagator(𝑓⁻(1)𝑓⁺(2)) g2 = FeynmanGraph([g1,], g1.properties; subgraph_factors=[5,], operator=Graphs.Prod()) @@ -559,43 +468,6 @@ end g2p = FeynmanGraph([g1, g2], drop_topology(g1.properties)) g3p = FeynmanGraph([g2p,], g2p.properties; subgraph_factors=[3,], operator=Graphs.Prod()) gp = FeynmanGraph([g3p,], g3p.properties; subgraph_factors=[2,], operator=Graphs.Prod()) - @testset "Merge chains" begin - # g ↦ 30*(*(*g1)) - g_merged = Graphs.merge_chain_prefactors(g) - @test g_merged.subgraph_factors == [30,] - @test all(isfactorless(node) for node in PreOrderDFS(eldest(g_merged))) - # in-place form - gc = deepcopy(g) - Graphs.merge_chain_prefactors!(gc) - @test isequiv(gc, g_merged, :id) - # gp ↦ 6*(*(g1 + 5*g1)) - gp_merged = Graphs.merge_chain_prefactors(gp) - @test gp_merged.subgraph_factors == [6,] - @test isfactorless(eldest(gp)) == false - @test isfactorless(eldest(gp_merged)) - @test isequiv(eldest(eldest(gp_merged)), g2p, :id) - # g ↦ 30*g1 - g_merged = merge_chains(g) - @test isequiv(g_merged, 30 * g1, :id) - # in-place form - merge_chains!(g) - @test isequiv(g, 30 * g1, :id) - # gp ↦ 6*(g1 + 5*g1) - gp_merged = merge_chains(gp) - @test isequiv(gp_merged, 6 * g2p, :id) - # Test a generic trivial unary chain - # *(O3(5 * O2(3 * O1(2 * h)))) ↦ 30 * h - h = propagator(𝑓⁻(1)𝑓⁺(2)) - h1 = FeynmanGraph([h,], h.properties; subgraph_factors=[2,], operator=O1()) - h2 = FeynmanGraph([h1,], h1.properties; subgraph_factors=[3,], operator=O2()) - h3 = FeynmanGraph([h2,], h2.properties; subgraph_factors=[5,], operator=O3()) - h4 = FeynmanGraph([h3,], h3.properties; operator=Graphs.Prod()) - h4_merged = merge_chains(h4) - @test isequiv(h4_merged, 30 * h, :id) - # in-place form - merge_chains!(h4) - @test isequiv(h4, 30 * h, :id) - end @testset "Merge prefactors" begin g1 = propagator(𝑓⁺(1)𝑓⁻(2)) h1 = FeynmanGraph([g1, g1], drop_topology(g1.properties), subgraph_factors=[1, 2]) @@ -630,32 +502,6 @@ end end @testset verbose = true "Optimizations" begin - @testset "Remove one-child parents" begin - g1 = propagator(𝑓⁻(1)𝑓⁺(2)) - g2 = 2 * g1 - # h = O(7 * (5 * (3 * (2 * g)))) ↦ O(210 * g) - g3 = FeynmanGraph([g2,], g2.properties; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = FeynmanGraph([g3,], g3.properties; subgraph_factors=[5,], operator=Graphs.Prod()) - h = FeynmanGraph([g4,], drop_topology(g4.properties); subgraph_factors=[7,], operator=O()) - hvec = repeat([h], 3) - # Test on a single graph - Graphs.merge_all_chains!(h) - @test h.operator == O - @test h.subgraph_factors == [210,] - @test isequiv(eldest(h), g1, :id) - # Test on a vector of graphs - Graphs.merge_all_chains!(hvec) - @test all(h.operator == O for h in hvec) - @test all(h.subgraph_factors == [210,] for h in hvec) - @test all(isequiv(eldest(h), g1, :id) for h in hvec) - - g2 = 2 * g1 - g3 = FeynmanGraph([g2,], g2.properties; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = FeynmanGraph([g3,], g3.properties; subgraph_factors=[5,], operator=Graphs.Prod()) - h = FeynmanGraph([g1, g4], drop_topology(g4.properties); subgraph_factors=[2, 7], operator=O()) - Graphs.merge_all_chains!(h) - @test h.subgraph_factors == [2, 210] - end @testset "optimize" begin g1 = propagator(𝑓⁻(1)𝑓⁺(2)) g2 = 2 * g1 From b78ca2a55ae28469c4ddecac9a4bea5fc3435192 Mon Sep 17 00:00:00 2001 From: houpc Date: Thu, 26 Oct 2023 23:06:51 +0800 Subject: [PATCH 3/3] graphs::Union{Tuple,AbstractVector{<:AbstractGraph}} for optimized functions --- src/computational_graph/optimize.jl | 42 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index 206fe576..2ceb47b3 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -26,7 +26,7 @@ end """ function optimize(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing) - Optimize a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations. + Optimizes a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations. # Arguments: - `graphs`: A tuple or vector of graphs. @@ -66,9 +66,9 @@ function flatten_all_chains!(g::AbstractGraph; verbose=0) end """ - function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) + function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) - In-place flattens all nodes representing trivial unary chains in given graphs. + Flattens all nodes representing trivial unary chains in-place in given graphs. # Arguments: - `graphs`: A collection of graphs to be processed. @@ -77,7 +77,7 @@ end # Returns: - The mutated collection `graphs` with all chains in each graph flattened. """ -function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) +function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) verbose > 0 && println("flatten all nodes representing trivial unary chains.") # Post-order DFS for g in graphs @@ -90,7 +90,7 @@ end """ function merge_all_linear_combinations!(g::AbstractGraph; verbose=0) - In-place merge all nodes representing a linear combination of a non-unique list of subgraphs within a single graph. + Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place within a single graph. # Arguments: - `g`: An AbstractGraph. @@ -112,19 +112,19 @@ function merge_all_linear_combinations!(g::AbstractGraph; verbose=0) end """ - function merge_all_linear_combinations!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) + function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) - In-place merge all nodes representing a linear combination of a non-unique list of subgraphs in given graphs. + Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in given graphs. # Arguments: -- `graphs`: An AbstractVector of graphs. +- `graphs`: A collection of graphs to be processed. - `verbose`: Level of verbosity (default: 0). # Returns: - Optimized graphs. # """ -function merge_all_linear_combinations!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) +function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.") # Post-order DFS for g in graphs @@ -137,7 +137,7 @@ end """ function merge_all_multi_products!(g::Graph; verbose=0) - In-place merge all nodes representing a multi product of a non-unique list of subgraphs within a single graph. + Merges all nodes representing a multi product of a non-unique list of subgraphs in-place within a single graph. # Arguments: - `g::Graph`: A Graph. @@ -159,19 +159,19 @@ function merge_all_multi_products!(g::Graph; verbose=0) end """ - function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0) + function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0) - In-place merge all nodes representing a multi product of a non-unique list of subgraphs in given graphs. + Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in given graphs. # Arguments: -- `graphs::Union{Tuple,AbstractVector{Graph}}`: A tuple or vector of graphs. +- `graphs`: A collection of graphs to be processed. - `verbose`: Level of verbosity (default: 0). # Returns: - Optimized graphs. # """ -function merge_all_multi_products!(graphs::AbstractVector{<:Graph}; verbose=0) +function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0) verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.") # Post-order DFS for g in graphs @@ -184,10 +184,10 @@ end """ function unique_leaves(_graphs::AbstractVector{<:AbstractGraph}) - Identify and retrieve unique leaf nodes from a set of graphs. + Identifies and retrieves unique leaf nodes from a set of graphs. # Arguments: -- `_graphs`: A tuple or vector of graphs. +- `_graphs`: A collection of graphs to be processed. # Returns: - The vector of unique leaf nodes. @@ -218,19 +218,19 @@ function unique_leaves(_graphs::AbstractVector{<:AbstractGraph}) end """ - function remove_duplicated_leaves!(graphs::AbstractVector{<:AbstractGraph}; verbose=0, normalize=nothing, kwargs...) + function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing, kwargs...) - In-place remove duplicated leaf nodes from a collection of graphs. It also provides optional normalization for these leaves. + Removes duplicated leaf nodes in-place from a collection of graphs. It also provides optional normalization for these leaves. # Arguments: -- `graphs`: An AbstractVector of graphs. +- `graphs`: A collection of graphs to be processed. - `verbose`: Level of verbosity (default: 0). - `normalize`: Optional function to normalize the graphs (default: nothing). # Returns: - A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)). """ -function remove_duplicated_leaves!(graphs::AbstractVector{<:AbstractGraph}; verbose=0, normalize=nothing, kwargs...) +function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing, kwargs...) verbose > 0 && println("remove duplicated leaves.") leaves = Vector{eltype(graphs)}() for g in graphs @@ -264,7 +264,7 @@ end """ function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph} - In-place remove all nodes connected to the target leaves via "Prod" operators. + Removes all nodes connected to the target leaves in-place via "Prod" operators. # Arguments: - `graphs`: An AbstractVector of graphs.