diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 9df58478..d07450e0 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -255,6 +255,7 @@ Base.broadcastable(label::NodeLabel) = Ref(label) getname(label::NodeLabel) = label.name getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels)) +getid(label::NodeLabel) = label.global_counter iterate(label::NodeLabel) = (label, nothing) iterate(label::NodeLabel, any) = nothing @@ -765,9 +766,10 @@ mutable struct NodeData const context :: Context const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}} const extra :: UnorderedDictionary{Symbol, Any} + const id :: Int end -NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}()) +NodeData(context, properties, id) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}(), id) function Base.show(io::IO, nodedata::NodeData) context = getcontext(nodedata) @@ -1529,7 +1531,7 @@ end function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index) # In theory plugins are able to overwrite this potential_label = generate_nodelabel(model, name) - potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options)) + potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options), getid(potential_label)) label, nodedata = preprocess_plugins( UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options ) @@ -1643,7 +1645,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr factornode_id = generate_factor_nodelabel(context, fform) potential_label = generate_nodelabel(model, fform) - potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options)) + potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options), getid(potential_label)) label, nodedata = preprocess_plugins( UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options diff --git a/test/graph_engine_tests.jl b/test/graph_engine_tests.jl index 678f7a0b..4e768f2f 100644 --- a/test/graph_engine_tests.jl +++ b/test/graph_engine_tests.jl @@ -118,7 +118,7 @@ end @testset "FactorNodeProperties" begin properties = FactorNodeProperties(fform = String) - nodedata = NodeData(context, properties) + nodedata = NodeData(context, properties, 1) @test getcontext(nodedata) === context @test getproperties(nodedata) === properties @@ -135,7 +135,7 @@ end @testset "VariableNodeProperties" begin properties = VariableNodeProperties(name = :x, index = 1) - nodedata = NodeData(context, properties) + nodedata = NodeData(context, properties, 1) @test getcontext(nodedata) === context @test getproperties(nodedata) === properties @@ -183,7 +183,7 @@ end context = getcontext(model) @testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1)) - nodedata = NodeData(context, properties) + nodedata = NodeData(context, properties, 1) @test !hasextra(nodedata, :a) @test getextra(nodedata, :a, 2) === 2 @@ -552,7 +552,10 @@ end function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options) # Here we replace the original options entirely - return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0))) + return label, + NodeData( + context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0)), GraphPPL.getid(label) + ) end for model_fn in ModelsInTheZooWithoutArguments @@ -933,13 +936,13 @@ end model = create_test_model() ctx = getcontext(model) - model[NodeLabel(:μ, 1)] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) + model[NodeLabel(:μ, 1)] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing), 1) @test nv(model) == 1 && ne(model) == 0 - model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2) @test nv(model) == 2 && ne(model) == 0 - model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum)) + model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum), 3) @test nv(model) == 3 && ne(model) == 0 @test_throws MethodError model[0] = 1 @@ -959,8 +962,8 @@ end μ = NodeLabel(:μ, 1) xref = NodeLabel(:x, 2) - model[μ] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) - model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + model[μ] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing), 1) + model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2) model[μ, xref] = EdgeLabel(:interface, 1) @test ne(model) == 1 @@ -990,7 +993,7 @@ end model = create_test_model() ctx = getcontext(model) label = NodeLabel(:x, 1) - model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 1) @test isa(model[label], NodeData) @test isa(getproperties(model[label]), VariableNodeProperties) @test_throws KeyError model[NodeLabel(:x, 10)] @@ -1024,8 +1027,8 @@ end @test nv(model) == 0 @test ne(model) == 0 - model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) - model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) + model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1) + model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2) @test !isempty(model) @test nv(model) == 2 @test ne(model) == 0 @@ -1059,8 +1062,8 @@ end ctx = getcontext(model) a = NodeLabel(:a, 1) b = NodeLabel(:b, 2) - model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) - model[b] = NodeData(ctx, FactorNodeProperties(fform = sum)) + model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1) + model[b] = NodeData(ctx, FactorNodeProperties(fform = sum), 2) @test !has_edge(model, a, b) @test !has_edge(model, b, a) add_edge!(model, b, getproperties(model[b]), a, :edge, 1) @@ -1069,7 +1072,7 @@ end @test length(edges(model)) == 1 c = NodeLabel(:c, 2) - model[c] = NodeData(ctx, FactorNodeProperties(fform = sum)) + model[c] = NodeData(ctx, FactorNodeProperties(fform = sum), 2) @test !has_edge(model, a, c) @test !has_edge(model, c, a) add_edge!(model, c, getproperties(model[c]), a, :edge, 2) @@ -1109,8 +1112,8 @@ end a = NodeLabel(:a, 1) b = NodeLabel(:b, 2) - model[a] = NodeData(ctx, FactorNodeProperties(fform = sum)) - model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) + model[a] = NodeData(ctx, FactorNodeProperties(fform = sum), 1) + model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2) add_edge!(model, a, getproperties(model[a]), b, :edge, 1) @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] @@ -1120,9 +1123,9 @@ end b = ResizableArray(NodeLabel, Val(1)) for i in 1:3 a[i] = NodeLabel(:a, i) - model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum)) + model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum), i) b[i] = NodeLabel(:b, i) - model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i)) + model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i), i) add_edge!(model, a[i], getproperties(model[a[i]]), b[i], :edge, i) end for n in b diff --git a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl index b69db741..3a4d907d 100644 --- a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl @@ -862,35 +862,35 @@ end ]) variable = ResolvedIndexedVariable(:w, 2:3, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2) @test node_data ∈ variable variable = ResolvedIndexedVariable(:w, 2:3, context) - node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2)) + node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2), 2) @test !(node_data ∈ variable) variable = ResolvedIndexedVariable(:w, 2, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2) @test node_data ∈ variable variable = ResolvedIndexedVariable(:w, SplittedRange(2, 3), context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2) @test node_data ∈ variable variable = ResolvedIndexedVariable(:w, SplittedRange(10, 15), context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2) @test !(node_data ∈ variable) variable = ResolvedIndexedVariable(:x, nothing, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2), 2) @test node_data ∈ variable variable = ResolvedIndexedVariable(:x, nothing, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing)) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing), 1) @test node_data ∈ variable variable = ResolvedIndexedVariable(:prec, 3, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3))) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3)), 2) @test node_data ∈ variable end