Skip to content

Commit

Permalink
udpate mask_zero_subgraph_factors and has_zero_subfactors with multip…
Browse files Browse the repository at this point in the history
…le dispatch
  • Loading branch information
houpc committed Oct 31, 2024
1 parent db67fa3 commit 391fd5c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 45 deletions.
2 changes: 2 additions & 0 deletions src/computational_graph/feynmangraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ mutable struct FeynmanGraph{F<:Number,W} <: AbstractGraph # FeynmanGraph
@assert length(external_indices) == length(external_legs)
if typeof(operator) <: Power
@assert length(subgraphs) == 1 "FeynmanGraph with Power operator must have one and only one subgraph."
elseif typeof(operator) <: Unitary
@assert length(subgraphs) == 0 "FeynmanGraph with Unitary operator must have no subgraphs."
end
# @assert allunique(subgraphs) "all subgraphs must be distinct."
if isnothing(vertices)
Expand Down
2 changes: 2 additions & 0 deletions src/computational_graph/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ mutable struct Graph{F<:Number,W} <: AbstractGraph # Graph
)
if typeof(operator) <: Power
@assert length(subgraphs) == 1 "Graph with Power operator must have one and only one subgraph."
elseif typeof(operator) <: Unitary
@assert length(subgraphs) == 0 "Graph with Unitary operator must have no subgraphs."
end
# @assert allunique(subgraphs) "all subgraphs must be distinct."
g = new{ftype,wtype}(uid(), String(name), orders, subgraphs, subgraph_factors, typeof(operator), weight, properties)
Expand Down
59 changes: 47 additions & 12 deletions src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,47 @@ end
"""
flatten_chains(g::AbstractGraph) = flatten_chains!(deepcopy(g))

"""
function mask_zero_subgraph_factors(operator::Type{<:AbstractOperator}, subg_fac::Vector{F}) where {F}
Returns a list of indices that should be considered when performing the operation (e.g., Sum, Prod, Power), effectively masking out zero values as appropriate.
The behavior of the function depends on the operator type:
- `Sum`: Returns all indices that are not equal to zero.
- `Prod`: Returns the index of the first zero value, or all indices if none are found.
- `Power`: Returns `[1]`, or error if the power is negative.
- Other `AbstractOperator`: Defaults to return all indices.
"""
function mask_zero_subgraph_factors(::Type{Sum}, subg_fac::Vector{F}) where {F}
mask_zeros = findall(x -> x != zero(x), subg_fac)
if isempty(mask_zeros)
mask_zeros = [1]
end
return mask_zeros
end
function mask_zero_subgraph_factors(::Type{Prod}, subg_fac::Vector{F}) where {F}
idx = findfirst(x -> x == zero(x), subg_fac)
if isnothing(idx)
mask_zeros = eachindex(subg_fac)
else
mask_zeros = [idx]
end
return mask_zeros
end
function mask_zero_subgraph_factors(::Type{Power{N}}, subg_fac::Vector{F}) where {N,F}
if N >= 0
return [1]
else
error("0^$N is illegal!")
end
end
function mask_zero_subgraph_factors(::Type{<:AbstractOperator}, subg_fac::Vector{F}) where {F}
@info("Masking zero-valued subgraphs when the node operator is $operator is not implemented. Defaulted to no mask! \n" *
"It's better to define a method `mask_zero_subgraph_factors(operator::Type, subg_fac::Vector{F})`."
)
return eachindex(subg_fac)
end

"""
function remove_zero_valued_subgraphs!(g::AbstractGraph)
Expand All @@ -391,22 +432,16 @@ function remove_zero_valued_subgraphs!(g::AbstractGraph)
zero_sgf = zero(subg_fac[1]) # F(0)
# Find subgraphs with all-zero subgraph_factors and propagate subfactor one level up
for (i, sub_g) in enumerate(subg)
if has_zero_subfactors(sub_g)
if isleaf(sub_g)
continue
end
if has_zero_subfactors(sub_g, sub_g.operator)
subg_fac[i] = zero_sgf
end
end

# Remove marked zero subgraph factor subgraph(s) of g
mask_zeros = findall(x -> x != zero(x), subg_fac)
if isempty(mask_zeros)
mask_zeros = [1] # retain eldest(g) if all subfactors are zero
else
if g.operator == Prod
idx = findfirst(x -> x==zero(x), subg_fac)
if !isnothing(idx)
append!(mask_zeros, idx)
end
end
end
mask_zeros = mask_zero_subgraph_factors(g.operator, subg_fac)
set_subgraphs!(g, subg[mask_zeros])
set_subgraph_factors!(g, subg_fac[mask_zeros])
return g
Expand Down
41 changes: 28 additions & 13 deletions src/computational_graph/tree_properties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,39 @@ function ischain(g::AbstractGraph)
end

"""
function has_zero_subfactors(g)
function has_zero_subfactors(g::AbstractGraph, operator_type::Type{<:AbstractOperator})
Returns whether the graph g has only zero-valued subgraph factor(s).
Note that this function does not recurse through subgraphs of g, so that one may have, e.g.,
`has_zero_subfactors(g) == true` but `has_zero_subfactors(eldest(g)) == false`.
By convention, returns `false` if g is a leaf.
Determines whether the graph `g` has only zero-valued subgraph factors based on the specified operator type.
This function does not recurse through the subgraphs of `g`, so it only checks the immediate subgraph factors.
If `g` is a leaf (i.e., has no subgraphs), the function returns `false` by convention.
The behavior of the function depends on the operator type:
- `Sum`: Checks if all subgraph factors are zero.
- `Prod`: Checks if any subgraph factor is zero.
- `Power{N}`: Checks if the first subgraph factor is zero.
- Other `AbstractOperator`: Defaults to return `false`.
# Arguments:
- `g::AbstractGraph`: graph to be analyzed
- `operator`: the operator used in graph `g`
"""
function has_zero_subfactors(g::AbstractGraph)
if isleaf(g)
return false # convention: subgraph_factors = [] ⟹ subfactorless = false
elseif g.operator == Prod && 0 in subgraph_factors(g)
return true
else
return iszero(subgraph_factors(g))
end
function has_zero_subfactors(g::AbstractGraph, ::Type{Sum})
@assert g.operator == Sum "Operator must be Sum"
return iszero(subgraph_factors(g))
end

function has_zero_subfactors(g::AbstractGraph, ::Type{Prod})
@assert g.operator == Prod "Operator must be Prod"
return 0 in subgraph_factors(g)
end

function has_zero_subfactors(g::AbstractGraph, ::Type{Power{N}}) where {N}
@assert g.operator <: Power "Operator must be a Power"
return iszero(subgraph_factors(g)[1])
end

function has_zero_subfactors(g::AbstractGraph, ::Type{<:AbstractOperator})
@info "has_zero_subfactors: Operator type $operator is not specifically defined. Defaults to return false."
return false
end

"""
Expand Down
72 changes: 52 additions & 20 deletions test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,25 +324,40 @@ end
l6 = Graph([]; factor=6)
l7 = Graph([]; factor=7)
l8 = Graph([]; factor=8)
l2_test = Graph([]; factor=2)
Graphs.remove_zero_valued_subgraphs(l2)
@test isequiv(l2, l2_test, :id)
# subgraphs
sg1 = l1
sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1())
sg3 = Graph([l4]; subgraph_factors=[0], operator=O2())
sg4 = Graph([l5, l6, l7]; subgraph_factors=[0, 0, 0], operator=O3())
sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=Graphs.Sum())
sg2_test = Graph([l2]; subgraph_factors=[1.0], operator=Graphs.Sum())
sg3 = Graph([l4]; subgraph_factors=[0], operator=Graphs.Power(2))
sg3_test = Graph([l4]; subgraph_factors=[0], operator=Graphs.Power(2))
sg4 = Graph([l5, l6, l7]; subgraph_factors=[0, 0, 0], operator=Graphs.Sum())
sg5 = l8
sg6 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=Graphs.Prod())
sg6c = deepcopy(sg6)
sg6c_test = Graph([l3]; subgraph_factors=[0.0], operator=Graphs.Prod())
Graphs.remove_zero_valued_subgraphs!(sg2)
Graphs.remove_zero_valued_subgraphs!(sg3)
@test isequiv(sg2, sg2_test, :id)
@test isequiv(sg3, sg3_test, :id)
@test isequiv(sg6, sg6c, :id)
# graphs
g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O())
g_test = Graph([sg1, sg2]; subgraph_factors=[1, 1], operator=O())
g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=Graphs.Sum())
g_test = Graph([sg1, sg2]; subgraph_factors=[1, 1], operator=Graphs.Sum())
g1 = Graph([sg1, sg2, sg3, sg4, sg5, sg6]; subgraph_factors=[1, 1, 1, 1, 0, 2], operator=Graphs.Sum())
g1_test = Graph([sg1, sg2]; subgraph_factors=[1, 1], operator=Graphs.Sum())
gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O())
gp_test = Graph([sg3]; subgraph_factors=[0], operator=O())
g2 = Graph([sg1, sg2, sg3, sg4, sg5, sg6]; subgraph_factors=[1, 1, 1, 1, 0, 2], operator=O1())
g2_test = Graph([sg1, sg2, sg3, sg4, sg5, sg6]; subgraph_factors=[1, 1, 1, 1, 0, 2], operator=O1())
gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=Graphs.Sum())
gp_test = Graph([sg3]; subgraph_factors=[0], operator=Graphs.Sum())
Graphs.remove_zero_valued_subgraphs!(g)
Graphs.remove_zero_valued_subgraphs!(g1)
Graphs.remove_zero_valued_subgraphs!(gp)
@test isequiv(g, g_test, :id)
@test isequiv(g1, g1_test, :id)
@test isequiv(g2, g2_test, :id)
@test isequiv(gp, gp_test, :id)
end
end
Expand Down Expand Up @@ -387,24 +402,32 @@ end
ssg1 = Graph([l7]; subgraph_factors=[0], operator=O())
# subgraphs
sg1 = l1
sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=O1())
sg2_test = Graph([l2]; subgraph_factors=[1.0], operator=O1())
sg3 = Graph([l4]; subgraph_factors=[0], operator=O2())
sg4 = Graph([l5, l6, ssg1]; subgraph_factors=[0, 0, 1], operator=O3())
sg2 = Graph([l2, l3]; subgraph_factors=[1.0, 0.0], operator=Graphs.Sum())
sg2c = deepcopy(sg2)
sg2_test = Graph([l2]; subgraph_factors=[1.0], operator=Graphs.Sum())
sg3 = Graph([l4]; subgraph_factors=[0], operator=Graphs.Sum())
sg4 = Graph([l5, l6, ssg1]; subgraph_factors=[0, 0, 3], operator=Graphs.Sum())
sg4c = deepcopy(sg4)
sg4_test = Graph([ssg1], subgraph_factors=[3], operator=Graphs.Sum())
sg5 = l8
sg6 = Graph([l2, sg3]; subgraph_factors=[1.0, 2.0], operator=Graphs.Prod())
# graphs
g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=O())
g_test = Graph([sg1, sg2_test]; subgraph_factors=[1, 1], operator=O())
g = Graph([sg1, sg2, sg3, sg4, sg5]; subgraph_factors=[1, 1, 1, 1, 0], operator=Graphs.Sum())
g_test = Graph([sg1, sg2_test, sg4_test]; subgraph_factors=[1, 1, 1], operator=Graphs.Sum())
g1 = Graph([sg1, sg2, sg3, sg4, sg5, sg6]; subgraph_factors=[1, 1, 1, 1, 0, -1], operator=Graphs.Sum())
g1_test = Graph([sg1, sg2_test]; subgraph_factors=[1, 1], operator=Graphs.Sum())
gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 1, 0], operator=O())
gp_test = Graph([sg3]; subgraph_factors=[0], operator=O())
g1_test = Graph([sg1, sg2_test, sg4_test]; subgraph_factors=[1, 1, 1], operator=Graphs.Sum())

g2 = Graph([sg1, sg2c, sg3, sg4c, sg5, sg6]; subgraph_factors=[1, 0, 1, 1, 0, -1], operator=O1())
g2_test = Graph([sg1, sg2_test, sg3, sg4_test, sg5, sg6]; subgraph_factors=[1, 0, 0, 1, 0, 0], operator=O1())
gp = Graph([sg3, sg4, sg5]; subgraph_factors=[1, 0, 0], operator=Graphs.Sum())
gp_test = Graph([sg3]; subgraph_factors=[0], operator=Graphs.Sum())
Graphs.remove_all_zero_valued_subgraphs!(g)
Graphs.remove_all_zero_valued_subgraphs!(g1)
Graphs.remove_all_zero_valued_subgraphs!(g2)
Graphs.remove_all_zero_valued_subgraphs!(gp)
@test isequiv(g, g_test, :id)
@test isequiv(g1, g1_test, :id)
@test isequiv(g2, g2_test, :id)
@test isequiv(gp, gp_test, :id)
end
@testset "Merge all linear combinations" begin
Expand Down Expand Up @@ -1050,7 +1073,8 @@ end

@testset verbose = true "Tree properties" begin
using FeynmanDiagram.ComputationalGraphs:
haschildren, onechild, isleaf, isbranch, ischain, eldest, count_operation
haschildren, onechild, isleaf, isbranch, ischain, eldest, count_operation, has_zero_subfactors

# Leaves: gᵢ
g1 = Graph([])
g2 = Graph([], factor=2)
Expand All @@ -1068,6 +1092,8 @@ end
g10 = g1 * g2 + g8 * g9
h2 = Graph([g1, g2]; subgraph_factors=[0, 0], operator=Graphs.Sum())
h3 = Graph([g1, g2]; subgraph_factors=[1, 0], operator=Graphs.Sum())
h4 = Graph([g1]; subgraph_factors=[0], operator=Graphs.Power(2))
h5 = Graph([g1, g2]; subgraph_factors=[0, 0], operator=O())
glist = [g1, g2, g8, g9, g10]

@testset "Leaves" begin
Expand All @@ -1087,7 +1113,7 @@ end
@test isbranch(g3)
@test ischain(g3)
@test isleaf(eldest(g3))
@test has_zero_subfactors(h1)
@test has_zero_subfactors(h1, h1.operator)
end
@testset "Chains" begin
@test haschildren(g6)
Expand All @@ -1107,8 +1133,14 @@ end
@test count_operation(g8) == [1, 0]
@test count_operation(g9) == [2, 0]
@test count_operation(g10) == [4, 2]
@test has_zero_subfactors(h2)
@test has_zero_subfactors(h3) == false
@test has_zero_subfactors(h2, h2.operator)
@test has_zero_subfactors(h3, h3.operator) == false
@test has_zero_subfactors(h4, h4.operator)
@test has_zero_subfactors(h5, h5.operator) == false
function FeynmanDiagram.has_zero_subfactors(g::AbstractGraph, ::Type{O})
return iszero(g.subgraph_factors)
end
@test has_zero_subfactors(h5, h5.operator)
end
@testset "Iteration" begin
count_pre = sum(1 for node in PreOrderDFS(g9))
Expand Down

0 comments on commit 391fd5c

Please sign in to comment.