Skip to content

Commit

Permalink
diversion: Introduce protocols to unite Action and Condition classes
Browse files Browse the repository at this point in the history
Enforce a common interface for all Action and Condition related classes
and connect them to a common protocol class to support isinstance
checks.

Related pwr-Solaar#2659
  • Loading branch information
MattHag committed Jan 2, 2025
1 parent 64f22ce commit 6a99fc6
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 40 deletions.
93 changes: 56 additions & 37 deletions lib/logitech_receiver/diversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def charging(f, r, d, _a):
}


def compile_component(c):
if isinstance(c, Rule) or isinstance(c, Condition) or isinstance(c, Action):
def compile_component(c) -> Rule | type[ConditionProtocol] | type[ActionProtocol]:
if isinstance(c, Rule) or isinstance(c, ConditionProtocol) or isinstance(c, ActionProtocol):
return c
elif isinstance(c, dict) and len(c) == 1:
k, v = next(iter(c.items()))
Expand All @@ -523,16 +523,16 @@ def compile_component(c):
except KeyError:
pass
logger.warning("illegal component in rule: %s", c)
return FalllbackCondition()
return FallbackCondition()


def _evaluate(components, feature, notification: HIDPPNotification, device, result) -> Any:
res = True
for component in components:
res = component.evaluate(feature, notification, device, result)
if not isinstance(component, Action) and res is None:
if not isinstance(component, ActionProtocol) and res is None:
return None
if isinstance(component, Condition) and not res:
if isinstance(component, ConditionProtocol) and not res:
return res
return res

Expand All @@ -559,7 +559,22 @@ def data(self):
return {"Rule": [c.data() for c in self.components]}


class Condition:
@typing.runtime_checkable
class ConditionProtocol(typing.Protocol):
def __init__(self, args: Any, warn: bool) -> None:
...

def __str__(self) -> str:
...

def evaluate(self, feature, notification: HIDPPNotification, device, last_result) -> bool:
...

def data(self) -> dict[str, Any]:
...


class FallbackCondition(ConditionProtocol):
def __init__(self, *args):
pass

Expand All @@ -572,7 +587,7 @@ def evaluate(self, feature, notification: HIDPPNotification, device, last_result
return False


class Not(Condition):
class Not(ConditionProtocol):
def __init__(self, op, warn=True):
if isinstance(op, list) and len(op) == 1:
op = op[0]
Expand All @@ -592,7 +607,7 @@ def data(self):
return {"Not": self.component.data()}


class Or(Condition):
class Or(ConditionProtocol):
def __init__(self, args, warn=True):
self.components = [compile_component(a) for a in args]

Expand All @@ -605,17 +620,17 @@ def evaluate(self, feature, notification: HIDPPNotification, device, last_result
result = False
for component in self.components:
result = component.evaluate(feature, notification, device, last_result)
if not isinstance(component, Action) and result is None:
if not isinstance(component, ActionProtocol) and result is None:
return None
if isinstance(component, Condition) and result:
if isinstance(component, ConditionProtocol) and result:
return result
return result

def data(self):
return {"Or": [c.data() for c in self.components]}


class And(Condition):
class And(ConditionProtocol):
def __init__(self, args, warn=True):
self.components = [compile_component(a) for a in args]

Expand Down Expand Up @@ -677,7 +692,7 @@ def gnome_dbus_pointer_prog():
return (wm_class,) if wm_class else None


class Process(Condition):
class Process(ConditionProtocol):
def __init__(self, process, warn=True):
self.process = process
if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()):
Expand Down Expand Up @@ -708,7 +723,7 @@ def data(self):
return {"Process": str(self.process)}


class MouseProcess(Condition):
class MouseProcess(ConditionProtocol):
def __init__(self, process, warn=True):
self.process = process
if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()):
Expand Down Expand Up @@ -739,7 +754,7 @@ def data(self):
return {"MouseProcess": str(self.process)}


class Feature(Condition):
class Feature(ConditionProtocol):
def __init__(self, feature: str, warn: bool = True):
try:
self.feature = SupportedFeature[feature]
Expand All @@ -760,7 +775,7 @@ def data(self):
return {"Feature": str(self.feature)}


class Report(Condition):
class Report(ConditionProtocol):
def __init__(self, report, warn=True):
if not (isinstance(report, int)):
if warn:
Expand All @@ -782,7 +797,7 @@ def data(self):


# Setting(device, setting, [key], value...)
class Setting(Condition):
class Setting(ConditionProtocol):
def __init__(self, args, warn=True):
if not (isinstance(args, list) and len(args) > 2):
if warn:
Expand Down Expand Up @@ -829,7 +844,7 @@ def data(self):
MODIFIER_MASK = MODIFIERS["Shift"] + MODIFIERS["Control"] + MODIFIERS["Alt"] + MODIFIERS["Super"]


class Modifiers(Condition):
class Modifiers(ConditionProtocol):
def __init__(self, modifiers, warn=True):
modifiers = [modifiers] if isinstance(modifiers, str) else modifiers
self.desired = 0
Expand Down Expand Up @@ -859,7 +874,7 @@ def data(self):
return {"Modifiers": [str(m) for m in self.modifiers]}


class Key(Condition):
class Key(ConditionProtocol):
DOWN = "pressed"
UP = "released"

Expand Down Expand Up @@ -914,7 +929,7 @@ def data(self):
return {"Key": [str(self.key), self.action]}


class KeyIsDown(Condition):
class KeyIsDown(ConditionProtocol):
def __init__(self, args, warn=True):
default_key = 0

Expand Down Expand Up @@ -958,7 +973,7 @@ def range_test_helper(_f, _r, d):
return range_test_helper


class Test(Condition):
class Test(ConditionProtocol):
def __init__(self, test, warn=True):
self.test = ""
self.parameter = None
Expand Down Expand Up @@ -1000,7 +1015,7 @@ def data(self):
return {"Test": ([self.test, self.parameter] if self.parameter is not None else [self.test])}


class TestBytes(Condition):
class TestBytes(ConditionProtocol):
def __init__(self, test, warn=True):
self.test = test
if (
Expand Down Expand Up @@ -1028,7 +1043,7 @@ def data(self):
return {"TestBytes": self.test[:]}


class MouseGesture(Condition):
class MouseGesture(ConditionProtocol):
MOVEMENTS = [
"Mouse Up",
"Mouse Down",
Expand Down Expand Up @@ -1083,7 +1098,7 @@ def data(self):
return {"MouseGesture": [str(m) for m in self.movements]}


class Active(Condition):
class Active(ConditionProtocol):
def __init__(self, devID, warn=True):
if not (isinstance(devID, str)):
if warn:
Expand All @@ -1104,7 +1119,7 @@ def data(self):
return {"Active": self.devID}


class Device(Condition):
class Device(ConditionProtocol):
def __init__(self, devID, warn=True):
if not (isinstance(devID, str)):
if warn:
Expand All @@ -1124,7 +1139,7 @@ def data(self):
return {"Device": self.devID}


class Host(Condition):
class Host(ConditionProtocol):
def __init__(self, host, warn=True):
if not (isinstance(host, str)):
if warn:
Expand All @@ -1145,12 +1160,16 @@ def data(self):
return {"Host": self.host}


class Action:
def __init__(self, *args):
pass
@typing.runtime_checkable
class ActionProtocol(typing.Protocol):
def __init__(self, args: Any, warn: bool) -> None:
...

def evaluate(self, feature, notification: HIDPPNotification, device, last_result):
return None
def evaluate(self, feature, notification: HIDPPNotification, device, last_result) -> None:
...

def data(self) -> dict[str, Any]:
...


def keysym_to_keycode(keysym, _modifiers) -> Tuple[int, int]: # maybe should take shift into account
Expand Down Expand Up @@ -1179,7 +1198,7 @@ def keysym_to_keycode(keysym, _modifiers) -> Tuple[int, int]: # maybe should ta
return keycode, level


class KeyPress(Action):
class KeyPress(ActionProtocol):
def __init__(self, args, warn=True):
self.key_names, self.action = self.regularize_args(args)
if not isinstance(self.key_names, list):
Expand Down Expand Up @@ -1269,7 +1288,7 @@ def data(self):
# super().keyUp(self.keys, current_key_modifiers)


class MouseScroll(Action):
class MouseScroll(ActionProtocol):
def __init__(self, amounts, warn=True):
if len(amounts) == 1 and isinstance(amounts[0], list):
amounts = amounts[0]
Expand Down Expand Up @@ -1297,7 +1316,7 @@ def data(self):
return {"MouseScroll": self.amounts[:]}


class MouseClick(Action):
class MouseClick(ActionProtocol):
def __init__(self, args, warn=True):
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
Expand Down Expand Up @@ -1336,7 +1355,7 @@ def data(self):
return {"MouseClick": [self.button, self.count]}


class Set(Action):
class Set(ActionProtocol):
def __init__(self, args, warn=True):
if not (isinstance(args, list) and len(args) > 2):
if warn:
Expand Down Expand Up @@ -1382,7 +1401,7 @@ def data(self):
return {"Set": self.args[:]}


class Execute(Action):
class Execute(ActionProtocol):
def __init__(self, args, warn=True):
if isinstance(args, str):
args = [args]
Expand All @@ -1406,7 +1425,7 @@ def data(self):
return {"Execute": self.args[:]}


class Later(Action):
class Later(ActionProtocol):
def __init__(self, args, warn=True):
self.delay = 0
self.rule = Rule([])
Expand Down Expand Up @@ -1441,7 +1460,7 @@ def data(self):
return {"Later": data}


COMPONENTS = {
COMPONENTS: dict[str, Rule | ConditionProtocol | ActionProtocol] = {
"Rule": Rule,
"Not": Not,
"Or": Or,
Expand Down
2 changes: 1 addition & 1 deletion lib/solaar/ui/diversion_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def left_label(cls, component):


class ActionUI(RuleComponentUI):
CLASS = diversion.Action
CLASS = diversion.ActionProtocol

@classmethod
def icon_name(cls):
Expand Down
2 changes: 1 addition & 1 deletion lib/solaar/ui/rule_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class ActionUI(RuleComponentUI):
CLASS = diversion.Action
CLASS = diversion.ActionProtocol

@classmethod
def icon_name(cls):
Expand Down
2 changes: 1 addition & 1 deletion lib/solaar/ui/rule_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class ConditionUI(RuleComponentUI):
CLASS = diversion.Condition
CLASS = diversion.ConditionProtocol

@classmethod
def icon_name(cls):
Expand Down

0 comments on commit 6a99fc6

Please sign in to comment.