Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flatten_chains and the relevant tests #150

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ include("eval.jl")
export eval!

include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!, flatten_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains, flatten_chains

include("optimize.jl")
export optimize!, optimize
Expand Down
166 changes: 37 additions & 129 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose
else
graphs = collect(graphs)
leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize)
merge_all_chains!(graphs, verbose=verbose)
flatten_all_chains!(graphs, verbose=verbose)
merge_all_linear_combinations!(graphs, verbose=verbose)
return leaf_mapping
end
Expand All @@ -26,7 +26,7 @@ end
"""
function optimize(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing)

Optimize a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations.
Optimizes a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations.

# Arguments:
- `graphs`: A tuple or vector of graphs.
Expand All @@ -44,145 +44,53 @@ function optimize(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=
end

"""
function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0)

In-place merge prefactors of all nodes representing trivial unary chains towards the root level for a single graph.

# Arguments:
- `g`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graph.
#
"""
function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0)
verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.")
# Post-order DFS
for sub_g in g.subgraphs
merge_all_chain_prefactors!(sub_g)
merge_chain_prefactors!(sub_g)
end
merge_chain_prefactors!(g)
return g
end

"""
function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)

In-place merge prefactors of all nodes representing trivial unary chains towards the root level for given graphs.
function flatten_all_chains!(g::AbstractGraph; verbose=0)
F
Flattens all nodes representing trivial unary chains in-place in the given graph `g`.

# Arguments:
- `graphs`: An AbstractVector of graphs.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graphs.
#
"""
function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.")
# Post-order DFS
for g in graphs
merge_all_chain_prefactors!(g.subgraphs)
merge_chain_prefactors!(g)
end
return graphs
end

"""
function merge_all_factorless_chains!(g::AbstractGraph; verbose=0)

In-place merge all nodes representing factorless trivial unary chains within a single graph.

# Arguments:
- `g`: An AbstractGraph.
- `graphs`: The graph to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graph.
#
- The mutated graph `g` with all chains flattened.
"""
function merge_all_factorless_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("merge all nodes representing factorless trivial unary chains.")
# Post-order DFS
function flatten_all_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
for sub_g in g.subgraphs
merge_all_factorless_chains!(sub_g)
merge_factorless_chain!(sub_g)
flatten_all_chains!(sub_g)
flatten_chains!(sub_g)
end
merge_factorless_chain!(g)
flatten_chains!(g)
return g
end

"""
function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)

In-place merge all nodes representing factorless trivial unary chains within given graphs.
Flattens all nodes representing trivial unary chains in-place in given graphs.

# Arguments:
- `graphs`: An AbstractVector of graphs.
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graphs.
#
- The mutated collection `graphs` with all chains in each graph flattened.
"""
function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("merge all nodes representing factorless trivial unary chains.")
function flatten_all_chains!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
verbose > 0 && println("flatten all nodes representing trivial unary chains.")
# Post-order DFS
for g in graphs
merge_all_factorless_chains!(g.subgraphs)
merge_factorless_chain!(g)
flatten_all_chains!(g.subgraphs)
flatten_chains!(g)
end
return graphs
end

"""
function merge_all_chains!(g::AbstractGraph; verbose=0)

In-place merge all nodes representing trivial unary chains within a single graph.
This function consolidates both chain prefactors and factorless chains.

# Arguments:
- `g`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graph.
#
"""
function merge_all_chains!(g::AbstractGraph; verbose=0)
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(g, verbose=verbose)
merge_all_factorless_chains!(g, verbose=verbose)
return g
end

"""
function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) where {G<:AbstractGraph}

In-place merge all nodes representing trivial unary chains in given graphs.
This function consolidates both chain prefactors and factorless chains.

# Arguments:
- `graphs`: An AbstractVector of graphs.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graphs.
#
"""
function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(graphs, verbose=verbose)
merge_all_factorless_chains!(graphs, verbose=verbose)
return graphs
end

"""
function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)

In-place merge all nodes representing a linear combination of a non-unique list of subgraphs within a single graph.
Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place within a single graph.

# Arguments:
- `g`: An AbstractGraph.
Expand All @@ -204,19 +112,19 @@ function merge_all_linear_combinations!(g::AbstractGraph; verbose=0)
end

"""
function merge_all_linear_combinations!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)

In-place merge all nodes representing a linear combination of a non-unique list of subgraphs in given graphs.
Merges all nodes representing a linear combination of a non-unique list of subgraphs in-place in given graphs.

# Arguments:
- `graphs`: An AbstractVector of graphs.
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graphs.
#
"""
function merge_all_linear_combinations!(graphs::AbstractVector{<:AbstractGraph}; verbose=0)
function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0)
verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
Expand All @@ -229,7 +137,7 @@ end
"""
function merge_all_multi_products!(g::Graph; verbose=0)

In-place merge all nodes representing a multi product of a non-unique list of subgraphs within a single graph.
Merges all nodes representing a multi product of a non-unique list of subgraphs in-place within a single graph.

# Arguments:
- `g::Graph`: A Graph.
Expand All @@ -251,19 +159,19 @@ function merge_all_multi_products!(g::Graph; verbose=0)
end

"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0)

In-place merge all nodes representing a multi product of a non-unique list of subgraphs in given graphs.
Merges all nodes representing a multi product of a non-unique list of subgraphs in-place in given graphs.

# Arguments:
- `graphs::Union{Tuple,AbstractVector{Graph}}`: A tuple or vector of graphs.
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).

# Returns:
- Optimized graphs.
#
"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{<:Graph}}; verbose=0)
verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
Expand All @@ -276,10 +184,10 @@ end
"""
function unique_leaves(_graphs::AbstractVector{<:AbstractGraph})

Identify and retrieve unique leaf nodes from a set of graphs.
Identifies and retrieves unique leaf nodes from a set of graphs.

# Arguments:
- `_graphs`: A tuple or vector of graphs.
- `_graphs`: A collection of graphs to be processed.

# Returns:
- The vector of unique leaf nodes.
Expand Down Expand Up @@ -310,19 +218,19 @@ function unique_leaves(_graphs::AbstractVector{<:AbstractGraph})
end

"""
function remove_duplicated_leaves!(graphs::AbstractVector{<:AbstractGraph}; verbose=0, normalize=nothing, kwargs...)
function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing, kwargs...)

In-place remove duplicated leaf nodes from a collection of graphs. It also provides optional normalization for these leaves.
Removes duplicated leaf nodes in-place from a collection of graphs. It also provides optional normalization for these leaves.

# Arguments:
- `graphs`: An AbstractVector of graphs.
- `graphs`: A collection of graphs to be processed.
- `verbose`: Level of verbosity (default: 0).
- `normalize`: Optional function to normalize the graphs (default: nothing).

# Returns:
- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)).
"""
function remove_duplicated_leaves!(graphs::AbstractVector{<:AbstractGraph}; verbose=0, normalize=nothing, kwargs...)
function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose=0, normalize=nothing, kwargs...)
verbose > 0 && println("remove duplicated leaves.")
leaves = Vector{eltype(graphs)}()
for g in graphs
Expand Down Expand Up @@ -356,7 +264,7 @@ end
"""
function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph}

In-place remove all nodes connected to the target leaves via "Prod" operators.
Removes all nodes connected to the target leaves in-place via "Prod" operators.

# Arguments:
- `graphs`: An AbstractVector of graphs.
Expand Down
Loading
Loading