Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add open_parenthesis and flatten_prod #170

Merged
merged 3 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/FeynmanDiagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains

export open_parenthesis, flatten_prod!, flatten_prod
export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves!

include("TaylorSeries/TaylorSeries.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export eval!
include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains

export open_parenthesis, flatten_prod!, flatten_prod
include("optimize.jl")
export optimize!, optimize

Expand Down
88 changes: 88 additions & 0 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,94 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph)
return g_new
end

function open_parenthesis(graph::AbstractGraph)
if isempty(graph.subgraphs)
return deepcopy(graph)
else
children = []
for sub in graph.subgraphs
push!(children, open_parenthesis(sub))
end
newnode = Graph([]; operator=Sum())
if graph.operator == Sum
# flatten function make sure that all children are already converted to Sum->Prod two layer graphs, so here when merging the subgraphs we just consider the case when operator are Sum.
for (child_idx, child) in enumerate(children)
if isempty(child.subgraphs)
push!(newnode.subgraphs, child)
push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx])
else
for (grandchild_idx, grandchild) in enumerate(child.subgraphs)
push!(newnode.subgraphs, grandchild)
push!(newnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
end
end
end
elseif graph.operator == Prod
# When opertaor is Prod, we expand parenthese and replace Prod with a Sum operator.
childsub_len = [length(child.subgraphs) for child in children]
ordtuple = ((childsub_len[num] > 0) ? (1:childsub_len[num]) : (0:0) for num in eachindex(childsub_len)) #The child with no grand child is labeled with a single idx=0
for indices in collect(Iterators.product(ordtuple...)) #Indices for all combination of grandchilds, with one from each child.
newchildnode = Graph([]; operator=Prod())
for (child_idx, grandchild_idx) in enumerate(indices)
child = children[child_idx]
if grandchild_idx == 0 #Meaning this node is a leaf node
push!(newchildnode.subgraphs, child)
push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx])
else
push!(newchildnode.subgraphs, child.subgraphs[grandchild_idx])
push!(newchildnode.subgraph_factors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
end
end
push!(newnode.subgraphs, newchildnode)
push!(newnode.subgraph_factors, 1.0)
end
end
return newnode
end
end

function flatten_prod!(graph::AbstractGraph)
if isempty(graph.subgraphs)
return graph
else
children = []
for sub in graph.subgraphs
push!(children, flatten_prod!(sub))
end
newchildren = []
newfactors = []
if graph.operator == Sum
for (child_idx, child) in enumerate(children)
push!(newchildren, child)
push!(newfactors, graph.subgraph_factors[child_idx])
end
elseif graph.operator == Prod
for (child_idx, child) in enumerate(children)
if isempty(child.subgraphs) || child.operator == Sum
push!(newchildren, child)
push!(newfactors, graph.subgraph_factors[child_idx])
else
for (grandchild_idx, grandchild) in enumerate(child.subgraphs)
push!(newchildren, grandchild)
if grandchild_idx == 1
push!(newfactors, graph.subgraph_factors[child_idx] * child.subgraph_factors[grandchild_idx])
else
push!(newfactors, child.subgraph_factors[grandchild_idx])
end
end
end
end
end
graph.subgraphs = newchildren
graph.subgraph_factors = newfactors
return graph
end
end

function flatten_prod(graph::AbstractGraph)
flatten_prod!(deepcopy(graph))
end

"""
function flatten_chains!(g::AbstractGraph)

Expand Down
Loading