Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/computgraph' into tao_AD
Browse files Browse the repository at this point in the history
  • Loading branch information
fsxbhyy committed Oct 22, 2023
2 parents 5ac9ac5 + ec763fb commit 071ce89
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/FeynmanDiagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ export linear_combination, feynman_diagram, propagator, interaction, external_ve
# export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction
# export Coupling_yukawa, Coupling_phi3, Coupling_phi4, Coupling_phi6
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_chains
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains

export optimize!, optimize, merge_all_chains!, merge_all_linear_combinations!, remove_duplicated_leaves!

Expand Down
20 changes: 14 additions & 6 deletions src/backend/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ function to_julia_str(graphs::AbstractVector{G}; root::AbstractVector{Int}=[g.id
inds_visitednode = Int[]
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
target = "g$(g.id)"
isroot = false
if g.id in root
target = "root[$(findfirst(x -> x == g.id, root))]"
else
target = "g$(g.id)"
target_root = "root[$(findfirst(x -> x == g.id, root))]"
isroot = true
end
if isempty(g.subgraphs) #leaf
g.id in inds_visitedleaf && continue
Expand All @@ -80,6 +81,9 @@ function to_julia_str(graphs::AbstractVector{G}; root::AbstractVector{Int}=[g.id
body *= " $target = $(_to_static(g.operator, g.subgraphs, g.subgraph_factors))$factor_str\n "
push!(inds_visitednode, g.id)
end
if isroot
body *= " $target_root = $target\n "
end
end
end
tail = "end"
Expand Down Expand Up @@ -107,10 +111,11 @@ function to_julia_str(graphs::AbstractVector{G}, leafMap::Dict{Int,Int}; root::A
inds_visitednode = Int[]
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
target = "g$(g.id)"
isroot = false
if g.id in root
target = "root[$(findfirst(x -> x == g.id, root))]"
else
target = "g$(g.id)"
target_root = "root[$(findfirst(x -> x == g.id, root))]"
isroot = true
end
if isempty(g.subgraphs) #leaf
g.id in inds_visitedleaf && continue
Expand All @@ -123,6 +128,9 @@ function to_julia_str(graphs::AbstractVector{G}, leafMap::Dict{Int,Int}; root::A
body *= " $target = $(_to_static(g.operator, g.subgraphs, g.subgraph_factors))$factor_str\n "
push!(inds_visitednode, g.id)
end
if isroot
body *= " $target_root = $target\n "
end
end
end
tail = "end"
Expand Down
4 changes: 2 additions & 2 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ include("eval.jl")
export eval!

include("transform.jl")
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_chains
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains

include("optimize.jl")
export optimize!, optimize
Expand Down
47 changes: 47 additions & 0 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,53 @@ function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}};
return graphs
end

"""
function merge_all_multi_products!(g::Graph; verbose=0)
In-place merge all nodes representing a multi product of a non-unique list of subgraphs within a single graph.
# Arguments:
- `g::Graph`: A Graph.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graph.
#
"""
function merge_all_multi_products!(g::Graph; verbose=0)
verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.")
# Post-order DFS
for sub_g in g.subgraphs
merge_all_multi_products!(sub_g)
merge_multi_product!(sub_g)
end
merge_multi_product!(g)
return g
end

"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
In-place merge all nodes representing a multi product of a non-unique list of subgraphs in given graphs.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{Graph}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graphs.
#
"""
function merge_all_multi_products!(graphs::Union{Tuple,AbstractVector{Graph}}; verbose=0)
verbose > 0 && println("merge nodes representing a multi product of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
merge_all_multi_products!(g.subgraphs)
merge_multi_product!(g)
end
return graphs
end

"""
function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph}
Expand Down
60 changes: 60 additions & 0 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,63 @@ function merge_linear_combination!(g::FeynmanGraph{F,W}) where {F,W}
end
return g
end

"""
function merge_multi_product(g::Graph{F,W}) where {F,W}
Merge multiple products within a computational graph `g` if they share the same operator (`Prod`).
If `g.operator == Prod`, this function will merge `N` identical subgraphs into a single subgraph with a power operator `Power(N)`.
The function ensures each unique subgraph is counted and merged appropriately, preserving any distinct subgraph_factors associated with them.
# Arguments:
- `g::Graph`: graph to be modified
# Returns
- A merged computational graph with potentially fewer subgraphs if there were repeating subgraphs
with the `Prod` operator. If the input graph's operator isn't `Prod`, the function returns the input graph unchanged.
"""
function merge_multi_product(g::Graph{F,W}) where {F,W}
if g.operator == Prod
unique_graphs = Vector{Graph{F,W}}()
unique_factors = F[]
repeated_counts = Int[]
for (idx, subg) in enumerate(g.subgraphs)
loc = findfirst(isequal(subg), unique_graphs)
if isnothing(loc)
push!(unique_graphs, subg)
push!(unique_factors, g.subgraph_factors[idx])
push!(repeated_counts, 1)
else
unique_factors[loc] *= g.subgraph_factors[idx]
repeated_counts[loc] += 1
end
end

if length(unique_factors) == 1
g_merged = Graph(unique_graphs; subgraph_factors=unique_factors, operator=Power(repeated_counts[1]), ftype=F, wtype=W)
else
subgraphs = Vector{Graph{F,W}}()
for (idx, g) in enumerate(unique_graphs)
if repeated_counts[idx] == 1
push!(subgraphs, g)
else
push!(subgraphs, Graph([g], operator=Power(repeated_counts[idx]), ftype=F, wtype=W))
end
end
g_merged = Graph(subgraphs; subgraph_factors=unique_factors, operator=Prod(), ftype=F, wtype=W)
end
return g_merged
else
return g
end
end

function merge_multi_product!(g::Graph{F,W}) where {F,W}
if g.operator == Prod
g_merged = merge_multi_product(g)
g.subgraphs = g_merged.subgraphs
g.subgraph_factors = g_merged.subgraph_factors
g.operator = g_merged.operator
end
return g
end
13 changes: 13 additions & 0 deletions test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true
@test h8.subgraph_factors == [36]
@test isequiv(h7_lc, h8, :id)
end
@testset "Merge multi prodict" begin
g1 = Graph([])
g2 = Graph([], factor=2)
g3 = Graph([], factor=3)
h1 = Graph([g1, g2, g1, g1, g3, g2]; subgraph_factors=[3, 2, 5, 1, 1, 3], operator=Graphs.Prod())
h1_mp = merge_multi_product(h1)
h1_s1 = Graph([g1], operator=Graphs.Power(3))
h1_s2 = Graph([g2], operator=Graphs.Power(2))
h1_r = Graph([h1_s1, h1_s2, g3], subgraph_factors=[15, 6, 1], operator=Graphs.Prod())
@test isequiv(h1_r, h1_mp, :id)
merge_multi_product!(h1)
@test isequiv(h1, h1_mp, :id)
end
end
@testset verbose = true "Optimizations" begin
@testset "Remove one-child parents" begin
Expand Down

0 comments on commit 071ce89

Please sign in to comment.