diff --git a/README.md b/README.md index a713fd2..c4bee2c 100644 --- a/README.md +++ b/README.md @@ -69,11 +69,15 @@ name: ns-example managed: true # list of dns servers, if empty dns servers from default netns will be used dns-server: [10.10.10.1, 10.10.10.2] -# shell hooks, e.g. to set firewall rules -pre-up: echo pre-up -post-up: echo post-up -pre-own: echo pre-down -post-down: echo post-down +# shell hooks, e.g. to set firewall rules, two formats are supported +pre-up: echo pre-up from managed netns +post-up: +- host-namespace: true + command: echo post-up from host netns +- host-namespace: false + command: echo post-up from managed netns +pre-down: echo pre-down from managed netns +post-down: echo post-down from managed netns # list of wireguard interfaces inside the netns interfaces: # interface name, required diff --git a/wgnetns/main.py b/wgnetns/main.py index d4ef602..c643d0c 100755 --- a/wgnetns/main.py +++ b/wgnetns/main.py @@ -175,13 +175,63 @@ def exists(self, namespace: Namespace) -> bool: return False +@dataclasses.dataclass +class ScriptletItem: + command: str + host_namespace: bool = False + + @classmethod + def from_str(cls, data: str) -> ScriptletItem: + return cls(command=data) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScriptletItem: + data = {key.replace('-', '_'): value for key, value in data.items()} + host_namespace = bool(data.pop('host_namespace', None)) + return cls(**data, host_namespace=host_namespace) + + def run(self, netns: str): + if self.host_namespace: + host_eval(self.command) + else: + ip_netns_eval(self.command, netns=netns) + + +@dataclasses.dataclass +class Scriptlet: + items: list[ScriptletItem] = dataclasses.field(default_factory=list) + + @classmethod + def from_value(cls, data) -> Scriptlet: + if isinstance(data, list): + return cls.from_list(data) + elif isinstance(data, str): + return cls.from_singleton(data) + else: + raise RuntimeError(f'unsupported scriptlet type: {data.__class__.__name__}') + + @classmethod + def from_list(cls, data: list[Any]) -> Scriptlet: + items = [ScriptletItem.from_dict(item) for item in data] + return cls(items=items) + + @classmethod + def from_singleton(cls, data) -> Scriptlet: + item = ScriptletItem.from_str(data) + return cls(items=[item]) + + def run(self, netns: str): + for item in self.items: + item.run(netns=netns) + + @dataclasses.dataclass class Namespace: name: str - pre_up: Optional[str] = None - post_up: Optional[str] = None - pre_down: Optional[str] = None - post_down: Optional[str] = None + pre_up: Optional[Scriptlet] = None + post_up: Optional[Scriptlet] = None + pre_down: Optional[Scriptlet] = None + post_down: Optional[Scriptlet] = None managed: bool = True dns_server: list[str] = dataclasses.field(default_factory=list) interfaces: list[Interface] = dataclasses.field(default_factory=list) @@ -191,7 +241,7 @@ def from_profile(cls, path: Path) -> Namespace: try: return cls.from_dict(cls._read_profile(cls._find_profile(path))) except Exception as e: - raise RuntimeError('failed to load profile') from e + raise RuntimeError(f'failed to load profile: {e}') from e @staticmethod def _find_profile(profile: Path) -> Path: @@ -217,32 +267,34 @@ def _read_profile(profile: Path) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> Namespace: data = {key.replace('-', '_'): value for key, value in data.items()} + scriptlets = {key: data.pop(key, None) for key in ['pre_up', 'post_up', 'pre_down', 'post_down']} + scriptlets = {key: Scriptlet.from_value(value) for key, value in scriptlets.items() if value is not None} interfaces = data.pop('interfaces', list()) interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}) for interface in interfaces] - return cls(**data, interfaces=interfaces) + return cls(**data, **scriptlets, interfaces=interfaces) def setup(self) -> Namespace: - if self.pre_up: - ip_netns_eval(self.pre_up, netns=self.name) if self.managed: self._create() self._write_resolvconf() + if self.pre_up: + self.pre_up.run(netns=self.name) for interface in self.interfaces: interface.setup(self) if self.post_up: - ip_netns_eval(self.post_up, netns=self.name) + self.post_up.run(netns=self.name) return self def teardown(self, check=True) -> Namespace: if self.pre_down: - ip_netns_eval(self.pre_down, netns=self.name) + self.pre_down.run(netns=self.name) for interface in self.interfaces: interface.teardown(self, check=check) + if self.post_down: + self.post_down.run(netns=self.name) if self.managed and self.exists(): self._delete(check) self._delete_resolvconf() - if self.post_down: - ip_netns_eval(self.post_down, netns=self.name) return self def exists(self) -> bool: @@ -291,6 +343,10 @@ def ip(*args, stdin: str = None, check=True, capture=False) -> str: return run('ip', *args, stdin=stdin, check=check, capture=capture) +def host_eval(*args, stdin: str = None, check=True, capture=False) -> str: + return run(SHELL, '-c', *args, stdin=stdin, check=check, capture=capture) + + def run(*args, stdin: str = None, check=True, capture=False) -> str: args = [str(item) if item is not None else '' for item in args] if VERBOSE: