diff --git a/src/FeynmanDiagram.jl b/src/FeynmanDiagram.jl index aed2d068..5ed4c1f2 100644 --- a/src/FeynmanDiagram.jl +++ b/src/FeynmanDiagram.jl @@ -119,7 +119,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti # export reducibility, connectivity # export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction # export Coupling_yukawa, Coupling_phi3, Coupling_phi4, Coupling_phi6 -export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest +export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains! export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains diff --git a/src/backend/to_dot.jl b/src/backend/to_dot.jl index b001ada3..69e64f2f 100644 --- a/src/backend/to_dot.jl +++ b/src/backend/to_dot.jl @@ -1,26 +1,26 @@ function to_dotstatic(operator::Type, id::Int, factor, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) error( - "Static representation for computational graph nodes with operator $(operator) not yet implemented! " + "Static representation for computational graph nodes with operator $(operator) not yet implemented! " ) end -function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} +function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} node_temp = "" arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" end opr_node = opr_name * "[shape=box, label = \"Add\", style=filled, fillcolor=cyan,]\n" node_temp *= opr_node - for (gix,(g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) - if gfactor!= 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) + if gfactor != 1 + factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= factor_str * subg_str arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" @@ -29,7 +29,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgra arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" end end - return node_temp,arrow_temp + return node_temp, arrow_temp end function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} @@ -37,27 +37,27 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= opr_node if length(subgraphs) == 1 - if subgraph_factors[1] ==1 + if subgraph_factors[1] == 1 arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" else - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" node_temp *= factor_str arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end else - for (gix,(g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) - if gfactor!= 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) + if gfactor != 1 + factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= factor_str * subg_str arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" @@ -67,7 +67,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg end end end - return node_temp,arrow_temp + return node_temp, arrow_temp end function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} @@ -75,19 +75,19 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" + opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" node_temp *= opr_node * order_node - arrow_temp*= "order$(id)->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" if subgraph_factors[1] != 1 - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= factor_str * subg_str arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" @@ -98,14 +98,14 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, return node_temp, arrow_temp end -function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} +function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} node_temp = "" arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" @@ -123,7 +123,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F,subgr arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" end end - return node_temp,arrow_temp + return node_temp, arrow_temp end function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} @@ -131,20 +131,20 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= opr_node if length(subgraphs) == 1 - if subgraph_factors[1] ==1 + if subgraph_factors[1] == 1 arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" else - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" node_temp *= factor_str arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end @@ -161,7 +161,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg end end end - return node_temp,arrow_temp + return node_temp, arrow_temp end function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} @@ -169,19 +169,19 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, arrow_temp = "" if factor != 1 opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp*= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + opr_name = "g$(id)_t" + node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" node_temp *= opr_fac * node_str else opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" + opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" node_temp *= opr_node * order_node - arrow_temp*= "order$(id)->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" if subgraph_factors[1] != 1 - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" node_temp *= factor_str * subg_str arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" @@ -193,13 +193,15 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, end """ - function to_dot_str(graphs::AbstractVector{<:AbstractGraph}) + function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", diagram_id_map=nothing) + Compile a list of graphs into a string for dot language. + # Arguments: - `graphs` vector of computational graphs - - `title` The name of the complied function (defaults to `"ComputationalGraph"`) + - `title` The name of the compiled function (defaults to nothing) """ -function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="ComputationalGraph") +function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", diagram_id_map=nothing) head = "digraph ComputationalGraph { \nlabel=\"$name\"\n" head *= "ReturnNode[shape=box, label = \"Return\", style=filled, fillcolor=darkorange,]\n" body_node = "" @@ -218,12 +220,12 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu end if isempty(subgraphs(g)) #leaf g_id in inds_visitedleaf && continue - leafname = getname(g.properties, leafidx) + leafname = get_leafname(g, leafidx, diagram_id_map) if factor(g) == 1 gnode_str = "g$g_id[label=$leafname, style=filled, fillcolor=paleturquoise]\n" body_node *= gnode_str else - factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n" + factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n" leaf_node = "l$(leafidx)[label=$leafname, style=filled, fillcolor=paleturquoise]\n" gnode_str = "g$g_id[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" body_node *= factor_str * leaf_node * gnode_str @@ -233,14 +235,14 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu push!(inds_visitedleaf, g_id) else g_id in inds_visitednode && continue - temp_node,temp_arrow = to_dotstatic(operator(g), g_id, factor(g), subgraphs(g), subgraph_factors(g)) - body_node *=temp_node + temp_node, temp_arrow = to_dotstatic(operator(g), g_id, factor(g), subgraphs(g), subgraph_factors(g)) + body_node *= temp_node body_arrow *= temp_arrow push!(inds_visitednode, g_id) end if isroot body_arrow *= "g$(g_id)->ReturnNode[arrowhead=vee,]\n" - rootidx +=1 + rootidx += 1 end end end @@ -250,25 +252,42 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="Compu return expr end -function compile_dot(graphs::AbstractVector{<:AbstractGraph}, filename::String; graph_name="ComputationalGraph") - dot_string = to_dot_str(graphs, graph_name) +function compile_dot(graphs::AbstractVector{<:AbstractGraph}, filename::String; graph_name="", diagram_id_map=nothing) + dot_string = to_dot_str(graphs, graph_name, diagram_id_map) open(filename, "w") do f write(f, dot_string) end end -function getname(properties,leafidx) - if properties isa BareGreenId - lfname = "<G$leafidx>" - elseif properties isa BareInteractionId - lfname = "<V$leafidx>" - elseif typeof(properties) == FeynmanProperties && properties.diagtype == ComputationalGraphs.Propagator - lfname = "<G$leafidx>" - elseif typeof(properties) == FeynmanProperties && properties.diagtype == ComputationalGraphs.Interaction - lfname = "<V$leafidx>" +function get_leafname(g, leafidx, diagram_id_map=nothing) + println(typeof(g)) + leaftype = Nothing + if g isa FeynmanGraph + leaftype = g.properties.diagtype + elseif g isa Graph + if isnothing(diagram_id_map) == false + leaftype = typeof(diagram_id_map[g.id]) + end + else + error("Unknown graph type: $(typeof(g))") + end + if leaftype in [BareGreenId, ComputationalGraphs.Propagator] + leafname = "<G$leafidx>" + elseif leaftype in [BareInteractionId, ComputationalGraphs.Interaction] + leafname = "<V$leafidx>" + elseif leaftype == PolarId + leafname = "<Π$leafidx>" + elseif leaftype == Ver3Id + leafname = "<Γ(3)$leafidx>" + elseif leaftype == Ver4Id + leafname = "<Γ(4)$leafidx>" else - lfname = "$leafidx>" + leafname = "$leafidx>" end - return lfname + println() + println(g) + println(leaftype) + println(leafname) + println() + return leafname end - diff --git a/src/computational_graph/ComputationalGraph.jl b/src/computational_graph/ComputationalGraph.jl index 97c2bffb..4a50a17f 100644 --- a/src/computational_graph/ComputationalGraph.jl +++ b/src/computational_graph/ComputationalGraph.jl @@ -38,7 +38,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti include("tree_properties.jl") -export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest, count_operation +export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest, count_operation include("operation.jl") include("io.jl") diff --git a/src/computational_graph/abstractgraph.jl b/src/computational_graph/abstractgraph.jl index a5f53a1d..d47d6b8d 100644 --- a/src/computational_graph/abstractgraph.jl +++ b/src/computational_graph/abstractgraph.jl @@ -254,6 +254,16 @@ function set_subgraph_factors!(g::AbstractGraph, subgraph_factors::AbstractVecto end end +""" +function disconnect_subgraphs!(g::G) where {G<:AbstractGraph} + + Empty the subgraphs and subgraph_factors of graph `g`. Any child nodes of g + not referenced elsewhere in the full computational graph are effectively deleted. +""" +function disconnect_subgraphs!(g::AbstractGraph) + empty!(subgraphs(g)) + empty!(subgraph_factors(g)) +end ### Methods ### diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index 95e7cf39..0017f3ad 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -16,7 +16,7 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize) flatten_all_chains!(graphs, verbose=verbose) merge_all_linear_combinations!(graphs, verbose=verbose) - + remove_all_zero_valued_subgraphs!(graphs, verbose=verbose) return graphs end end @@ -65,7 +65,7 @@ end """ function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) - Flattens all nodes representing trivial unary chains in-place in given graphs. + Flattens all nodes representing trivial unary chains in-place in the given graphs. # Arguments: - `graphs`: A collection of graphs to be processed. @@ -84,10 +84,57 @@ function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph} return graphs end +""" + function remove_all_zero_valued_subgraphs!(g::AbstractGraph; verbose=0) + + Recursively removes all zero-valued subgraph(s) in-place in the given graph `g`. + +# Arguments: +- `g`: An AbstractGraph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" +function remove_all_zero_valued_subgraphs!(g::AbstractGraph; verbose=0) + verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.") + # Post-order DFS + for sub_g in subgraphs(g) + remove_all_zero_valued_subgraphs!(sub_g) + remove_zero_valued_subgraphs!(sub_g) + end + remove_zero_valued_subgraphs!(g) + return g +end + +""" + function remove_all_zero_valued_subgraphs!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) + + Recursively removes all zero-valued subgraph(s) in-place in the given graphs. + +# Arguments: +- `graphs`: A collection of graphs to be processed. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function remove_all_zero_valued_subgraphs!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) + verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.") + # Post-order DFS + for g in graphs + remove_all_zero_valued_subgraphs!(subgraphs(g)) + remove_zero_valued_subgraphs!(g) + end + return graphs +end + """ function merge_all_linear_combinations!(g::AbstractGraph; verbose=0) - Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place within a single graph. + Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in the given graph `g`. # Arguments: - `g`: An AbstractGraph. @@ -111,7 +158,7 @@ end """ function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0) - Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in given graphs. + Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in the given graphs. # Arguments: - `graphs`: A collection of graphs to be processed. @@ -134,7 +181,7 @@ end """ function merge_all_multi_products!(g::Graph; verbose=0) - Merges all nodes representing a multi product of a non-unique list of subgraphs in-place within a single graph. + Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in the given graph `g`. # Arguments: - `g::Graph`: A Graph. @@ -158,7 +205,7 @@ end """ function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0) - Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in given graphs. + Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in the given graphs. # Arguments: - `graphs`: A collection of graphs to be processed. diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 3c5c6380..fecf791a 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -183,7 +183,7 @@ end function flatten_chains(g::AbstractGraph) Recursively flattens chains of subgraphs within a given graph `g` by merging certain trivial unary subgraphs into their parent graphs, - This function returns a new graph with flatten chains, dervied from the input graph `g` remaining unchanged. + This function returns a new graph with flatten chains, derived from the input graph `g` remaining unchanged. # Arguments: - `g::AbstractGraph`: graph to be modified @@ -191,7 +191,49 @@ end flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g)) """ - function merge_linear_combination(g::AbstractGraph) + function remove_zero_valued_subgraphs!(g::AbstractGraph) + + Removes zero-valued (zero subgraph_factor) subgraph(s) of a computational graph `g`. If all subgraphs are zero-valued, the first one (`eldest(g)`) will be retained. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" +function remove_zero_valued_subgraphs!(g::AbstractGraph) + if isleaf(g) || isbranch(g) # we must retain at least one subgraph + return g + end + subg = collect(subgraphs(g)) + subg_fac = collect(subgraph_factors(g)) + zero_sgf = zero(subg_fac[1]) # F(0) + # Find subgraphs with all-zero subgraph_factors and propagate subfactor one level up + for (i, sub_g) in enumerate(subg) + if has_zero_subfactors(sub_g) + subg_fac[i] = zero_sgf + end + end + # Remove marked zero subgraph factor subgraph(s) of g + mask_zeros = findall(x -> x != zero(x), subg_fac) + if isempty(mask_zeros) + mask_zeros = [1] # retain eldest(g) if all subfactors are zero + end + set_subgraphs!(g, subg[mask_zeros]) + set_subgraph_factors!(g, subg_fac[mask_zeros]) + return g +end + +""" + function remove_zero_valued_subgraphs(g::AbstractGraph) + + Returns a copy of graph `g` with zero-valued (zero subgraph_factor) subgraph(s) removed. + If all subgraphs are zero-valued, the first one (`eldest(g)`) will be retained. + +# Arguments: +- `g::AbstractGraph`: graph to be modified +""" +remove_zero_valued_subgraphs(g::AbstractGraph) = remove_zero_valued_subgraphs!(deepcopy(g)) + +""" + function merge_linear_combination!(g::AbstractGraph) Modifies a computational graph `g` by factorizing multiplicative prefactors, e.g., 3*g1 + 5*g2 + 7*g1 + 9*g2 ↦ 10*g1 + 14*g2 = linear_combination(g1, g2, 10, 14). diff --git a/src/computational_graph/tree_properties.jl b/src/computational_graph/tree_properties.jl index 84e2819b..b6d820b5 100644 --- a/src/computational_graph/tree_properties.jl +++ b/src/computational_graph/tree_properties.jl @@ -98,6 +98,25 @@ function isfactorless(g::AbstractGraph) end end +""" + function has_zero_subfactors(g) + + Returns whether the graph g has only zero-valued subgraph factor(s). + Note that this function does not recurse through subgraphs of g, so that one may have, e.g., + `isfactorless(g) == true` but `isfactorless(eldest(g)) == false`. + By convention, returns `false` if g is a leaf. + +# Arguments: +- `g::AbstractGraph`: graph to be analyzed +""" +function has_zero_subfactors(g::AbstractGraph) + if isleaf(g) + return false # convention: subgraph_factors = [] ⟹ subfactorless = false + else + return iszero(subgraph_factors(g)) + end +end + """ function eldest(g::AbstractGraph) diff --git a/test/computational_graph.jl b/test/computational_graph.jl index bb36652a..f61a5c05 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -105,6 +105,12 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true Graphs.set_subgraph_factors!(g, [5.0, 2.0, 3.0], [3, 1, 2]) # default method @test Graphs.subgraph_factors(g) == [2.0, 3.0, 5.0] end + @testset "Disconnect subgraphs" begin + g_dc = deepcopy(g) + Graphs.disconnect_subgraphs!(g_dc) + @test isempty(Graphs.subgraphs(g_dc)) + @test isempty(Graphs.subgraph_factors(g_dc)) + end @testset "Equivalence" begin Graphs.set_name!(g, Graphs.name(gp)) @test g == g @@ -302,6 +308,32 @@ end @test r2 == Graphs.flatten_chains(rvec[2]) @test r3 == Graphs.flatten_chains(rvec[3]) end + @testset "Remove zero-valued subgraphs" begin + # leaves + l1 = Graph([]; factor=1) + l2 = Graph([]; factor=2) + l3 = Graph([]; factor=3) + l4 = Graph([]; factor=4) + l5 = Graph([]; factor=5) + l6 = Graph([]; factor=6) + l7 = Graph([]; factor=7) + l8 = Graph([]; factor=8) + # subgraphs + sg1 = l1 + sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1()) + sg3 = Graph([l4]; subgraph_factors=[0], operator=O2()) + sg4 = Graph([l5, l6, l7]; subgraph_factors=[0, 0, 0], operator=O3()) + sg5 = l8 + # graphs + g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O()) + g_test = Graph([sg1, sg2]; subgraph_factors=[1, 1], operator=O()) + gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O()) + gp_test = Graph([sg3]; subgraph_factors=[0], operator=O()) + Graphs.remove_zero_valued_subgraphs!(g) + Graphs.remove_zero_valued_subgraphs!(gp) + @test isequiv(g, g_test, :id) + @test isequiv(gp, gp_test, :id) + end end @testset verbose = true "Optimizations" begin @testset "Flatten all chains" begin @@ -330,6 +362,35 @@ end Graphs.flatten_all_chains!(rvec) @test rvec == [r1, r2, r3] end + @testset "Remove all zero-valued subgraphs" begin + # leaves + l1 = Graph([]; factor=1) + l2 = Graph([]; factor=2) + l3 = Graph([]; factor=3) + l4 = Graph([]; factor=4) + l5 = Graph([]; factor=5) + l6 = Graph([]; factor=6) + l7 = Graph([]; factor=7) + l8 = Graph([]; factor=8) + # sub-subgraph + ssg1 = Graph([l7]; subgraph_factors=[0], operator=O()) + # subgraphs + sg1 = l1 + sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1()) + sg2_test = Graph([l2]; subgraph_factors=[1.0], operator=O1()) + sg3 = Graph([l4]; subgraph_factors=[0], operator=O2()) + sg4 = Graph([l5, l6, ssg1]; subgraph_factors=[0, 0, 1], operator=O3()) + sg5 = l8 + # graphs + g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O()) + g_test = Graph([sg1, sg2_test]; subgraph_factors=[1, 1], operator=O()) + gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O()) + gp_test = Graph([sg3]; subgraph_factors=[0], operator=O()) + Graphs.remove_all_zero_valued_subgraphs!(g) + Graphs.remove_all_zero_valued_subgraphs!(gp) + @test isequiv(g, g_test, :id) + @test isequiv(gp, gp_test, :id) + end @testset "Merge all linear combinations" begin g1 = Graph([]) g2 = 2 * g1 @@ -1037,6 +1098,7 @@ end g3 = 1 * g1 g4 = 1 * g2 g5 = 2 * g1 + h1 = 0 * g1 # Chains: Ⓧ --- Ⓧ --- gᵢ (simplified by default) g6 = Graph([g5,]; subgraph_factors=[1,], operator=Graphs.Prod()) g7 = Graph([g3,]; subgraph_factors=[2,], operator=Graphs.Prod()) @@ -1044,6 +1106,8 @@ end g8 = 2 * (3 * g1 + 5 * g2) g9 = g1 + 2 * (3 * g1 + 5 * g2) g10 = g1 * g2 + g8 * g9 + h2 = Graph([g1, g2]; subgraph_factors=[0, 0], operator=Graphs.Sum()) + h3 = Graph([g1, g2]; subgraph_factors=[1, 0], operator=Graphs.Sum()) glist = [g1, g2, g8, g9, g10] @testset "Leaves" begin @@ -1068,6 +1132,7 @@ end @test isfactorless(g4) @test isfactorless(g5) == false @test isleaf(eldest(g3)) + @test has_zero_subfactors(h1) end @testset "Chains" begin @test haschildren(g6) @@ -1090,6 +1155,8 @@ end @test count_operation(g8) == [1, 0] @test count_operation(g9) == [2, 0] @test count_operation(g10) == [4, 2] + @test has_zero_subfactors(h2) + @test has_zero_subfactors(h3) == false end @testset "Iteration" begin count_pre = sum(1 for node in PreOrderDFS(g9))