Skip to content

Commit

Permalink
add flatten_chains and the relevant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Oct 25, 2023
1 parent f0cee8b commit 39d7986
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ include("eval.jl")
export eval!

include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains

include("optimize.jl")
export optimize!, optimize
Expand Down
49 changes: 47 additions & 2 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose
else
graphs = collect(graphs)
leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize)
merge_all_chains!(graphs, verbose=verbose)
# merge_all_chains!(graphs, verbose=verbose)
flatten_all_chains!(graphs, verbose=verbose)
merge_all_linear_combinations!(graphs, verbose=verbose)
return leaf_mapping
end
Expand Down Expand Up @@ -179,6 +180,50 @@ function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
return graphs
end

"""
function flatten_all_chains!(g::AbstractGraph; verbose=0)
In-place flattens all nodes representing trivial unary chains in the given graph `g`.
# Arguments:
- `graphs`: The graph to be processed.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- The mutated graph `g` with all chains flattened.
"""
function flatten_all_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
for sub_g in g.subgraphs
flatten_all_chains!(sub_g)
flatten_chains!(sub_g)
end
flatten_chains!(g)
return g
end

"""
function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
In-place flattens all nodes representing trivial unary chains in given graphs.
# Arguments:
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- The mutated collection `graphs` with all chains in each graph flattened.
"""
function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
# Post-order DFS
for g in graphs
flatten_all_chains!(g.subgraphs)
flatten_chains!(g)
end
return graphs
end

"""
function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)
Expand Down Expand Up @@ -263,7 +308,7 @@ end
- Optimized graphs.
#
"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
function merge_all_multi_products!(graphs::AbstractVector{<:Graph}; verbose=0)
verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
Expand Down
42 changes: 42 additions & 0 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ end
- `g::AbstractGraph`: graph to be modified
"""
function merge_factorless_chain!(g::AbstractGraph)
if unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
child = eldest(g)
for field in fieldnames(typeof(g))
value = getproperty(child, field)
setproperty!(g, field, value)
end
end
while unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
child = eldest(g)
for field in fieldnames(typeof(g))
Expand Down Expand Up @@ -278,6 +285,41 @@ end
"""
merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g))

"""
function flatten_chains!(g::AbstractGraph)
Recursively flattens chains of subgraphs within the given graph `g` by merging certain trivial unary subgraphs
into their parent graphs in the in-place form.
Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!),
where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation.
# Arguments:
- `g::AbstractGraph`: graph to be modified
"""
function flatten_chains!(g::AbstractGraph)
for (i, sub_g) in enumerate(g.subgraphs)
if onechild(sub_g) && unary_istrivial(sub_g.operator)
flatten_chains!(sub_g)

g.subgraph_factors[i] *= sub_g.subgraph_factors[1]
g.subgraphs[i] = eldest(sub_g)
end
end
return g
end

"""
function flatten_chains(g::AbstractGraph)
Recursively flattens chains of subgraphs within a given graph `g` by merging certain trivial unary subgraphs into their parent graphs,
This function returns a new graph with flatten chains, dervied from the input graph `g` remaining unchanged.
# Arguments:
- `g::AbstractGraph`: graph to be modified
"""
flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g))

"""
function merge_linear_combination(g::Graph)
Expand Down
80 changes: 80 additions & 0 deletions test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,32 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
merge_multi_product!(h1)
@test isequiv(h1, h1_mp, :id)
end
@testset "Flatten chains" begin
l0 = Graph([])
l1 = Graph([l0]; subgraph_factors=[2])
g1 = Graph([l1]; subgraph_factors=[-1], operator=O())
g1c = deepcopy(g1)
g2 = 2 * g1
g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod())
g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod())
r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod())
r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod())
r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O())
rvec = deepcopy([r1, r2, r3])
Graphs.flatten_chains!(r1)
@test isequiv(g1, g1c, :id)
@test isequiv(r1, 210g1, :id)
@test isequiv(g2, 2g1, :id)
@test isequiv(g3, 6g1, :id)
@test isequiv(g4, 30g1, :id)
Graphs.flatten_chains!(r2)
@test isequiv(r2, -30g1, :id)
Graphs.flatten_chains!(r3)
@test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id)
@test r1 == Graphs.flatten_chains(rvec[1])
@test r2 == Graphs.flatten_chains(rvec[2])
@test r3 == Graphs.flatten_chains(rvec[3])
end
end
@testset verbose = true "Optimizations" begin
@testset "Remove one-child parents" begin
Expand Down Expand Up @@ -247,6 +273,40 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
@test h.subgraph_factors == [105]
@test eldest(h) == h0
end
@testset "Flatten all chains" begin
l0 = Graph([])
l1 = Graph([l0]; subgraph_factors=[2])
l2 = Graph([]; factor=3)
g1 = Graph([l1, l2]; subgraph_factors=[-1, 1])
g2 = 2 * g1
g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod())
g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod())
r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod())
r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod())
r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O())
rvec = deepcopy([r1, r2, r3])
rvec1 = deepcopy([r1, r2, r3])
Graphs.flatten_all_chains!(r1)
@test isequiv(g1, Graph([l0, l2]; subgraph_factors=[-2, 1]), :id)
@test isequiv(r1, 210g1, :id)
@test isequiv(g2, 2g1, :id)
@test isequiv(g3, 6g1, :id)
@test isequiv(g4, 30g1, :id)
Graphs.flatten_all_chains!(r2)
@test isequiv(r2, -30g1, :id)
Graphs.flatten_all_chains!(r3)
@test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id)
Graphs.flatten_all_chains!(rvec)
@test rvec == [r1, r2, r3]

Graphs.merge_all_chains!(rvec1)
@test rvec1[1].subgraph_factors == [210]
@test eldest(rvec1[1]) == g1
@test rvec1[2].subgraph_factors == [-1]
@test eldest(rvec1[2]) == g1 # BUG!
@test rvec1[3].subgraph_factors == [2, 7]
@test rvec1[3].subgraphs == [g1, g1] # BUG!
end
@testset "merge all linear combinations" begin
g1 = Graph([])
g2 = 2 * g1
Expand All @@ -265,6 +325,26 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
Graphs.merge_all_linear_combinations!(h0)
@test isequiv(h0.subgraphs[1], _h, :id)
end
@testset "Merge all multi prodicts" begin
g1 = Graph([])
g2 = Graph([], factor=2)
g3 = Graph([], factor=3)
h = Graph([g1, g2, g1, g1, g3, g2]; subgraph_factors=[3, 2, 5, 1, 1, 3], operator=Graphs.Prod())
hvec = repeat([deepcopy(h)], 3)
h0 = Graph([deepcopy(h), g2])
h_s1 = Graph([g1], operator=Graphs.Power(3))
h_s2 = Graph([g2], operator=Graphs.Power(2))
_h = Graph([h_s1, h_s2, g3], subgraph_factors=[15, 6, 1], operator=Graphs.Prod())
# Test on a single graph
Graphs.merge_all_multi_products!(h)
@test isequiv(h, _h, :id)
# Test on a vector of graphs
Graphs.merge_all_multi_products!(hvec)
@test all(isequiv(h, _h, :id) for h in hvec)

Graphs.merge_all_multi_products!(h0)
@test isequiv(h0.subgraphs[1], _h, :id)
end
@testset "optimize" begin
g1 = Graph([])
g2 = 2 * g1
Expand Down

0 comments on commit 39d7986

Please sign in to comment.