From 5193eeb1b5487ffd25279811a84538e0184bd239 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Sat, 23 Dec 2023 12:09:07 -0500 Subject: [PATCH 1/3] add flatten_sum --- src/computational_graph/ComputationalGraph.jl | 2 +- src/computational_graph/transform.jl | 66 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 3979f504..b2b2ea05 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -50,7 +50,7 @@ export eval! include("transform.jl") export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains -export open_parenthesis, flatten_prod!, flatten_prod +export open_parenthesis, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 17372852..13d453d5 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -155,6 +155,15 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end +""" + open_parenthesis(graph::AbstractGraph) + + Recursively open parenthesis of subgraphs within the given graph `g`. The graph eventually becomes + a single Sum root node with multiple subgraphs that represents multi-product of nodes (not flattened). + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" function open_parenthesis(graph::AbstractGraph) if isempty(graph.subgraphs) return deepcopy(graph) @@ -201,6 +210,15 @@ function open_parenthesis(graph::AbstractGraph) end end +""" + flatten_prod!(graph::AbstractGraph) + + Recursively merge multi-product sub-branches within the given graph `g by merging product subgraphs + into their parent product graphs in the in-place form. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" function flatten_prod!(graph::AbstractGraph) if isempty(graph.subgraphs) return graph @@ -243,6 +261,54 @@ function flatten_prod(graph::AbstractGraph) flatten_prod!(deepcopy(graph)) end +""" + flatten_sum!(graph::AbstractGraph) + + Recursively merge multi-product sub-branches within the given graph `g by merging sum subgraphs + into their parent sum graphs in the in-place form. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" +function flatten_sum!(graph::AbstractGraph) + if isempty(graph.subgraphs) + return graph + else + children = [] + for sub in graph.subgraphs + push!(children, flatten_sum!(sub)) + end + newchildren = [] + newfactors = [] + if graph.operator == Sum + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) || child.operator == Prod + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newchildren, grandchild) + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + end + elseif graph.operator == Prod + for (child_idx, child) in enumerate(children) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + end + end + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + return graph + end +end + +function flatten_sum!(graph::AbstractGraph) + flatten_sum!(deepcopy(graph)) +end + +to_feynman(g::AbstractGraph) = flatten_prod(open_parenthesis(g)) """ function flatten_chains!(g::AbstractGraph) From 6fee3c549ee36bf4a5e362fa3a3dfe8b04ceb34f Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Sat, 23 Dec 2023 22:26:16 -0500 Subject: [PATCH 2/3] make inplace open_parenthese; add map to all new transforms --- src/FeynmanDiagram.jl | 2 +- src/computational_graph/ComputationalGraph.jl | 4 +- src/computational_graph/eval.jl | 20 ++-- src/computational_graph/transform.jl | 102 +++++++++++++++--- 4 files changed, 105 insertions(+), 23 deletions(-) diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index 38a410a6..9870d195 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains -export open_parenthesis, flatten_prod!, flatten_prod +export open_parenthesis, open_parenthesis!, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves! include("TaylorSeries/TaylorSeries.jl") diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index b2b2ea05..2ea9d50e 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -3,7 +3,7 @@ module ComputationalGraphs using AbstractTrees using StaticArrays using Printf, PyCall, DataFrames -#using ..Taylor +using Random macro todo() return :(error("Not yet implemented!")) end @@ -50,7 +50,7 @@ export eval! include("transform.jl") export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains -export open_parenthesis, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum +export open_parenthesis!, open_parenthesis, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/eval.jl b/src/computational_graph/eval.jl index 05109cfc..04ca3c97 100644 --- a/src/computational_graph/eval.jl +++ b/src/computational_graph/eval.jl @@ -12,15 +12,23 @@ @inline apply(o::Prod, diag::FeynmanGraph{F,W}) where {F<:Number,W<:Number} = diag.weight @inline apply(o::Power{N}, diag::FeynmanGraph{F,W}) where {N,F<:Number,W<:Number} = diag.weight -function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}()) where {F,W} +function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}(); inherit=false, randseed::Int=-1) where {F,W} result = nothing - + if randseed > 0 + Random.seed!(randseed) + end for node in PostOrderDFS(g) if isleaf(node) - if isempty(leafmap) - node.weight = 1.0 - else - node.weight = leaf[leafmap[node.id]] + if !inherit + if isempty(leafmap) + if randseed < 0 + node.weight = 1.0 + else + node.weight = rand() + end + else + node.weight = leaf[leafmap[node.id]] + end end else node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors) diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 13d453d5..31b8cf7a 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -156,14 +156,72 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) end """ - open_parenthesis(graph::AbstractGraph) + open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} - Recursively open parenthesis of subgraphs within the given graph `g`. The graph eventually becomes + Recursively open parenthesis of subgraphs within the given graph `g`with in place form. The graph eventually becomes a single Sum root node with multiple subgraphs that represents multi-product of nodes (not flattened). # Arguments: - `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. +parents """ +function open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end + + if isempty(graph.subgraphs) + map[graph.id] = graph + return graph + else + children = [] + for sub in graph.subgraphs + push!(children, open_parenthesis(sub)) + end + newchildren = [] + newfactors = [] + if graph.operator == Sum + # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newchildren, grandchild) + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + end + elseif graph.operator == Prod + graph.operator = Sum + # When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator. + childsub_len = [length(child.subgraphs) for child in children] + ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0 + for indices in collect(Iterators.product(ordtuple...)) #Indices for all combination of grandchilds, with one from each child. + newchildnode = Graph([]; operator=Prod()) + for (child_idx, grandchild_idx) in enumerate(indices) + child = children[child_idx] + if grandchild_idx == 0 #Meaning this node is a leaf node + push!(newchildnode.subgraphs, child) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx]) + else + push!(newchildnode.subgraphs, child.subgraphs[grandchild_idx]) + push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + push!(newchildren, newchildnode) + push!(newfactors, 1.0) + end + end + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + return graph + end +end + function open_parenthesis(graph::AbstractGraph) if isempty(graph.subgraphs) return deepcopy(graph) @@ -172,6 +230,8 @@ function open_parenthesis(graph::AbstractGraph) for sub in graph.subgraphs push!(children, open_parenthesis(sub)) end + newchildren = [] + newfactors = [] newnode = Graph([]; operator=Sum()) if graph.operator == Sum # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. @@ -211,21 +271,28 @@ function open_parenthesis(graph::AbstractGraph) end """ - flatten_prod!(graph::AbstractGraph) + flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} Recursively merge multi-product sub-branches within the given graph `g by merging product subgraphs into their parent product graphs in the in-place form. # Arguments: - `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. """ -function flatten_prod!(graph::AbstractGraph) +function flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end + if isempty(graph.subgraphs) + map[graph.id] = graph return graph else children = [] for sub in graph.subgraphs - push!(children, flatten_prod!(sub)) + push!(children, flatten_prod!(sub, map=map)) end newchildren = [] newfactors = [] @@ -253,30 +320,37 @@ function flatten_prod!(graph::AbstractGraph) end graph.subgraphs = newchildren graph.subgraph_factors = newfactors + map[graph.id] = graph return graph end end -function flatten_prod(graph::AbstractGraph) - flatten_prod!(deepcopy(graph)) +function flatten_prod(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + flatten_prod!(deepcopy(graph), map=map) end -""" - flatten_sum!(graph::AbstractGraph) +""" + flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} Recursively merge multi-product sub-branches within the given graph `g by merging sum subgraphs into their parent sum graphs in the in-place form. # Arguments: - `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. """ -function flatten_sum!(graph::AbstractGraph) +function flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end if isempty(graph.subgraphs) + map[graph.id] = graph return graph else children = [] for sub in graph.subgraphs - push!(children, flatten_sum!(sub)) + push!(children, flatten_sum!(sub, map=map)) end newchildren = [] newfactors = [] @@ -300,15 +374,15 @@ function flatten_sum!(graph::AbstractGraph) end graph.subgraphs = newchildren graph.subgraph_factors = newfactors + map[graph.id] = graph return graph end end -function flatten_sum!(graph::AbstractGraph) - flatten_sum!(deepcopy(graph)) +function flatten_sum(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + flatten_sum!(deepcopy(graph), map=map) end -to_feynman(g::AbstractGraph) = flatten_prod(open_parenthesis(g)) """ function flatten_chains!(g::AbstractGraph) From f94f13cb0aa4ed0927ef1fddef3a68c6e0250402 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Sat, 23 Dec 2023 22:50:27 -0500 Subject: [PATCH 3/3] minor change --- src/computational_graph/transform.jl | 52 +++------------------------- 1 file changed, 4 insertions(+), 48 deletions(-) diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 31b8cf7a..a59df941 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -222,52 +222,8 @@ function open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:A end end -function open_parenthesis(graph::AbstractGraph) - if isempty(graph.subgraphs) - return deepcopy(graph) - else - children = [] - for sub in graph.subgraphs - push!(children, open_parenthesis(sub)) - end - newchildren = [] - newfactors = [] - newnode = Graph([]; operator=Sum()) - if graph.operator == Sum - # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. - for (child_idx, child) in enumerate(children) - if isempty(child.subgraphs) - push!(newnode.subgraphs, child) - push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) - else - for (grandchild_idx, grandchild) in enumerate(child.subgraphs) - push!(newnode.subgraphs, grandchild) - push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) - end - end - end - elseif graph.operator == Prod - # When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator. - childsub_len = [length(child.subgraphs) for child in children] - ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0 - for indices in collect(Iterators.product(ordtuple...)) #Indices for all combination of grandchilds, with one from each child. - newchildnode = Graph([]; operator=Prod()) - for (child_idx, grandchild_idx) in enumerate(indices) - child = children[child_idx] - if grandchild_idx == 0 #Meaning this node is a leaf node - push!(newchildnode.subgraphs, child) - push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx]) - else - push!(newchildnode.subgraphs, child.subgraphs[grandchild_idx]) - push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) - end - end - push!(newnode.subgraphs, newchildnode) - push!(newnode.subgraph_factors, 1.0) - end - end - return newnode - end +function open_parenthesis(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + return open_parenthesis!(deepcopy(graph), map=map) end """ @@ -326,7 +282,7 @@ function flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:Abstr end function flatten_prod(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} - flatten_prod!(deepcopy(graph), map=map) + return flatten_prod!(deepcopy(graph), map=map) end """ @@ -380,7 +336,7 @@ function flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:Abstra end function flatten_sum(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} - flatten_sum!(deepcopy(graph), map=map) + return flatten_sum!(deepcopy(graph), map=map) end """