diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index 5ed4c1f2..38a410a6 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains - +export open_parenthesis, flatten_prod!, flatten_prod export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves! include("TaylorSeries/TaylorSeries.jl") diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 4a50a17f..3979f504 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -50,7 +50,7 @@ export eval! include("transform.jl") export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains - +export open_parenthesis, flatten_prod!, flatten_prod include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index fecf791a..17372852 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -155,6 +155,94 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end +function open_parenthesis(graph::AbstractGraph) + if isempty(graph.subgraphs) + return deepcopy(graph) + else + children = [] + for sub in graph.subgraphs + push!(children, open_parenthesis(sub)) + end + newnode = Graph([]; operator=Sum()) + if graph.operator == Sum + # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) + push!(newnode.subgraphs, child) + push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newnode.subgraphs, grandchild) + push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + end + elseif graph.operator == Prod + # When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator. + childsub_len = [length(child.subgraphs) for child in children] + ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0 + for indices in collect(Iterators.product(ordtuple...)) #Indices for all combination of grandchilds, with one from each child. + newchildnode = Graph([]; operator=Prod()) + for (child_idx, grandchild_idx) in enumerate(indices) + child = children[child_idx] + if grandchild_idx == 0 #Meaning this node is a leaf node + push!(newchildnode.subgraphs, child) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx]) + else + push!(newchildnode.subgraphs, child.subgraphs[grandchild_idx]) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + push!(newnode.subgraphs, newchildnode) + push!(newnode.subgraph_factors, 1.0) + end + end + return newnode + end +end + +function flatten_prod!(graph::AbstractGraph) + if isempty(graph.subgraphs) + return graph + else + children = [] + for sub in graph.subgraphs + push!(children, flatten_prod!(sub)) + end + newchildren = [] + newfactors = [] + if graph.operator == Sum + for (child_idx, child) in enumerate(children) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + end + elseif graph.operator == Prod + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) || child.operator == Sum + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newchildren, grandchild) + if grandchild_idx == 1 + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + else + push!(newfactors, child.subgraph_factors[grandchild_idx]) + end + end + end + end + end + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + return graph + end +end + +function flatten_prod(graph::AbstractGraph) + flatten_prod!(deepcopy(graph)) +end + """ function flatten_chains!(g::AbstractGraph)