diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index 981ff311..38a983c1 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -117,8 +117,8 @@ export linear_combination, feynman_diagram, propagator, interaction, external_ve # export ๐บแถ , ๐บแต‡, ๐บแต , ๐‘Š, Green2, Interaction # export Coupling_yukawa, Coupling_phi3, Coupling_phi4, Coupling_phi6 export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest -export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_chains! -export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_chains +export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! +export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves! diff --git a/src/backend/static.jl b/src/backend/static.jl index 53cf49c2..4412fffc 100644 --- a/src/backend/static.jl +++ b/src/backend/static.jl @@ -63,10 +63,11 @@ function to_julia_str(graphs::AbstractVector{G}; root::AbstractVector{Int}=[g.id inds_visitednode = Int[] for graph in graphs for g in PostOrderDFS(graph) #leaf first search + target = "g$(g.id)" + isroot = false if g.id in root - target = "root[$(findfirst(x -> x == g.id, root))]" - else - target = "g$(g.id)" + target_root = "root[$(findfirst(x -> x == g.id, root))]" + isroot = true end if isempty(g.subgraphs) #leaf g.id in inds_visitedleaf && continue @@ -80,6 +81,9 @@ function to_julia_str(graphs::AbstractVector{G}; root::AbstractVector{Int}=[g.id body *= " $target = $(_to_static(g.operator, g.subgraphs, g.subgraph_factors))$factor_str\n " push!(inds_visitednode, g.id) end + if isroot + body *= " $target_root = $target\n " + end end end tail = "end" @@ -107,10 +111,11 @@ function to_julia_str(graphs::AbstractVector{G}, leafMap::Dict{Int,Int}; root::A inds_visitednode = Int[] for graph in graphs for g in PostOrderDFS(graph) #leaf first search + target = "g$(g.id)" + isroot = false if g.id in root - target = "root[$(findfirst(x -> x == g.id, root))]" - else - target = "g$(g.id)" + target_root = "root[$(findfirst(x -> x == g.id, root))]" + isroot = true end if isempty(g.subgraphs) #leaf g.id in inds_visitedleaf && continue @@ -123,6 +128,9 @@ function to_julia_str(graphs::AbstractVector{G}, leafMap::Dict{Int,Int}; root::A body *= " $target = $(_to_static(g.operator, g.subgraphs, g.subgraph_factors))$factor_str\n " push!(inds_visitednode, g.id) end + if isroot + body *= " $target_root = $target\n " + end end end tail = "end" diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index fc0fa049..c252261b 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -45,8 +45,8 @@ include("eval.jl") export eval! include("transform.jl") -export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_chains! -export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_chains +export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! +export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index 1efd49bc..65e94582 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -226,6 +226,53 @@ function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}}; return graphs end +""" + function merge_all_multi_products!(g::Graph; verbose=0) + + In-place merge all nodes representing a multi product of a non-unique list of subgraphs within a single graph. + +# Arguments: +- `g::Graph`: A Graph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" +function merge_all_multi_products!(g::Graph; verbose=0) + verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.") + # Post-order DFS + for sub_g in g.subgraphs + merge_all_multi_products!(sub_g) + merge_multi_product!(sub_g) + end + merge_multi_product!(g) + return g +end + +""" + function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0) + + In-place merge all nodes representing a multi product of a non-unique list of subgraphs in given graphs. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{Graph}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0) + verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.") + # Post-order DFS + for g in graphs + merge_all_multi_products!(g.subgraphs) + merge_multi_product!(g) + end + return graphs +end + """ function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph} diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index a1b3629e..c5040715 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -368,3 +368,63 @@ function merge_linear_combination!(g::FeynmanGraph{F,W}) where {F,W} end return g end + +""" + function merge_multi_product(g::Graph{F,W}) where {F,W} + + Merge multiple products within a computational graph `g` if they share the same operator (`Prod`). + If `g.operator == Prod`, this function will merge `N` identical subgraphs into a single subgraph with a power operator `Power(N)`. + The function ensures each unique subgraph is counted and merged appropriately, preserving any distinct subgraph_factors associated with them. + +# Arguments: +- `g::Graph`: graph to be modified + +# Returns +- A merged computational graph with potentially fewer subgraphs if there were repeating subgraphs + with the `Prod` operator. If the input graph's operator isn't `Prod`, the function returns the input graph unchanged. +""" +function merge_multi_product(g::Graph{F,W}) where {F,W} + if g.operator == Prod + unique_graphs = Vector{Graph{F,W}}() + unique_factors = F[] + repeated_counts = Int[] + for (idx, subg) in enumerate(g.subgraphs) + loc = findfirst(isequal(subg), unique_graphs) + if isnothing(loc) + push!(unique_graphs, subg) + push!(unique_factors, g.subgraph_factors[idx]) + push!(repeated_counts, 1) + else + unique_factors[loc] *= g.subgraph_factors[idx] + repeated_counts[loc] += 1 + end + end + + if length(unique_factors) == 1 + g_merged = Graph(unique_graphs; subgraph_factors=unique_factors, operator=Power(repeated_counts[1]), ftype=F, wtype=W) + else + subgraphs = Vector{Graph{F,W}}() + for (idx, g) in enumerate(unique_graphs) + if repeated_counts[idx] == 1 + push!(subgraphs, g) + else + push!(subgraphs, Graph([g], operator=Power(repeated_counts[idx]), ftype=F, wtype=W)) + end + end + g_merged = Graph(subgraphs; subgraph_factors=unique_factors, operator=Prod(), ftype=F, wtype=W) + end + return g_merged + else + return g + end +end + +function merge_multi_product!(g::Graph{F,W}) where {F,W} + if g.operator == Prod + g_merged = merge_multi_product(g) + g.subgraphs = g_merged.subgraphs + g.subgraph_factors = g_merged.subgraph_factors + g.operator = g_merged.operator + end + return g +end \ No newline at end of file diff --git a/test/computational_graph.jl b/test/computational_graph.jl index 166b02e8..84cd0f4a 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -198,6 +198,19 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true @test h8.subgraph_factors == [36] @test isequiv(h7_lc, h8, :id) end + @testset "Merge multi prodict" begin + g1 = Graph([]) + g2 = Graph([], factor=2) + g3 = Graph([], factor=3) + h1 = Graph([g1, g2, g1, g1, g3, g2]; subgraph_factors=[3, 2, 5, 1, 1, 3], operator=Graphs.Prod()) + h1_mp = merge_multi_product(h1) + h1_s1 = Graph([g1], operator=Graphs.Power(3)) + h1_s2 = Graph([g2], operator=Graphs.Power(2)) + h1_r = Graph([h1_s1, h1_s2, g3], subgraph_factors=[15, 6, 1], operator=Graphs.Prod()) + @test isequiv(h1_r, h1_mp, :id) + merge_multi_product!(h1) + @test isequiv(h1, h1_mp, :id) + end end @testset verbose = true "Optimizations" begin @testset "Remove one-child parents" begin