Skip to content

Commit

Permalink
add option to run shell hooks in host netns
Browse files Browse the repository at this point in the history
Closes #6.
  • Loading branch information
chrisbouchard authored and dadevel committed Oct 2, 2022
1 parent baa1d01 commit 98e7915
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 17 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ name: ns-example
managed: true
# list of dns servers, if empty dns servers from default netns will be used
dns-server: [10.10.10.1, 10.10.10.2]
# shell hooks, e.g. to set firewall rules
pre-up: echo pre-up
post-up: echo post-up
pre-own: echo pre-down
post-down: echo post-down
# shell hooks, e.g. to set firewall rules, two formats are supported
pre-up: echo pre-up from managed netns
post-up:
- host-namespace: true
command: echo post-up from host netns
- host-namespace: false
command: echo post-up from managed netns
pre-down: echo pre-down from managed netns
post-down: echo post-down from managed netns
# list of wireguard interfaces inside the netns
interfaces:
# interface name, required
Expand Down
80 changes: 68 additions & 12 deletions wgnetns/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,63 @@ def exists(self, namespace: Namespace) -> bool:
return False


@dataclasses.dataclass
class ScriptletItem:
command: str
host_namespace: bool = False

@classmethod
def from_str(cls, data: str) -> ScriptletItem:
return cls(command=data)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> ScriptletItem:
data = {key.replace('-', '_'): value for key, value in data.items()}
host_namespace = bool(data.pop('host_namespace', None))
return cls(**data, host_namespace=host_namespace)

def run(self, netns: str):
if self.host_namespace:
host_eval(self.command)
else:
ip_netns_eval(self.command, netns=netns)


@dataclasses.dataclass
class Scriptlet:
items: list[ScriptletItem] = dataclasses.field(default_factory=list)

@classmethod
def from_value(cls, data) -> Scriptlet:
if isinstance(data, list):
return cls.from_list(data)
elif isinstance(data, str):
return cls.from_singleton(data)
else:
raise RuntimeError(f'unsupported scriptlet type: {data.__class__.__name__}')

@classmethod
def from_list(cls, data: list[Any]) -> Scriptlet:
items = [ScriptletItem.from_dict(item) for item in data]
return cls(items=items)

@classmethod
def from_singleton(cls, data) -> Scriptlet:
item = ScriptletItem.from_str(data)
return cls(items=[item])

def run(self, netns: str):
for item in self.items:
item.run(netns=netns)


@dataclasses.dataclass
class Namespace:
name: str
pre_up: Optional[str] = None
post_up: Optional[str] = None
pre_down: Optional[str] = None
post_down: Optional[str] = None
pre_up: Optional[Scriptlet] = None
post_up: Optional[Scriptlet] = None
pre_down: Optional[Scriptlet] = None
post_down: Optional[Scriptlet] = None
managed: bool = True
dns_server: list[str] = dataclasses.field(default_factory=list)
interfaces: list[Interface] = dataclasses.field(default_factory=list)
Expand All @@ -191,7 +241,7 @@ def from_profile(cls, path: Path) -> Namespace:
try:
return cls.from_dict(cls._read_profile(cls._find_profile(path)))
except Exception as e:
raise RuntimeError('failed to load profile') from e
raise RuntimeError(f'failed to load profile: {e}') from e

@staticmethod
def _find_profile(profile: Path) -> Path:
Expand All @@ -217,32 +267,34 @@ def _read_profile(profile: Path) -> dict[str, Any]:
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Namespace:
data = {key.replace('-', '_'): value for key, value in data.items()}
scriptlets = {key: data.pop(key, None) for key in ['pre_up', 'post_up', 'pre_down', 'post_down']}
scriptlets = {key: Scriptlet.from_value(value) for key, value in scriptlets.items() if value is not None}
interfaces = data.pop('interfaces', list())
interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}) for interface in interfaces]
return cls(**data, interfaces=interfaces)
return cls(**data, **scriptlets, interfaces=interfaces)

def setup(self) -> Namespace:
if self.pre_up:
ip_netns_eval(self.pre_up, netns=self.name)
if self.managed:
self._create()
self._write_resolvconf()
if self.pre_up:
self.pre_up.run(netns=self.name)
for interface in self.interfaces:
interface.setup(self)
if self.post_up:
ip_netns_eval(self.post_up, netns=self.name)
self.post_up.run(netns=self.name)
return self

def teardown(self, check=True) -> Namespace:
if self.pre_down:
ip_netns_eval(self.pre_down, netns=self.name)
self.pre_down.run(netns=self.name)
for interface in self.interfaces:
interface.teardown(self, check=check)
if self.post_down:
self.post_down.run(netns=self.name)
if self.managed and self.exists():
self._delete(check)
self._delete_resolvconf()
if self.post_down:
ip_netns_eval(self.post_down, netns=self.name)
return self

def exists(self) -> bool:
Expand Down Expand Up @@ -291,6 +343,10 @@ def ip(*args, stdin: str = None, check=True, capture=False) -> str:
return run('ip', *args, stdin=stdin, check=check, capture=capture)


def host_eval(*args, stdin: str = None, check=True, capture=False) -> str:
return run(SHELL, '-c', *args, stdin=stdin, check=check, capture=capture)


def run(*args, stdin: str = None, check=True, capture=False) -> str:
args = [str(item) if item is not None else '' for item in args]
if VERBOSE:
Expand Down

0 comments on commit 98e7915

Please sign in to comment.