Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize type annotations and fix some discrepancies #451

Merged
merged 8 commits into from
Jun 11, 2024
10 changes: 6 additions & 4 deletions canopen/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from collections.abc import MutableMapping
import logging
import threading
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Callable, Dict, Iterator, List, Optional, Union

try:
import can
Expand Down Expand Up @@ -82,7 +84,7 @@ def unsubscribe(self, can_id, callback=None) -> None:
else:
self.subscribers[can_id].remove(callback)

def connect(self, *args, **kwargs) -> "Network":
def connect(self, *args, **kwargs) -> Network:
"""Connect to CAN bus using python-can.

Arguments are passed directly to :class:`can.BusABC`. Typically these
Expand Down Expand Up @@ -214,7 +216,7 @@ def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None:

def send_periodic(
self, can_id: int, data: bytes, period: float, remote: bool = False
) -> "PeriodicMessageTask":
) -> PeriodicMessageTask:
"""Start sending a message periodically.

:param can_id:
Expand Down Expand Up @@ -277,7 +279,7 @@ def __delitem__(self, node_id: int):
self.nodes[node_id].remove_network()
del self.nodes[node_id]

def __iter__(self) -> Iterable[int]:
def __iter__(self) -> Iterator[int]:
return iter(self.nodes)

def __len__(self) -> int:
Expand Down
38 changes: 20 additions & 18 deletions canopen/objectdictionary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
Object Dictionary module
"""
from __future__ import annotations

import struct
from typing import Dict, Iterable, List, Optional, TextIO, Union
from typing import Dict, Iterator, List, Optional, TextIO, Union
from collections.abc import MutableMapping, Mapping
import logging

Expand All @@ -13,7 +15,7 @@
logger = logging.getLogger(__name__)


def export_od(od, dest:Union[str,TextIO,None]=None, doc_type:Optional[str]=None):
def export_od(od, dest: Union[str, TextIO, None] = None, doc_type: Optional[str] = None):
""" Export :class: ObjectDictionary to a file.

:param od:
Expand Down Expand Up @@ -55,7 +57,7 @@ def export_od(od, dest:Union[str,TextIO,None]=None, doc_type:Optional[str]=None)
def import_od(
source: Union[str, TextIO, None],
node_id: Optional[int] = None,
) -> "ObjectDictionary":
) -> ObjectDictionary:
"""Parse an EDS, DCF, or EPF file.

:param source:
Expand Down Expand Up @@ -102,7 +104,7 @@ def __init__(self):

def __getitem__(
self, index: Union[int, str]
) -> Union["ODArray", "ODRecord", "ODVariable"]:
) -> Union[ODArray, ODRecord, ODVariable]:
"""Get object from object dictionary by name or index."""
item = self.names.get(index) or self.indices.get(index)
if item is None:
Expand All @@ -113,7 +115,7 @@ def __getitem__(
return item

def __setitem__(
self, index: Union[int, str], obj: Union["ODArray", "ODRecord", "ODVariable"]
self, index: Union[int, str], obj: Union[ODArray, ODRecord, ODVariable]
):
assert index == obj.index or index == obj.name
self.add_object(obj)
Expand All @@ -123,7 +125,7 @@ def __delitem__(self, index: Union[int, str]):
del self.indices[obj.index]
del self.names[obj.name]

def __iter__(self) -> Iterable[int]:
def __iter__(self) -> Iterator[int]:
return iter(sorted(self.indices))

def __len__(self) -> int:
Expand All @@ -132,7 +134,7 @@ def __len__(self) -> int:
def __contains__(self, index: Union[int, str]):
return index in self.names or index in self.indices

def add_object(self, obj: Union["ODArray", "ODRecord", "ODVariable"]) -> None:
def add_object(self, obj: Union[ODArray, ODRecord, ODVariable]) -> None:
"""Add object to the object dictionary.

:param obj:
Expand All @@ -147,7 +149,7 @@ def add_object(self, obj: Union["ODArray", "ODRecord", "ODVariable"]) -> None:

def get_variable(
self, index: Union[int, str], subindex: int = 0
) -> Optional["ODVariable"]:
) -> Optional[ODVariable]:
"""Get the variable object at specified index (and subindex if applicable).

:return: ODVariable if found, else `None`
Expand Down Expand Up @@ -182,13 +184,13 @@ def __init__(self, name: str, index: int):
def __repr__(self) -> str:
return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>"

def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
def __getitem__(self, subindex: Union[int, str]) -> ODVariable:
item = self.names.get(subindex) or self.subindices.get(subindex)
if item is None:
raise KeyError(f"Subindex {pretty_index(None, subindex)} was not found")
return item

def __setitem__(self, subindex: Union[int, str], var: "ODVariable"):
def __setitem__(self, subindex: Union[int, str], var: ODVariable):
assert subindex == var.subindex
self.add_member(var)

Expand All @@ -200,16 +202,16 @@ def __delitem__(self, subindex: Union[int, str]):
def __len__(self) -> int:
return len(self.subindices)

def __iter__(self) -> Iterable[int]:
def __iter__(self) -> Iterator[int]:
return iter(sorted(self.subindices))

def __contains__(self, subindex: Union[int, str]) -> bool:
return subindex in self.names or subindex in self.subindices

def __eq__(self, other: "ODRecord") -> bool:
def __eq__(self, other: ODRecord) -> bool:
return self.index == other.index

def add_member(self, variable: "ODVariable") -> None:
def add_member(self, variable: ODVariable) -> None:
"""Adds a :class:`~canopen.objectdictionary.ODVariable` to the record."""
variable.parent = self
self.subindices[variable.subindex] = variable
Expand Down Expand Up @@ -241,7 +243,7 @@ def __init__(self, name: str, index: int):
def __repr__(self) -> str:
return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>"

def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
def __getitem__(self, subindex: Union[int, str]) -> ODVariable:
var = self.names.get(subindex) or self.subindices.get(subindex)
if var is not None:
# This subindex is defined
Expand All @@ -264,13 +266,13 @@ def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
def __len__(self) -> int:
return len(self.subindices)

def __iter__(self) -> Iterable[int]:
def __iter__(self) -> Iterator[int]:
return iter(sorted(self.subindices))

def __eq__(self, other: "ODArray") -> bool:
def __eq__(self, other: ODArray) -> bool:
return self.index == other.index

def add_member(self, variable: "ODVariable") -> None:
def add_member(self, variable: ODVariable) -> None:
"""Adds a :class:`~canopen.objectdictionary.ODVariable` to the record."""
variable.parent = self
self.subindices[variable.subindex] = variable
Expand Down Expand Up @@ -348,7 +350,7 @@ def qualname(self) -> str:
return f"{self.parent.name}.{self.name}"
return self.name

def __eq__(self, other: "ODVariable") -> bool:
def __eq__(self, other: ODVariable) -> bool:
return (self.index == other.index and
self.subindex == other.subindex)

Expand Down
41 changes: 24 additions & 17 deletions canopen/pdo/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import threading
import math
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Callable, Dict, Iterator, List, Optional, Union, TYPE_CHECKING
from collections.abc import Mapping
import logging
import binascii
Expand All @@ -9,6 +10,12 @@
from canopen import objectdictionary
from canopen import variable

if TYPE_CHECKING:
from canopen.network import Network
from canopen import LocalNode, RemoteNode
from canopen.pdo import RPDO, TPDO
from canopen.sdo import SdoRecord

PDO_NOT_VALID = 1 << 31
RTR_NOT_ALLOWED = 1 << 30

Expand All @@ -22,10 +29,10 @@ class PdoBase(Mapping):
Parent object associated with this PDO instance
"""

def __init__(self, node):
self.network = None
self.map = None # instance of PdoMaps
self.node = node
def __init__(self, node: Union[LocalNode, RemoteNode]):
self.network: Optional[Network] = None
self.map: Optional[PdoMaps] = None
self.node: Union[LocalNode, RemoteNode] = node
Comment on lines +32 to +35
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intention of having Union[LocalNode, RemoteNode] as just using BaseNode? Do we expect other BaseNode derivatives that are not compatible with PdoBase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply because we need the .sdo attribute here and that is not part of BaseNode. I previously asked the same question in #446 (comment).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, otherwise we only have network, object_dictionary and id attributes. Also explicit is better than implicit, thats why I don't like defining these attributes in the BaseNode which would be another solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks.

Perhaps we should at some point make a NodeProtocol for the required attributes instead of using a exhaustive list? It might not be worth the efforts thou, since the Union is only used a few places.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think the Union approach is simple enough and warranted. We can still rework it when there's a stronger argument for it.


def __iter__(self):
return iter(self.map)
Expand Down Expand Up @@ -131,7 +138,7 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None):
:param pdo_node:
:param cob_base:
"""
self.maps: Dict[int, "PdoMap"] = {}
self.maps: Dict[int, PdoMap] = {}
for map_no in range(512):
if com_offset + map_no in pdo_node.node.object_dictionary:
new_map = PdoMap(
Expand All @@ -143,10 +150,10 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None):
new_map.predefined_cob_id = cob_base + map_no * 0x100 + pdo_node.node.id
self.maps[map_no + 1] = new_map

def __getitem__(self, key: int) -> "PdoMap":
def __getitem__(self, key: int) -> PdoMap:
return self.maps[key]

def __iter__(self) -> Iterable[int]:
def __iter__(self) -> Iterator[int]:
return iter(self.maps)

def __len__(self) -> int:
Expand All @@ -157,9 +164,9 @@ class PdoMap:
"""One message which can have up to 8 bytes of variables mapped."""

def __init__(self, pdo_node, com_record, map_array):
self.pdo_node = pdo_node
self.com_record = com_record
self.map_array = map_array
self.pdo_node: Union[TPDO, RPDO] = pdo_node
self.com_record: SdoRecord = com_record
self.map_array: SdoRecord = map_array
#: If this map is valid
self.enabled: bool = False
#: COB-ID for this PDO
Expand All @@ -177,7 +184,7 @@ def __init__(self, pdo_node, com_record, map_array):
#: Ignores SYNC objects up to this SYNC counter value (optional)
self.sync_start_value: Optional[int] = None
#: List of variables mapped to this PDO
self.map: List["PdoVariable"] = []
self.map: List[PdoVariable] = []
self.length: int = 0
#: Current message data
self.data = bytearray()
Expand Down Expand Up @@ -214,7 +221,7 @@ def __getitem_by_name(self, value):
raise KeyError(f"{value} not found in map. Valid entries are "
f"{', '.join(valid_values)}")

def __getitem__(self, key: Union[int, str]) -> "PdoVariable":
def __getitem__(self, key: Union[int, str]) -> PdoVariable:
if isinstance(key, int):
# there is a maximum available of 8 slots per PDO map
if key in range(0, 8):
Expand All @@ -228,7 +235,7 @@ def __getitem__(self, key: Union[int, str]) -> "PdoVariable":
var = self.__getitem_by_name(key)
return var

def __iter__(self) -> Iterable["PdoVariable"]:
def __iter__(self) -> Iterator[PdoVariable]:
return iter(self.map)

def __len__(self) -> int:
Expand Down Expand Up @@ -303,7 +310,7 @@ def on_message(self, can_id, data, timestamp):
for callback in self.callbacks:
callback(self)

def add_callback(self, callback: Callable[["PdoMap"], None]) -> None:
def add_callback(self, callback: Callable[[PdoMap], None]) -> None:
"""Add a callback which will be called on receive.

:param callback:
Expand Down Expand Up @@ -451,7 +458,7 @@ def add_variable(
index: Union[str, int],
subindex: Union[str, int] = 0,
length: Optional[int] = None,
) -> "PdoVariable":
) -> PdoVariable:
"""Add a variable from object dictionary as the next entry.

:param index: Index of variable as name or number
Expand Down Expand Up @@ -544,7 +551,7 @@ class PdoVariable(variable.Variable):

def __init__(self, od: objectdictionary.ODVariable):
#: PDO object that is associated with this ODVariable Object
self.pdo_parent = None
self.pdo_parent: Optional[PdoMap] = None
#: Location of variable in the message in bits
self.offset = None
self.length = len(od)
Expand Down
Loading
Loading