Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flatten_sum and refactor all new transform functions #171

Merged
merged 3 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/FeynmanDiagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
export open_parenthesis, flatten_prod!, flatten_prod
export open_parenthesis, open_parenthesis!, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum
export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves!

include("TaylorSeries/TaylorSeries.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ComputationalGraphs
using AbstractTrees
using StaticArrays
using Printf, PyCall, DataFrames
#using ..Taylor
using Random
macro todo()
return :(error("Not yet implemented!"))
end
Expand Down Expand Up @@ -50,7 +50,7 @@ export eval!
include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains
export open_parenthesis, flatten_prod!, flatten_prod
export open_parenthesis!, open_parenthesis, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum
include("optimize.jl")
export optimize!, optimize

Expand Down
20 changes: 14 additions & 6 deletions src/computational_graph/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
@inline apply(o::Prod, diag::FeynmanGraph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Power{N}, diag::FeynmanGraph{F,W}) where {N,F<:Number,W<:Number} = diag.weight

function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}()) where {F,W}
function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}(); inherit=false, randseed::Int=-1) where {F,W}
result = nothing

if randseed > 0
Random.seed!(randseed)
end
for node in PostOrderDFS(g)
if isleaf(node)
if isempty(leafmap)
node.weight = 1.0
else
node.weight = leaf[leafmap[node.id]]
if !inherit
if isempty(leafmap)
if randseed < 0
node.weight = 1.0
else
node.weight = rand()
end
else
node.weight = leaf[leafmap[node.id]]
end
end
else
node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors)
Expand Down
124 changes: 110 additions & 14 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,48 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph)
return g_new
end

function open_parenthesis(graph::AbstractGraph)
"""
open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}

Recursively open parenthesis of subgraphs within the given graph `g`with in place form. The graph eventually becomes
a single Sum root node with multiple subgraphs that represents multi-product of nodes (not flattened).

# Arguments:
- `g::AbstractGraph`: graph to be modified
- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation.
In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them.
parents
"""
function open_parenthesis!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
if haskey(map, graph.id)
return map[graph.id]
end

if isempty(graph.subgraphs)
return deepcopy(graph)
map[graph.id] = graph
return graph
else
children = []
for sub in graph.subgraphs
push!(children, open_parenthesis(sub))
end
newnode = Graph([]; operator=Sum())
newchildren = []
newfactors = []
if graph.operator == Sum
# flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum.
for (child_idx, child) in enumerate(children)
if isempty(child.subgraphs)
push!(newnode.subgraphs, child)
push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx])
push!(newchildren, child)
push!(newfactors, graph.subgraph_factors[child_idx])
else
for (grandchild_idx, grandchild) in enumerate(child.subgraphs)
push!(newnode.subgraphs, grandchild)
push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
push!(newchildren, grandchild)
push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
end
end
end
elseif graph.operator == Prod
graph.operator = Sum
# When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator.
childsub_len = [length(child.subgraphs) for child in children]
ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0
Expand All @@ -193,21 +212,43 @@ function open_parenthesis(graph::AbstractGraph)
push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
end
end
push!(newnode.subgraphs, newchildnode)
push!(newnode.subgraph_factors, 1.0)
push!(newchildren, newchildnode)
push!(newfactors, 1.0)
end
end
return newnode
graph.subgraphs = newchildren
graph.subgraph_factors = newfactors
return graph
end
end

function flatten_prod!(graph::AbstractGraph)
function open_parenthesis(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
return open_parenthesis!(deepcopy(graph), map=map)
end

"""
flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}

Recursively merge multi-product sub-branches within the given graph `g by merging product subgraphs
into their parent product graphs in the in-place form.

# Arguments:
- `g::AbstractGraph`: graph to be modified
- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation.
In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them.
"""
function flatten_prod!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
if haskey(map, graph.id)
return map[graph.id]
end

if isempty(graph.subgraphs)
map[graph.id] = graph
return graph
else
children = []
for sub in graph.subgraphs
push!(children, flatten_prod!(sub))
push!(children, flatten_prod!(sub, map=map))
end
newchildren = []
newfactors = []
Expand Down Expand Up @@ -235,12 +276,67 @@ function flatten_prod!(graph::AbstractGraph)
end
graph.subgraphs = newchildren
graph.subgraph_factors = newfactors
map[graph.id] = graph
return graph
end
end

function flatten_prod(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
return flatten_prod!(deepcopy(graph), map=map)
end

"""
flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}

Recursively merge multi-product sub-branches within the given graph `g by merging sum subgraphs
into their parent sum graphs in the in-place form.

# Arguments:
- `g::AbstractGraph`: graph to be modified
- `map::Dict{Int,G}=Dict{Int,G}()`: A dictionary that maps the id of an original node with its corresponding new node after transformation.
In recursive transform, nodes can be visited several times by different parents. This map keeps track of those visited, and reuse those transformed sub-branches instead of recreating them.
"""
function flatten_sum!(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
if haskey(map, graph.id)
return map[graph.id]
end
if isempty(graph.subgraphs)
map[graph.id] = graph
return graph
else
children = []
for sub in graph.subgraphs
push!(children, flatten_sum!(sub, map=map))
end
newchildren = []
newfactors = []
if graph.operator == Sum
for (child_idx, child) in enumerate(children)
if isempty(child.subgraphs) || child.operator == Prod
push!(newchildren, child)
push!(newfactors, graph.subgraph_factors[child_idx])
else
for (grandchild_idx, grandchild) in enumerate(child.subgraphs)
push!(newchildren, grandchild)
push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
end
end
end
elseif graph.operator == Prod
for (child_idx, child) in enumerate(children)
push!(newchildren, child)
push!(newfactors, graph.subgraph_factors[child_idx])
end
end
graph.subgraphs = newchildren
graph.subgraph_factors = newfactors
map[graph.id] = graph
return graph
end
end

function flatten_prod(graph::AbstractGraph)
flatten_prod!(deepcopy(graph))
function flatten_sum(graph::G; map::Dict{Int,G}=Dict{Int,G}()) where {G<:AbstractGraph}
return flatten_sum!(deepcopy(graph), map=map)
end

"""
Expand Down
Loading