diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6366fc390..e384425f9 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -24,7 +24,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = NodesAsPulses(), + node_definition: NodeDefinition = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -69,6 +69,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if node_definition is None: + node_definition = NodesAsPulses() + # Member Variables self._detector = detector self._edge_definition = edge_definition