Skip to content

Commit

Permalink
add parent_graphs field in AbstractGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Oct 24, 2023
1 parent 41dd80a commit 145b150
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export linear_combination, feynman_diagram, propagator, interaction, external_ve
# export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction

include("tree_properties.jl")
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest, count_operation
export haschildren, onechild, noparent, isleaf, isbranch, ischain, isfactorless, eldest, count_operation

include("operation.jl")
include("io.jl")
Expand Down
11 changes: 11 additions & 0 deletions src/computational_graph/abstractgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ function Base.isequal(a::AbstractGraph, b::AbstractGraph)
for field in fieldnames(typeof(a))
if field == :weight
(getproperty(a, :weight) getproperty(b, :weight)) == false && return false
elseif field == :parent_graphs
length(a.parent_graphs) != length(b.parent_graphs) && return false
ids_a = [pg.id for pg in a.parent_graphs]
ids_b = [pg.id for pg in b.parent_graphs]
Set(ids_a) != Set(ids_b) && return false
else
getproperty(a, field) != getproperty(b, field) && return false
end
Expand All @@ -65,6 +70,12 @@ function isequiv(a::AbstractGraph, b::AbstractGraph, args...)
elseif field == :subgraphs
length(a.subgraphs) != length(b.subgraphs) && return false
!all(isequiv.(getproperty(a, field), getproperty(b, field), args...)) && return false
elseif field == :parent_graphs
length(a.parent_graphs) != length(b.parent_graphs) && return false
!all(isequiv.(getproperty(a, field), getproperty(b, field), :subgraphs, args...)) && return false
# ids_a = [pg.id for pg in a.parent_graphs]
# ids_b = [pg.id for pg in b.parent_graphs]
# Set(ids_a) != Set(ids_b) && return false
else
getproperty(a, field) != getproperty(b, field) && return false
end
Expand Down
32 changes: 23 additions & 9 deletions src/computational_graph/feynmangraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ mutable struct FeynmanGraph{F,W} <: AbstractGraph # FeynmanGraph

subgraphs::Vector{FeynmanGraph{F,W}}
subgraph_factors::Vector{F}
parent_graphs::Vector{FeynmanGraph{F,W}}

operator::DataType
factor::F
Expand Down Expand Up @@ -108,8 +109,9 @@ mutable struct FeynmanGraph{F,W} <: AbstractGraph # FeynmanGraph
- `weight` weight of the diagram
"""
function FeynmanGraph(subgraphs::AbstractVector; topology=[], vertices::Union{Vector{OperatorProduct},Nothing}=nothing, external_indices=[], external_legs=[],
subgraph_factors=one.(eachindex(subgraphs)), name="", diagtype::DiagramType=GenericDiag(), operator::AbstractOperator=Sum(),
orders=zeros(Int, 16), ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
subgraph_factors=one.(eachindex(subgraphs)), parent_graphs::AbstractVector=eltype(subgraphs)[],
name="", diagtype::DiagramType=GenericDiag(), operator::AbstractOperator=Sum(), orders=zeros(Int, 16),
ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
)
@assert length(external_indices) == length(external_legs)
if typeof(operator) <: Power
Expand All @@ -120,7 +122,11 @@ mutable struct FeynmanGraph{F,W} <: AbstractGraph # FeynmanGraph
vertices = [external_operators(g) for g in subgraphs if diagram_type(g) != Propagator]
end
properties = FeynmanProperties(typeof(diagtype), vertices, topology, external_indices, external_legs)
return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
g = new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, parent_graphs, typeof(operator), factor, weight)
for sub_g in subgraphs
g sub_g.parent_graphs && push!(sub_g.parent_graphs, g)
end
return g
end

"""
Expand All @@ -142,15 +148,20 @@ mutable struct FeynmanGraph{F,W} <: AbstractGraph # FeynmanGraph
- `weight` weight of the diagram
"""
function FeynmanGraph(subgraphs::AbstractVector, properties::FeynmanProperties;
subgraph_factors=one.(eachindex(subgraphs)), name="", operator::AbstractOperator=Sum(),
orders=zeros(Int, 16), ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
subgraph_factors=one.(eachindex(subgraphs)), parent_graphs::AbstractVector=eltype(subgraphs)[],
name="", operator::AbstractOperator=Sum(), orders=zeros(Int, 16),
ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
)
@assert length(properties.external_indices) == length(properties.external_legs)
if typeof(operator) <: Power
@assert length(subgraphs) == 1 "FeynmanGraph with Power operator must have one and only one subgraph."
end
# @assert allunique(subgraphs) "all subgraphs must be distinct."
return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
g = new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, parent_graphs, typeof(operator), factor, weight)
for sub_g in subgraphs
g sub_g.parent_graphs && push!(sub_g.parent_graphs, g)
end
return g
end
end

Expand Down Expand Up @@ -248,9 +259,11 @@ end
function Base.:*(g1::FeynmanGraph{F,W}, c2::C) where {F,W,C}
g = FeynmanGraph([g1,], g1.properties; subgraph_factors=[F(c2),], operator=Prod(), orders=orders(g1), ftype=F, wtype=W)
# Merge multiplicative link
if g1.operator == Prod && onechild(g1)
if unary_istrivial(g1.operator) && onechild(g1)
g.subgraph_factors[1] *= g1.subgraph_factors[1]
g.subgraphs = g1.subgraphs
pop!(g1.parent_graphs)
push!(g1.subgraphs[1].parent_graphs, g)
end
return g
end
Expand All @@ -267,9 +280,11 @@ end
function Base.:*(c1::C, g2::FeynmanGraph{F,W}) where {F,W,C}
g = FeynmanGraph([g2,], g2.properties; subgraph_factors=[F(c1),], operator=Prod(), orders=orders(g2), ftype=F, wtype=W)
# Merge multiplicative link
if g2.operator == Prod && onechild(g2)
if unary_istrivial(g2.operator) && onechild(g2)
g.subgraph_factors[1] *= g2.subgraph_factors[1]
g.subgraphs = g2.subgraphs
pop!(g2.parent_graphs)
push!(g2.subgraphs[1].parent_graphs, g)
end
return g
end
Expand Down Expand Up @@ -305,7 +320,6 @@ function linear_combination(g1::FeynmanGraph{F,W}, g2::FeynmanGraph{F,W}, c1::C=
subgraph_factors[2] *= g2.subgraph_factors[1]
subgraphs[2] = g2.subgraphs[1]
end
# g = FeynmanGraph([g1, g2], properties; subgraph_factors=[F(c1), F(c2)], operator=Sum(), ftype=F, wtype=W)

if subgraphs[1] == subgraphs[2]
g = FeynmanGraph([subgraphs[1]], properties; subgraph_factors=[sum(subgraph_factors)], operator=Sum(), ftype=F, wtype=W)
Expand Down
17 changes: 14 additions & 3 deletions src/computational_graph/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mutable struct Graph{F,W} <: AbstractGraph # Graph

subgraphs::Vector{Graph{F,W}}
subgraph_factors::Vector{F}
parent_graphs::Vector{Graph{F,W}}

operator::DataType
factor::F
Expand All @@ -54,14 +55,20 @@ mutable struct Graph{F,W} <: AbstractGraph # Graph
- `factor` fixed scalar multiplicative factor for this diagram (e.g., a permutation sign)
- `weight` the weight of this node
"""
function Graph(subgraphs::AbstractVector; subgraph_factors=one.(eachindex(subgraphs)), name="", operator::AbstractOperator=Sum(),
orders=zeros(Int, 16), ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
function Graph(subgraphs::AbstractVector; subgraph_factors=one.(eachindex(subgraphs)), parent_graphs::AbstractVector=eltype(subgraphs)[],
name="", operator::AbstractOperator=Sum(), orders=zeros(Int, 16),
ftype=_dtype.factor, wtype=_dtype.weight, factor=one(ftype), weight=zero(wtype)
)
if typeof(operator) <: Power
@assert length(subgraphs) == 1 "Graph with Power operator must have one and only one subgraph."
end
# @assert allunique(subgraphs) "all subgraphs must be distinct."
return new{ftype,wtype}(uid(), name, orders, subgraphs, subgraph_factors, typeof(operator), factor, weight)
# return new{ftype,wtype}(uid(), name, orders, subgraphs, subgraph_factors, parent_graphs, typeof(operator), factor, weight)
g = new{ftype,wtype}(uid(), name, orders, subgraphs, subgraph_factors, parent_graphs, typeof(operator), factor, weight)
for sub_g in subgraphs
g sub_g.parent_graphs && push!(sub_g.parent_graphs, g)
end
return g
end
end

Expand Down Expand Up @@ -99,6 +106,8 @@ function Base.:*(g1::Graph{F,W}, c2::C) where {F,W,C}
if unary_istrivial(g1.operator) && onechild(g1)
g.subgraph_factors[1] *= g1.subgraph_factors[1]
g.subgraphs = g1.subgraphs
pop!(g1.parent_graphs)
push!(g1.subgraphs[1].parent_graphs, g)
end
return g
end
Expand All @@ -118,6 +127,8 @@ function Base.:*(c1::C, g2::Graph{F,W}) where {F,W,C}
if unary_istrivial(g2.operator) && onechild(g2)
g.subgraph_factors[1] *= g2.subgraph_factors[1]
g.subgraphs = g2.subgraphs
pop!(g2.parent_graphs)
push!(g2.subgraphs[1].parent_graphs, g)
end
return g
end
Expand Down
28 changes: 24 additions & 4 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,19 @@ end
"""
function merge_all_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(g, verbose=verbose)
merge_all_factorless_chains!(g, verbose=verbose)
for sub_g in g.subgraphs
merge_all_chains!(sub_g)
merge_chains!(sub_g)
end
merge_chains!(g)
return g
end
# function merge_all_chains!(g::AbstractGraph; verbose=0)
# verbose > 0 && println("merge all nodes representing trivial unary chains.")
# merge_all_chain_prefactors!(g, verbose=verbose)
# merge_all_factorless_chains!(g, verbose=verbose)
# return g
# end

"""
function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) where {G<:AbstractGraph}
Expand All @@ -174,10 +183,20 @@ end
"""
function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(graphs, verbose=verbose)
merge_all_factorless_chains!(graphs, verbose=verbose)
# Post-order DFS
for g in graphs
merge_all_chains!(g.subgraphs)
merge_chains!(g)
end
return graphs
end
# function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
# verbose > 0 && println("merge all nodes representing trivial unary chains.")
# merge_all_chain_prefactors!(graphs, verbose=verbose)
# merge_all_factorless_chains!(graphs, verbose=verbose)
# return graphs
# end


"""
function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)
Expand Down Expand Up @@ -345,6 +364,7 @@ function remove_duplicated_leaves!(graphs::AbstractVector{<:AbstractGraph}; verb
for (si, sub_g) in enumerate(n.subgraphs)
if isleaf(sub_g)
n.subgraphs[si] = uniqueLeaf[leafMap[sub_g.id]]
n n.subgraphs[si].parent_graphs && push!(n.subgraphs[si].parent_graphs, n)
end
end
end
Expand Down
85 changes: 79 additions & 6 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ function replace_subgraph!(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph)
for (i, child) in enumerate(children(node))
if isequiv(child, w, :id)
node.subgraphs[i] = m
node m.parent_graphs && push!(m.parent_graphs, node)
loc = findfirst(isequal(node), child.parent_graphs)
popat!(child.parent_graphs, loc)
return
end
end
Expand Down Expand Up @@ -150,6 +153,9 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph)
for (i, child) in enumerate(children(node))
if isequiv(child, w, :id)
node.subgraphs[i] = m
node m.parent_graphs && push!(m.parent_graphs, node)
loc = findfirst(isequal(node), child.parent_graphs)
popat!(child.parent_graphs, loc)
break
end
end
Expand All @@ -170,9 +176,18 @@ end
- `g::AbstractGraph`: graph to be modified
"""
function merge_factorless_chain!(g::AbstractGraph)
while unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
if unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
child = eldest(g)
for field in fieldnames(typeof(g))
field == :parent_graphs && continue
value = getproperty(child, field)
setproperty!(g, field, value)
end
end
while unary_istrivial(g.operator) && onechild(g) && oneparent(g) && isfactorless(g)
child = eldest(g)
for field in fieldnames(typeof(g))
field == :parent_graphs && continue
value = getproperty(child, field)
setproperty!(g, field, value)
end
Expand All @@ -194,10 +209,17 @@ end
- `g::AbstractGraph`: graph to be modified
"""
function merge_factorless_chain(g::AbstractGraph)
while unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
parent_g = g.parent_graphs
if unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
g = eldest(g)
end
while unary_istrivial(g.operator) && onechild(g) && oneparent(g) && isfactorless(g)
g = eldest(g)
end
return g
# g_new = deepcopy(g)
# g_new.parent_graphs = parent_g
# return g_new
end

"""
Expand All @@ -215,7 +237,7 @@ end
function merge_chain_prefactors!(g::AbstractGraph)
for (i, child) in enumerate(g.subgraphs)
total_chain_factor = 1
while onechild(child)
while onechild(child) && oneparent(child)
# Break case: end of trivial unary chain
unary_istrivial(child.operator) == false && break
# Move this subfactor to running total
Expand Down Expand Up @@ -244,6 +266,21 @@ end
"""
merge_chain_prefactors(g::AbstractGraph) = merge_chain_prefactors!(deepcopy(g))

# function merge_chain!(g::AbstractGraph)
# # while unary_istrivial(g.operator) && onechild(g)
# # !onechild(eldest(g)) && break
# while onechild(g) && onechild(eldest(g)) && unary_istrivial(eldest(g).operator)
# merge_chain!(eldest(g))

# loc = findfirst(isequal(g), eldest(g).parent_graphs)
# popat!(eldest(g).parent_graphs, loc)
# g.subgraph_factors[1] *= eldest(g).subgraph_factors[1]
# g.subgraphs = eldest(g).subgraphs
# push!(eldest(g).parent_graphs, g)
# end
# return g
# end

"""
function merge_chains!(g::AbstractGraph)
Expand All @@ -257,12 +294,26 @@ merge_chain_prefactors(g::AbstractGraph) = merge_chain_prefactors!(deepcopy(g))
- `g::AbstractGraph`: graph to be modified
"""
function merge_chains!(g::AbstractGraph)
merge_chain_prefactors!(g) # shift chain subgraph factors towards root level
for sub_g in g.subgraphs # prune factorless chain subgraphs
merge_factorless_chain!(sub_g)
for (i, sub_g) in enumerate(g.subgraphs)
if onechild(sub_g) && unary_istrivial(sub_g.operator)
merge_chains!(sub_g)

loc = findfirst(isequal(g), sub_g.parent_graphs)
popat!(sub_g.parent_graphs, loc)
g.subgraph_factors[i] *= sub_g.subgraph_factors[1]
g.subgraphs[i] = eldest(sub_g)
push!(eldest(sub_g).parent_graphs, g)
end
end
return g
end
# function merge_chains!(g::AbstractGraph)
# merge_chain_prefactors!(g) # shift chain subgraph factors towards root level
# for sub_g in g.subgraphs # prune factorless chain subgraphs
# merge_factorless_chain!(sub_g)
# end
# return g
# end

"""
function merge_chains(g::AbstractGraph)
Expand All @@ -278,6 +329,7 @@ end
"""
merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g))


"""
function merge_linear_combination(g::Graph)
Expand Down Expand Up @@ -356,6 +408,13 @@ function merge_linear_combination!(g::Graph{F,W}) where {F,W}
g_merged = merge_linear_combination(g)
g.subgraphs = g_merged.subgraphs
g.subgraph_factors = g_merged.subgraph_factors
for sub_g in g.subgraphs
if g in sub_g.parent_graphs
pop!(sub_g.parent_graphs) # delete the parent graph for g_merged.
else
sub_g.parent_graphs[end] = g
end
end
end
return g
end
Expand All @@ -365,6 +424,13 @@ function merge_linear_combination!(g::FeynmanGraph{F,W}) where {F,W}
g_merged = merge_linear_combination(g)
g.subgraphs = g_merged.subgraphs
g.subgraph_factors = g_merged.subgraph_factors
for sub_g in g.subgraphs
if g in sub_g.parent_graphs
pop!(sub_g.parent_graphs) # delete the parent graph for g_merged.
else
sub_g.parent_graphs[end] = g
end
end
end
return g
end
Expand Down Expand Up @@ -425,6 +491,13 @@ function merge_multi_product!(g::Graph{F,W}) where {F,W}
g.subgraphs = g_merged.subgraphs
g.subgraph_factors = g_merged.subgraph_factors
g.operator = g_merged.operator
for sub_g in g.subgraphs
if g in sub_g.parent_graphs
pop!(sub_g.parent_graphs) # delete the parent graph for g_merged.
else
sub_g.parent_graphs[end] = g
end
end
end
return g
end
Loading

0 comments on commit 145b150

Please sign in to comment.