From f7ef86e210f2d27badf83c354fd28b44a5b30878 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Fri, 22 Dec 2023 17:01:28 -0500 Subject: [PATCH 1/2] add open_parenthese and merge_prod --- src/FeynmanDiagram.jl | 4 +- src/computational_graph/ComputationalGraph.jl | 2 +- src/computational_graph/transform.jl | 121 ++++++++++++++++++ 3 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index aed2d068..b78b1dd8 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains - +export open_parenthesis, flatten_prod!, flatten_prod export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves! include("TaylorSeries/TaylorSeries.jl") @@ -165,7 +165,7 @@ export evalNaive, showTree include("utility.jl") using .Utility export Utility -export taylorexpansion! +export taylorexpansion!, flatten include("frontend/frontends.jl") using .FrontEnds diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 97c2bffb..e565764e 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -50,7 +50,7 @@ export eval! include("transform.jl") export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains - +export open_parenthesis, flatten_prod!, flatten_prod include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 3c5c6380..576a47c8 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -155,6 +155,127 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end +function open_parenthesis(graph::AbstractGraph) + if isempty(graph.subgraphs) + return deepcopy(graph) + else + children = [] + for sub in graph.subgraphs + push!(children, open_parenthesis(sub)) + end + newnode = Graph([]; operator=Sum()) + if graph.operator == Sum + # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) + push!(newnode.subgraphs, child) + push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newnode.subgraphs, grandchild) + push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + end + elseif graph.operator == Prod + # When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator. + childsub_len = [length(child.subgraphs) for child in children] + ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0 + for indices in collect(Iterators.product(ordtuple...)) #Indices for all combination of grandchilds, with one from each child. + newchildnode = Graph([]; operator=Prod()) + for (child_idx, grandchild_idx) in enumerate(indices) + child = children[child_idx] + if grandchild_idx == 0 #Meaning this node is a leaf node + push!(newchildnode.subgraphs, child) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx]) + else + push!(newchildnode.subgraphs, child.subgraphs[grandchild_idx]) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + push!(newnode.subgraphs, newchildnode) + push!(newnode.subgraph_factors, 1.0) + end + end + return newnode + end +end + +function flatten_prod!(graph::AbstractGraph) + if isempty(graph.subgraphs) + return graph + else + children = [] + for sub in graph.subgraphs + push!(children, flatten_prod!(sub)) + end + newchildren = [] + newfactors = [] + if graph.operator == Sum + for (child_idx, child) in enumerate(children) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + end + elseif graph.operator == Prod + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) || child.operator == Sum + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newchildren, grandchild) + if grandchild_idx == 1 + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + else + push!(newfactors, child.subgraph_factors[grandchild_idx]) + end + end + end + end + end + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + return graph + end +end + +function flatten_prod(graph::AbstractGraph) + flatten_prod!(deepcopy(graph)) +end +# function flatten_prod(graph::AbstractGraph) +# if isempty(graph.subgraphs) +# return deepcopy(graph) +# else +# children = [] +# for sub in graph.subgraphs +# push!(children, merge_prod(sub)) +# end +# newnode = Graph([]; operator=Sum()) +# if graph.operator == Sum +# for (child_idx, child) in enumerate(children) +# push!(newnode.subgraphs, child) +# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) +# end +# elseif graph.operator == Prod +# newnode.operator = Prod() +# for (child_idx, child) in enumerate(children) +# if isempty(child.subgraphs) || child.operator == Sum +# push!(newnode.subgraphs, child) +# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) +# else +# for (grandchild_idx, grandchild) in enumerate(child.subgraphs) +# push!(newnode.subgraphs, grandchild) +# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) +# end +# end +# end +# end +# return newnode +# end +# end + + + """ function flatten_chains!(g::AbstractGraph) From 2a4419aa23bf233bdffe269fa999ca5c4c4bbf61 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Fri, 22 Dec 2023 17:08:15 -0500 Subject: [PATCH 2/2] minor change --- src/FeynmanDiagram.jl | 2 +- src/computational_graph/transform.jl | 33 ---------------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index 084acc6a..38a410a6 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -165,7 +165,7 @@ export evalNaive, showTree include("utility.jl") using .Utility export Utility -export taylorexpansion!, flatten +export taylorexpansion! include("frontend/frontends.jl") using .FrontEnds diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 615276eb..17372852 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -242,39 +242,6 @@ end function flatten_prod(graph::AbstractGraph) flatten_prod!(deepcopy(graph)) end -# function flatten_prod(graph::AbstractGraph) -# if isempty(graph.subgraphs) -# return deepcopy(graph) -# else -# children = [] -# for sub in graph.subgraphs -# push!(children, merge_prod(sub)) -# end -# newnode = Graph([]; operator=Sum()) -# if graph.operator == Sum -# for (child_idx, child) in enumerate(children) -# push!(newnode.subgraphs, child) -# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) -# end -# elseif graph.operator == Prod -# newnode.operator = Prod() -# for (child_idx, child) in enumerate(children) -# if isempty(child.subgraphs) || child.operator == Sum -# push!(newnode.subgraphs, child) -# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) -# else -# for (grandchild_idx, grandchild) in enumerate(child.subgraphs) -# push!(newnode.subgraphs, grandchild) -# push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) -# end -# end -# end -# end -# return newnode -# end -# end - - """ function flatten_chains!(g::AbstractGraph)