diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl
index aed2d068..5ed4c1f2 100644
--- a/src/FeynmanDiagram.jl
+++ b/src/FeynmanDiagram.jl
@@ -119,7 +119,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti
# export reducibility, connectivity
# export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction
# export Coupling_yukawa, Coupling_phi3, Coupling_phi4, Coupling_phi6
-export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest
+export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
diff --git a/src/backend/to_dot.jl b/src/backend/to_dot.jl
index b001ada3..69e64f2f 100644
--- a/src/backend/to_dot.jl
+++ b/src/backend/to_dot.jl
@@ -1,26 +1,26 @@
function to_dotstatic(operator::Type, id::Int, factor, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
- "Static representation for computational graph nodes with operator $(operator) not yet implemented! "
+ "Static representation for computational graph nodes with operator $(operator) not yet implemented! "
)
end
-function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
+function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
node_temp = ""
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
end
opr_node = opr_name * "[shape=box, label = \"Add\", style=filled, fillcolor=cyan,]\n"
node_temp *= opr_node
- for (gix,(g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors))
- if gfactor!= 1
- factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n"
+ for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors))
+ if gfactor != 1
+ factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n"
subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= factor_str * subg_str
arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n"
@@ -29,7 +29,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgra
arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n"
end
end
- return node_temp,arrow_temp
+ return node_temp, arrow_temp
end
function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
@@ -37,27 +37,27 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
end
- opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= opr_node
if length(subgraphs) == 1
- if subgraph_factors[1] ==1
+ if subgraph_factors[1] == 1
arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n"
else
- factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
+ factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
node_temp *= factor_str
arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n"
end
else
- for (gix,(g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors))
- if gfactor!= 1
- factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n"
+ for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors))
+ if gfactor != 1
+ factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n"
subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= factor_str * subg_str
arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n"
@@ -67,7 +67,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg
end
end
end
- return node_temp,arrow_temp
+ return node_temp, arrow_temp
end
function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
@@ -75,19 +75,19 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F,
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
end
- opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n"
+ opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n"
order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n"
node_temp *= opr_node * order_node
- arrow_temp*= "order$(id)->$opr_name[arrowhead=vee,]\n"
+ arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n"
if subgraph_factors[1] != 1
- factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
+ factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= factor_str * subg_str
arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n"
@@ -98,14 +98,14 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F,
return node_temp, arrow_temp
end
-function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
+function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
node_temp = ""
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
@@ -123,7 +123,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgr
arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n"
end
end
- return node_temp,arrow_temp
+ return node_temp, arrow_temp
end
function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
@@ -131,20 +131,20 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
end
- opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= opr_node
if length(subgraphs) == 1
- if subgraph_factors[1] ==1
+ if subgraph_factors[1] == 1
arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n"
else
- factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
+ factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
node_temp *= factor_str
arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n"
end
@@ -161,7 +161,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg
end
end
end
- return node_temp,arrow_temp
+ return node_temp, arrow_temp
end
function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
@@ -169,19 +169,19 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F,
arrow_temp = ""
if factor != 1
opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n"
- opr_name = "g$(id)_t"
- node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
- arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
+ opr_name = "g$(id)_t"
+ node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
+ arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n"
node_temp *= opr_fac * node_str
else
opr_name = "g$id"
end
- opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n"
+ opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n"
order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n"
node_temp *= opr_node * order_node
- arrow_temp*= "order$(id)->$opr_name[arrowhead=vee,]\n"
+ arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n"
if subgraph_factors[1] != 1
- factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
+ factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n"
subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
node_temp *= factor_str * subg_str
arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n"
@@ -193,13 +193,15 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F,
end
"""
- function to_dot_str(graphs::AbstractVector{<:AbstractGraph})
+ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", diagram_id_map=nothing)
+
Compile a list of graphs into a string for dot language.
+
# Arguments:
- `graphs` vector of computational graphs
- - `title` The name of the complied function (defaults to `"ComputationalGraph"`)
+ - `title` The name of the compiled function (defaults to nothing)
"""
-function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="ComputationalGraph")
+function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", diagram_id_map=nothing)
head = "digraph ComputationalGraph { \nlabel=\"$name\"\n"
head *= "ReturnNode[shape=box, label = \"Return\", style=filled, fillcolor=darkorange,]\n"
body_node = ""
@@ -218,12 +220,12 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
- leafname = getname(g.properties, leafidx)
+ leafname = get_leafname(g, leafidx, diagram_id_map)
if factor(g) == 1
gnode_str = "g$g_id[label=$leafname, style=filled, fillcolor=paleturquoise]\n"
body_node *= gnode_str
else
- factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n"
+ factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n"
leaf_node = "l$(leafidx)[label=$leafname, style=filled, fillcolor=paleturquoise]\n"
gnode_str = "g$g_id[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n"
body_node *= factor_str * leaf_node * gnode_str
@@ -233,14 +235,14 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
- temp_node,temp_arrow = to_dotstatic(operator(g), g_id, factor(g), subgraphs(g), subgraph_factors(g))
- body_node *=temp_node
+ temp_node, temp_arrow = to_dotstatic(operator(g), g_id, factor(g), subgraphs(g), subgraph_factors(g))
+ body_node *= temp_node
body_arrow *= temp_arrow
push!(inds_visitednode, g_id)
end
if isroot
body_arrow *= "g$(g_id)->ReturnNode[arrowhead=vee,]\n"
- rootidx +=1
+ rootidx += 1
end
end
end
@@ -250,25 +252,42 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu
return expr
end
-function compile_dot(graphs::AbstractVector{<:AbstractGraph}, filename::String; graph_name="ComputationalGraph")
- dot_string = to_dot_str(graphs, graph_name)
+function compile_dot(graphs::AbstractVector{<:AbstractGraph}, filename::String; graph_name="", diagram_id_map=nothing)
+ dot_string = to_dot_str(graphs, graph_name, diagram_id_map)
open(filename, "w") do f
write(f, dot_string)
end
end
-function getname(properties,leafidx)
- if properties isa BareGreenId
- lfname = "<G$leafidx>"
- elseif properties isa BareInteractionId
- lfname = "<V$leafidx>"
- elseif typeof(properties) == FeynmanProperties && properties.diagtype == ComputationalGraphs.Propagator
- lfname = "<G$leafidx>"
- elseif typeof(properties) == FeynmanProperties && properties.diagtype == ComputationalGraphs.Interaction
- lfname = "<V$leafidx>"
+function get_leafname(g, leafidx, diagram_id_map=nothing)
+ println(typeof(g))
+ leaftype = Nothing
+ if g isa FeynmanGraph
+ leaftype = g.properties.diagtype
+ elseif g isa Graph
+ if isnothing(diagram_id_map) == false
+ leaftype = typeof(diagram_id_map[g.id])
+ end
+ else
+ error("Unknown graph type: $(typeof(g))")
+ end
+ if leaftype in [BareGreenId, ComputationalGraphs.Propagator]
+ leafname = "<G$leafidx>"
+ elseif leaftype in [BareInteractionId, ComputationalGraphs.Interaction]
+ leafname = "<V$leafidx>"
+ elseif leaftype == PolarId
+ leafname = "<Π$leafidx>"
+ elseif leaftype == Ver3Id
+ leafname = "<Γ(3)$leafidx>"
+ elseif leaftype == Ver4Id
+ leafname = "<Γ(4)$leafidx>"
else
- lfname = "$leafidx>"
+ leafname = "$leafidx>"
end
- return lfname
+ println()
+ println(g)
+ println(leaftype)
+ println(leafname)
+ println()
+ return leafname
end
-
diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl
index 97c2bffb..4a50a17f 100644
--- a/src/computational_graph/ComputationalGraph.jl
+++ b/src/computational_graph/ComputationalGraph.jl
@@ -38,7 +38,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti
include("tree_properties.jl")
-export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest, count_operation
+export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest, count_operation
include("operation.jl")
include("io.jl")
diff --git a/src/computational_graph/abstractgraph.jl b/src/computational_graph/abstractgraph.jl
index a5f53a1d..d47d6b8d 100644
--- a/src/computational_graph/abstractgraph.jl
+++ b/src/computational_graph/abstractgraph.jl
@@ -254,6 +254,16 @@ function set_subgraph_factors!(g::AbstractGraph, subgraph_factors::AbstractVecto
end
end
+"""
+function disconnect_subgraphs!(g::G) where {G<:AbstractGraph}
+
+ Empty the subgraphs and subgraph_factors of graph `g`. Any child nodes of g
+ not referenced elsewhere in the full computational graph are effectively deleted.
+"""
+function disconnect_subgraphs!(g::AbstractGraph)
+ empty!(subgraphs(g))
+ empty!(subgraph_factors(g))
+end
### Methods ###
diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl
index 95e7cf39..0017f3ad 100644
--- a/src/computational_graph/optimize.jl
+++ b/src/computational_graph/optimize.jl
@@ -16,7 +16,7 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose
remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize)
flatten_all_chains!(graphs, verbose=verbose)
merge_all_linear_combinations!(graphs, verbose=verbose)
-
+ remove_all_zero_valued_subgraphs!(graphs, verbose=verbose)
return graphs
end
end
@@ -65,7 +65,7 @@ end
"""
function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
- Flattens all nodes representing trivial unary chains in-place in given graphs.
+ Flattens all nodes representing trivial unary chains in-place in the given graphs.
# Arguments:
- `graphs`: A collection of graphs to be processed.
@@ -84,10 +84,57 @@ function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}
return graphs
end
+"""
+ function remove_all_zero_valued_subgraphs!(g::AbstractGraph; verbose=0)
+
+ Recursively removes all zero-valued subgraph(s) in-place in the given graph `g`.
+
+# Arguments:
+- `g`: An AbstractGraph.
+- `verbose`: Level of verbosity (default: 0).
+
+# Returns:
+- Optimized graph.
+#
+"""
+function remove_all_zero_valued_subgraphs!(g::AbstractGraph; verbose=0)
+ verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.")
+ # Post-order DFS
+ for sub_g in subgraphs(g)
+ remove_all_zero_valued_subgraphs!(sub_g)
+ remove_zero_valued_subgraphs!(sub_g)
+ end
+ remove_zero_valued_subgraphs!(g)
+ return g
+end
+
+"""
+ function remove_all_zero_valued_subgraphs!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
+
+ Recursively removes all zero-valued subgraph(s) in-place in the given graphs.
+
+# Arguments:
+- `graphs`: A collection of graphs to be processed.
+- `verbose`: Level of verbosity (default: 0).
+
+# Returns:
+- Optimized graphs.
+#
+"""
+function remove_all_zero_valued_subgraphs!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
+ verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.")
+ # Post-order DFS
+ for g in graphs
+ remove_all_zero_valued_subgraphs!(subgraphs(g))
+ remove_zero_valued_subgraphs!(g)
+ end
+ return graphs
+end
+
"""
function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)
- Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place within a single graph.
+ Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in the given graph `g`.
# Arguments:
- `g`: An AbstractGraph.
@@ -111,7 +158,7 @@ end
"""
function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
- Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in given graphs.
+ Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in the given graphs.
# Arguments:
- `graphs`: A collection of graphs to be processed.
@@ -134,7 +181,7 @@ end
"""
function merge_all_multi_products!(g::Graph; verbose=0)
- Merges all nodes representing a multi product of a non-unique list of subgraphs in-place within a single graph.
+ Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in the given graph `g`.
# Arguments:
- `g::Graph`: A Graph.
@@ -158,7 +205,7 @@ end
"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0)
- Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in given graphs.
+ Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in the given graphs.
# Arguments:
- `graphs`: A collection of graphs to be processed.
diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl
index 3c5c6380..fecf791a 100644
--- a/src/computational_graph/transform.jl
+++ b/src/computational_graph/transform.jl
@@ -183,7 +183,7 @@ end
function flatten_chains(g::AbstractGraph)
Recursively flattens chains of subgraphs within a given graph `g` by merging certain trivial unary subgraphs into their parent graphs,
- This function returns a new graph with flatten chains, dervied from the input graph `g` remaining unchanged.
+ This function returns a new graph with flatten chains, derived from the input graph `g` remaining unchanged.
# Arguments:
- `g::AbstractGraph`: graph to be modified
@@ -191,7 +191,49 @@ end
flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g))
"""
- function merge_linear_combination(g::AbstractGraph)
+ function remove_zero_valued_subgraphs!(g::AbstractGraph)
+
+ Removes zero-valued (zero subgraph_factor) subgraph(s) of a computational graph `g`. If all subgraphs are zero-valued, the first one (`eldest(g)`) will be retained.
+
+# Arguments:
+- `g::AbstractGraph`: graph to be modified
+"""
+function remove_zero_valued_subgraphs!(g::AbstractGraph)
+ if isleaf(g) || isbranch(g) # we must retain at least one subgraph
+ return g
+ end
+ subg = collect(subgraphs(g))
+ subg_fac = collect(subgraph_factors(g))
+ zero_sgf = zero(subg_fac[1]) # F(0)
+ # Find subgraphs with all-zero subgraph_factors and propagate subfactor one level up
+ for (i, sub_g) in enumerate(subg)
+ if has_zero_subfactors(sub_g)
+ subg_fac[i] = zero_sgf
+ end
+ end
+ # Remove marked zero subgraph factor subgraph(s) of g
+ mask_zeros = findall(x -> x != zero(x), subg_fac)
+ if isempty(mask_zeros)
+ mask_zeros = [1] # retain eldest(g) if all subfactors are zero
+ end
+ set_subgraphs!(g, subg[mask_zeros])
+ set_subgraph_factors!(g, subg_fac[mask_zeros])
+ return g
+end
+
+"""
+ function remove_zero_valued_subgraphs(g::AbstractGraph)
+
+ Returns a copy of graph `g` with zero-valued (zero subgraph_factor) subgraph(s) removed.
+ If all subgraphs are zero-valued, the first one (`eldest(g)`) will be retained.
+
+# Arguments:
+- `g::AbstractGraph`: graph to be modified
+"""
+remove_zero_valued_subgraphs(g::AbstractGraph) = remove_zero_valued_subgraphs!(deepcopy(g))
+
+"""
+ function merge_linear_combination!(g::AbstractGraph)
Modifies a computational graph `g` by factorizing multiplicative prefactors, e.g.,
3*g1 + 5*g2 + 7*g1 + 9*g2 ↦ 10*g1 + 14*g2 = linear_combination(g1, g2, 10, 14).
diff --git a/src/computational_graph/tree_properties.jl b/src/computational_graph/tree_properties.jl
index 84e2819b..b6d820b5 100644
--- a/src/computational_graph/tree_properties.jl
+++ b/src/computational_graph/tree_properties.jl
@@ -98,6 +98,25 @@ function isfactorless(g::AbstractGraph)
end
end
+"""
+ function has_zero_subfactors(g)
+
+ Returns whether the graph g has only zero-valued subgraph factor(s).
+ Note that this function does not recurse through subgraphs of g, so that one may have, e.g.,
+ `isfactorless(g) == true` but `isfactorless(eldest(g)) == false`.
+ By convention, returns `false` if g is a leaf.
+
+# Arguments:
+- `g::AbstractGraph`: graph to be analyzed
+"""
+function has_zero_subfactors(g::AbstractGraph)
+ if isleaf(g)
+ return false # convention: subgraph_factors = [] ⟹ subfactorless = false
+ else
+ return iszero(subgraph_factors(g))
+ end
+end
+
"""
function eldest(g::AbstractGraph)
diff --git a/test/computational_graph.jl b/test/computational_graph.jl
index bb36652a..f61a5c05 100644
--- a/test/computational_graph.jl
+++ b/test/computational_graph.jl
@@ -105,6 +105,12 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
Graphs.set_subgraph_factors!(g, [5.0, 2.0, 3.0], [3, 1, 2]) # default method
@test Graphs.subgraph_factors(g) == [2.0, 3.0, 5.0]
end
+ @testset "Disconnect subgraphs" begin
+ g_dc = deepcopy(g)
+ Graphs.disconnect_subgraphs!(g_dc)
+ @test isempty(Graphs.subgraphs(g_dc))
+ @test isempty(Graphs.subgraph_factors(g_dc))
+ end
@testset "Equivalence" begin
Graphs.set_name!(g, Graphs.name(gp))
@test g == g
@@ -302,6 +308,32 @@ end
@test r2 == Graphs.flatten_chains(rvec[2])
@test r3 == Graphs.flatten_chains(rvec[3])
end
+ @testset "Remove zero-valued subgraphs" begin
+ # leaves
+ l1 = Graph([]; factor=1)
+ l2 = Graph([]; factor=2)
+ l3 = Graph([]; factor=3)
+ l4 = Graph([]; factor=4)
+ l5 = Graph([]; factor=5)
+ l6 = Graph([]; factor=6)
+ l7 = Graph([]; factor=7)
+ l8 = Graph([]; factor=8)
+ # subgraphs
+ sg1 = l1
+ sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1())
+ sg3 = Graph([l4]; subgraph_factors=[0], operator=O2())
+ sg4 = Graph([l5, l6, l7]; subgraph_factors=[0, 0, 0], operator=O3())
+ sg5 = l8
+ # graphs
+ g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O())
+ g_test = Graph([sg1, sg2]; subgraph_factors=[1, 1], operator=O())
+ gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O())
+ gp_test = Graph([sg3]; subgraph_factors=[0], operator=O())
+ Graphs.remove_zero_valued_subgraphs!(g)
+ Graphs.remove_zero_valued_subgraphs!(gp)
+ @test isequiv(g, g_test, :id)
+ @test isequiv(gp, gp_test, :id)
+ end
end
@testset verbose = true "Optimizations" begin
@testset "Flatten all chains" begin
@@ -330,6 +362,35 @@ end
Graphs.flatten_all_chains!(rvec)
@test rvec == [r1, r2, r3]
end
+ @testset "Remove all zero-valued subgraphs" begin
+ # leaves
+ l1 = Graph([]; factor=1)
+ l2 = Graph([]; factor=2)
+ l3 = Graph([]; factor=3)
+ l4 = Graph([]; factor=4)
+ l5 = Graph([]; factor=5)
+ l6 = Graph([]; factor=6)
+ l7 = Graph([]; factor=7)
+ l8 = Graph([]; factor=8)
+ # sub-subgraph
+ ssg1 = Graph([l7]; subgraph_factors=[0], operator=O())
+ # subgraphs
+ sg1 = l1
+ sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1())
+ sg2_test = Graph([l2]; subgraph_factors=[1.0], operator=O1())
+ sg3 = Graph([l4]; subgraph_factors=[0], operator=O2())
+ sg4 = Graph([l5, l6, ssg1]; subgraph_factors=[0, 0, 1], operator=O3())
+ sg5 = l8
+ # graphs
+ g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O())
+ g_test = Graph([sg1, sg2_test]; subgraph_factors=[1, 1], operator=O())
+ gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O())
+ gp_test = Graph([sg3]; subgraph_factors=[0], operator=O())
+ Graphs.remove_all_zero_valued_subgraphs!(g)
+ Graphs.remove_all_zero_valued_subgraphs!(gp)
+ @test isequiv(g, g_test, :id)
+ @test isequiv(gp, gp_test, :id)
+ end
@testset "Merge all linear combinations" begin
g1 = Graph([])
g2 = 2 * g1
@@ -1037,6 +1098,7 @@ end
g3 = 1 * g1
g4 = 1 * g2
g5 = 2 * g1
+ h1 = 0 * g1
# Chains: Ⓧ --- Ⓧ --- gᵢ (simplified by default)
g6 = Graph([g5,]; subgraph_factors=[1,], operator=Graphs.Prod())
g7 = Graph([g3,]; subgraph_factors=[2,], operator=Graphs.Prod())
@@ -1044,6 +1106,8 @@ end
g8 = 2 * (3 * g1 + 5 * g2)
g9 = g1 + 2 * (3 * g1 + 5 * g2)
g10 = g1 * g2 + g8 * g9
+ h2 = Graph([g1, g2]; subgraph_factors=[0, 0], operator=Graphs.Sum())
+ h3 = Graph([g1, g2]; subgraph_factors=[1, 0], operator=Graphs.Sum())
glist = [g1, g2, g8, g9, g10]
@testset "Leaves" begin
@@ -1068,6 +1132,7 @@ end
@test isfactorless(g4)
@test isfactorless(g5) == false
@test isleaf(eldest(g3))
+ @test has_zero_subfactors(h1)
end
@testset "Chains" begin
@test haschildren(g6)
@@ -1090,6 +1155,8 @@ end
@test count_operation(g8) == [1, 0]
@test count_operation(g9) == [2, 0]
@test count_operation(g10) == [4, 2]
+ @test has_zero_subfactors(h2)
+ @test has_zero_subfactors(h3) == false
end
@testset "Iteration" begin
count_pre = sum(1 for node in PreOrderDFS(g9))