Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flatten_chains and the relevant tests #150

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ include("eval.jl")
export eval!

include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains

include("optimize.jl")
export optimize!, optimize
Expand Down
49 changes: 47 additions & 2 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose
else
graphs = collect(graphs)
leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize)
merge_all_chains!(graphs, verbose=verbose)
# merge_all_chains!(graphs, verbose=verbose)
houpc marked this conversation as resolved.
Show resolved Hide resolved
flatten_all_chains!(graphs, verbose=verbose)
merge_all_linear_combinations!(graphs, verbose=verbose)
return leaf_mapping
end
Expand Down Expand Up @@ -179,6 +180,50 @@ function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
return graphs
end

"""
function flatten_all_chains!(g::AbstractGraph; verbose=0)

In-place flattens all nodes representing trivial unary chains in the given graph `g`.
houpc marked this conversation as resolved.
Show resolved Hide resolved

# Arguments:
- `graphs`: The graph to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- The mutated graph `g` with all chains flattened.
"""
function flatten_all_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
for sub_g in g.subgraphs
flatten_all_chains!(sub_g)
flatten_chains!(sub_g)
end
flatten_chains!(g)
return g
end

"""
function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)

In-place flattens all nodes representing trivial unary chains in given graphs.
houpc marked this conversation as resolved.
Show resolved Hide resolved

# Arguments:
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- The mutated collection `graphs` with all chains in each graph flattened.
"""
function flatten_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
# Post-order DFS
for g in graphs
flatten_all_chains!(g.subgraphs)
flatten_chains!(g)
end
return graphs
end

"""
function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)

Expand Down Expand Up @@ -263,7 +308,7 @@ end
- Optimized graphs.
#
"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
function merge_all_multi_products!(graphs::AbstractVector{<:Graph}; verbose=0)
houpc marked this conversation as resolved.
Show resolved Hide resolved
verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
Expand Down
42 changes: 42 additions & 0 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ end
- `g::AbstractGraph`: graph to be modified
"""
function merge_factorless_chain!(g::AbstractGraph)
if unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
houpc marked this conversation as resolved.
Show resolved Hide resolved
child = eldest(g)
for field in fieldnames(typeof(g))
value = getproperty(child, field)
setproperty!(g, field, value)
end
end
while unary_istrivial(g.operator) && onechild(g) && isfactorless(g)
child = eldest(g)
for field in fieldnames(typeof(g))
Expand Down Expand Up @@ -278,6 +285,41 @@ end
"""
merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g))

"""
function flatten_chains!(g::AbstractGraph)

Recursively flattens chains of subgraphs within the given graph `g` by merging certain trivial unary subgraphs
into their parent graphs in the in-place form.

Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!),
where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation.

# Arguments:
- `g::AbstractGraph`: graph to be modified
"""
function flatten_chains!(g::AbstractGraph)
for (i, sub_g) in enumerate(g.subgraphs)
if onechild(sub_g) && unary_istrivial(sub_g.operator)
flatten_chains!(sub_g)

g.subgraph_factors[i] *= sub_g.subgraph_factors[1]
g.subgraphs[i] = eldest(sub_g)
end
end
return g
end

"""
function flatten_chains(g::AbstractGraph)

Recursively flattens chains of subgraphs within a given graph `g` by merging certain trivial unary subgraphs into their parent graphs,
This function returns a new graph with flatten chains, dervied from the input graph `g` remaining unchanged.

# Arguments:
- `g::AbstractGraph`: graph to be modified
"""
flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g))

"""
function merge_linear_combination(g::Graph)

Expand Down
80 changes: 80 additions & 0 deletions test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,32 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
merge_multi_product!(h1)
@test isequiv(h1, h1_mp, :id)
end
@testset "Flatten chains" begin
l0 = Graph([])
l1 = Graph([l0]; subgraph_factors=[2])
g1 = Graph([l1]; subgraph_factors=[-1], operator=O())
g1c = deepcopy(g1)
g2 = 2 * g1
g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod())
g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod())
r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod())
r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod())
r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O())
rvec = deepcopy([r1, r2, r3])
Graphs.flatten_chains!(r1)
@test isequiv(g1, g1c, :id)
@test isequiv(r1, 210g1, :id)
@test isequiv(g2, 2g1, :id)
@test isequiv(g3, 6g1, :id)
@test isequiv(g4, 30g1, :id)
Graphs.flatten_chains!(r2)
@test isequiv(r2, -30g1, :id)
Graphs.flatten_chains!(r3)
@test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id)
@test r1 == Graphs.flatten_chains(rvec[1])
@test r2 == Graphs.flatten_chains(rvec[2])
@test r3 == Graphs.flatten_chains(rvec[3])
end
end
@testset verbose = true "Optimizations" begin
@testset "Remove one-child parents" begin
Expand Down Expand Up @@ -247,6 +273,40 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
@test h.subgraph_factors == [105]
@test eldest(h) == h0
end
@testset "Flatten all chains" begin
l0 = Graph([])
l1 = Graph([l0]; subgraph_factors=[2])
l2 = Graph([]; factor=3)
g1 = Graph([l1, l2]; subgraph_factors=[-1, 1])
g2 = 2 * g1
g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod())
g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod())
r1 = Graph([g4,]; subgraph_factors=[7,], operator=Graphs.Prod())
r2 = Graph([g4,]; subgraph_factors=[-1,], operator=Graphs.Prod())
r3 = Graph([g3, g4,]; subgraph_factors=[2, 7], operator=O())
rvec = deepcopy([r1, r2, r3])
rvec1 = deepcopy([r1, r2, r3])
Graphs.flatten_all_chains!(r1)
@test isequiv(g1, Graph([l0, l2]; subgraph_factors=[-2, 1]), :id)
@test isequiv(r1, 210g1, :id)
@test isequiv(g2, 2g1, :id)
@test isequiv(g3, 6g1, :id)
@test isequiv(g4, 30g1, :id)
Graphs.flatten_all_chains!(r2)
@test isequiv(r2, -30g1, :id)
Graphs.flatten_all_chains!(r3)
@test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id)
Graphs.flatten_all_chains!(rvec)
@test rvec == [r1, r2, r3]

Graphs.merge_all_chains!(rvec1)
@test rvec1[1].subgraph_factors == [210]
@test eldest(rvec1[1]) == g1
@test rvec1[2].subgraph_factors == [-1]
@test eldest(rvec1[2]) == g1 # BUG!
@test rvec1[3].subgraph_factors == [2, 7]
@test rvec1[3].subgraphs == [g1, g1] # BUG!
end
@testset "merge all linear combinations" begin
g1 = Graph([])
g2 = 2 * g1
Expand All @@ -265,6 +325,26 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
Graphs.merge_all_linear_combinations!(h0)
@test isequiv(h0.subgraphs[1], _h, :id)
end
@testset "Merge all multi prodicts" begin
g1 = Graph([])
g2 = Graph([], factor=2)
g3 = Graph([], factor=3)
h = Graph([g1, g2, g1, g1, g3, g2]; subgraph_factors=[3, 2, 5, 1, 1, 3], operator=Graphs.Prod())
hvec = repeat([deepcopy(h)], 3)
h0 = Graph([deepcopy(h), g2])
h_s1 = Graph([g1], operator=Graphs.Power(3))
h_s2 = Graph([g2], operator=Graphs.Power(2))
_h = Graph([h_s1, h_s2, g3], subgraph_factors=[15, 6, 1], operator=Graphs.Prod())
# Test on a single graph
Graphs.merge_all_multi_products!(h)
@test isequiv(h, _h, :id)
# Test on a vector of graphs
Graphs.merge_all_multi_products!(hvec)
@test all(isequiv(h, _h, :id) for h in hvec)

Graphs.merge_all_multi_products!(h0)
@test isequiv(h0.subgraphs[1], _h, :id)
end
@testset "optimize" begin
g1 = Graph([])
g2 = 2 * g1
Expand Down
Loading