diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index 38a410a6..9870d195 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains -export open_parenthesis, flatten_prod!, flatten_prod +export open_parenthesis, open_parenthesis!, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves! include("TaylorSeries/TaylorSeries.jl") diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 3979f504..2ea9d50e 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -3,7 +3,7 @@ module ComputationalGraphs using AbstractTrees using StaticArrays using Printf, PyCall, DataFrames -#using ..Taylor +using Random macro todo() return :(error("Not yet implemented!")) end @@ -50,7 +50,7 @@ export eval! include("transform.jl") export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains -export open_parenthesis, flatten_prod!, flatten_prod +export open_parenthesis!, open_parenthesis, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum include("optimize.jl") export optimize!, optimize diff --git a/src/computational_graph/eval.jl b/src/computational_graph/eval.jl index 05109cfc..04ca3c97 100644 --- a/src/computational_graph/eval.jl +++ b/src/computational_graph/eval.jl @@ -12,15 +12,23 @@ @inline apply(o::Prod, diag::FeynmanGraph{F,W}) where {F<:Number,W<:Number} = diag.weight @inline apply(o::Power{N}, diag::FeynmanGraph{F,W}) where {N,F<:Number,W<:Number} = diag.weight -function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}()) where {F,W} +function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}(); inherit=false, randseed::Int=-1) where {F,W} result = nothing - + if randseed > 0 + Random.seed!(randseed) + end for node in PostOrderDFS(g) if isleaf(node) - if isempty(leafmap) - node.weight = 1.0 - else - node.weight = leaf[leafmap[node.id]] + if !inherit + if isempty(leafmap) + if randseed < 0 + node.weight = 1.0 + else + node.weight = rand() + end + else + node.weight = leaf[leafmap[node.id]] + end end else node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors) diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 17372852..a59df941 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -155,29 +155,48 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end -function open_parenthesis(graph::AbstractGraph) +""" + open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + + Recursively open parenthesis of subgraphs within the given graph `g`with in place form. The graph eventually becomes + a single Sum root node with multiple subgraphs that represents multi-product of nodes (not flattened). + +# Arguments: +- `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. +parents +""" +function open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end + if isempty(graph.subgraphs) - return deepcopy(graph) + map[graph.id] = graph + return graph else children = [] for sub in graph.subgraphs push!(children, open_parenthesis(sub)) end - newnode = Graph([]; operator=Sum()) + newchildren = [] + newfactors = [] if graph.operator == Sum # flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum. for (child_idx, child) in enumerate(children) if isempty(child.subgraphs) - push!(newnode.subgraphs, child) - push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx]) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) else for (grandchild_idx, grandchild) in enumerate(child.subgraphs) - push!(newnode.subgraphs, grandchild) - push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + push!(newchildren, grandchild) + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) end end end elseif graph.operator == Prod + graph.operator = Sum # When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator. childsub_len = [length(child.subgraphs) for child in children] ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0 @@ -193,21 +212,43 @@ function open_parenthesis(graph::AbstractGraph) push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) end end - push!(newnode.subgraphs, newchildnode) - push!(newnode.subgraph_factors, 1.0) + push!(newchildren, newchildnode) + push!(newfactors, 1.0) end end - return newnode + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + return graph end end -function flatten_prod!(graph::AbstractGraph) +function open_parenthesis(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + return open_parenthesis!(deepcopy(graph), map=map) +end + +""" + flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + + Recursively merge multi-product sub-branches within the given graph `g by merging product subgraphs + into their parent product graphs in the in-place form. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. +""" +function flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end + if isempty(graph.subgraphs) + map[graph.id] = graph return graph else children = [] for sub in graph.subgraphs - push!(children, flatten_prod!(sub)) + push!(children, flatten_prod!(sub, map=map)) end newchildren = [] newfactors = [] @@ -235,12 +276,67 @@ function flatten_prod!(graph::AbstractGraph) end graph.subgraphs = newchildren graph.subgraph_factors = newfactors + map[graph.id] = graph + return graph + end +end + +function flatten_prod(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + return flatten_prod!(deepcopy(graph), map=map) +end + +""" + flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + + Recursively merge multi-product sub-branches within the given graph `g by merging sum subgraphs + into their parent sum graphs in the in-place form. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation. +In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them. +""" +function flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + if haskey(map, graph.id) + return map[graph.id] + end + if isempty(graph.subgraphs) + map[graph.id] = graph + return graph + else + children = [] + for sub in graph.subgraphs + push!(children, flatten_sum!(sub, map=map)) + end + newchildren = [] + newfactors = [] + if graph.operator == Sum + for (child_idx, child) in enumerate(children) + if isempty(child.subgraphs) || child.operator == Prod + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + else + for (grandchild_idx, grandchild) in enumerate(child.subgraphs) + push!(newchildren, grandchild) + push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx]) + end + end + end + elseif graph.operator == Prod + for (child_idx, child) in enumerate(children) + push!(newchildren, child) + push!(newfactors, graph.subgraph_factors[child_idx]) + end + end + graph.subgraphs = newchildren + graph.subgraph_factors = newfactors + map[graph.id] = graph return graph end end -function flatten_prod(graph::AbstractGraph) - flatten_prod!(deepcopy(graph)) +function flatten_sum(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph} + return flatten_sum!(deepcopy(graph), map=map) end """