Skip to content

Commit

Permalink
feat: extend docker compose integration to support native execution
Browse files Browse the repository at this point in the history
  • Loading branch information
rpoisel committed Jan 12, 2025
1 parent eebf3f2 commit 77cdd95
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tests/openvpn/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ services:
- NET_ADMIN
privileged: true
ports:
- "1194:1194/udp"
- "{openvpn}:1194/udp"
networks:
- shared_network

Expand Down
13 changes: 5 additions & 8 deletions tests/test_openvpn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
from collections.abc import Iterator
from pathlib import Path

import openwrt
import pytest
from docker import DockerComposeWrapper
from labgrid.driver import SSHDriver
from network import NetworkError, primary_host_ip, resolve
from process import run
from ssh import put_file
from x509 import PKI, create_pki
Expand Down Expand Up @@ -35,6 +33,8 @@ def openvpn_server_env(pki: PKI) -> Iterator[DockerComposeWrapper]:
compose.up(build=True)
yield compose
compose.rm(force=True, stop=True)
compose.kill()
compose.cleanup()


@pytest.mark.openvpn
Expand All @@ -51,12 +51,9 @@ def step_openwrt_setup_openvpn() -> None:
put_file(ssh_command, Path("/etc") / "openvpn" / "client.crt", pki.client_cert)
put_file(ssh_command, Path("/etc") / "openvpn" / "client.key", pki.client_key)
run(ssh_command, "uci set openvpn.sample_client.enabled='1'")
try:
openvpn_server_ip = resolve("openvpn-server")
except NetworkError:
logging.info("Could not resolve openvpn-server address. Trying published port on host.")
openvpn_server_ip = primary_host_ip()
run(ssh_command, f"uci set openvpn.sample_client.remote='{openvpn_server_ip} 1194'")
openvpn_server_name = openvpn_server_env.map_hostname("openvpn-server")
openvpn_server_port = openvpn_server_env.port_mappings["udp"]["openvpn"]
run(ssh_command, f"uci set openvpn.sample_client.remote='{openvpn_server_name} {openvpn_server_port}'")
run(ssh_command, "uci commit openvpn")
run(ssh_command, "/etc/init.d/openvpn restart sample_client")

Expand Down
34 changes: 34 additions & 0 deletions tests/util/test_docker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from ipaddress import IPv4Address
from pathlib import Path
from unittest.mock import MagicMock, patch

from docker import DockerComposeWrapper, DockerInDockerComposeRenderer, LocalComposeRenderer

OPENVPN_COMPOSE_TEMPLATE: str = (Path(__file__).parent.parent / "openvpn" / "compose.yaml").read_text()


@patch("docker.in_docker_container")
def test_docker_compose_env_ok(in_docker_container: MagicMock) -> None:
in_docker_container.return_value = True
docker_compose_wrapper = DockerComposeWrapper(OPENVPN_COMPOSE_TEMPLATE, {})
logging.info(f"Current Docker Compose Env Dir: {docker_compose_wrapper.cwd}")


@patch("docker.primary_host_ip")
def test_docker_compose_renderer_local(primary_host_ip: MagicMock) -> None:
primary_host_ip.return_value = IPv4Address("1.2.3.4")

renderer = LocalComposeRenderer(OPENVPN_COMPOSE_TEMPLATE)
assert not renderer.port_mappings["tcp"]
assert len(renderer.port_mappings["udp"]) == 1 # openvpn port
assert renderer.map_service("openvpn-server") == "1.2.3.4"


@patch("docker.resolve")
def test_docker_compose_renderer_dind(resolve: MagicMock) -> None:
resolve.return_value = "mapped-service-name"

renderer = DockerInDockerComposeRenderer(OPENVPN_COMPOSE_TEMPLATE)
assert renderer.port_mappings["udp"]["openvpn"] == 1194
assert renderer.map_service("openvpn-server") == "mapped-service-name"
File renamed without changes.
136 changes: 119 additions & 17 deletions util/docker.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,137 @@
import json
import logging
import random
import re
import shutil
import string
import subprocess
import tempfile
from pathlib import Path
from typing import Any

import yaml
from fs import create_temp_dir
from network import NetworkError, get_free_tcp_port, get_free_udp_port, primary_host_ip, resolve

def create_temp_dir() -> Path:
while True:
# we don't need a super-secure cryptographic PRNG
random_name = "".join(random.choices(string.ascii_letters + string.digits, k=12)) # noqa: S311
temp_path = Path(tempfile.gettempdir()) / random_name
logging.info(f"Creating temporary directory {temp_path}.")

def in_docker_container() -> bool:
return Path("/.dockerenv").exists()


PortMappings = dict[str, dict[str, int]]


class ComposeRenderer:
def __init__(self, compose_template: str) -> None:
self._compose_data: dict = yaml.safe_load(compose_template)
self._port_mappings: PortMappings = {"tcp": {}, "udp": {}}

@property
def rendered(self) -> str:
return yaml.dump(self._compose_data, default_flow_style=False)

@property
def port_mappings(self) -> PortMappings:
return self._port_mappings

def map_service(self, hostname: str) -> str:
del hostname # unused
raise NotImplementedError()


class LocalComposeRenderer(ComposeRenderer):
def __init__(self, compose_template: str) -> None:
super().__init__(compose_template)
services = self._compose_data.get("services", {})
for _, service_data in services.items():
networks = service_data.get("networks", [])
if networks:
service_data["networks"] = [network for network in networks if network != "shared_network"]

ports: list[str] = service_data.get("ports", [])
for idx, port in enumerate(ports):
match = DOCKER_PORT_REGEX.match(port)
if match:
name = match.group(1)
port = int(match.group(2))
proto = match.group(4)
proto_str = proto if proto else "tcp"
free_port = get_free_udp_port() if proto else get_free_tcp_port()
self._port_mappings[proto_str][name] = free_port
ports[idx] = f"{free_port}:{port}/{proto_str}"
else:
logging.warning(f"Could not parse port expression {port}")

def map_service(self, hostname: str) -> str:
del hostname # unused
return str(primary_host_ip())


DOCKER_PORT_REGEX: re.Pattern = re.compile(
r"\{([\w-]+)\}:(\d+)(/((tcp)|(udp)))?",
flags=re.DOTALL | re.IGNORECASE,
)


class DockerInDockerComposeRenderer(ComposeRenderer):
def __init__(self, compose_template: str) -> None:
super().__init__(compose_template)
services = self._compose_data.get("services", {})
for service_name, service_data in services.items():
if "volumes" in service_data:
logging.info(
"Volume mounts are not supported in docker-in-docker scenarios. "
f"Some have been found in the {service_name} service."
)

ports: list[str] = service_data.get("ports", [])
for idx, port in enumerate(ports):
match = DOCKER_PORT_REGEX.match(port)
if match:
name = match.group(1)
port = int(match.group(2))
proto = match.group(4)
proto_str = proto if proto else "tcp"
self._port_mappings[proto_str][name] = port
ports[idx] = f"{port}:{port}/{proto_str}"
else:
logging.warning(f"Could not parse port expression {port}")

def map_service(self, hostname: str) -> str:
try:
temp_path.mkdir(exist_ok=False)
return temp_path
except FileExistsError:
continue
return str(resolve(hostname))
except NetworkError:
logging.info("Could not resolve openvpn-server address. Trying published port on host.")
return str(primary_host_ip())


def create_compose_renderer(compose_template: str) -> ComposeRenderer:
if in_docker_container():
return DockerInDockerComposeRenderer(compose_template)
return LocalComposeRenderer(compose_template)


class DockerComposeWrapper:
def __init__(self, compose_yaml: str, files: dict[str, bytes]) -> None:
def __init__(self, compose_template: str, files: dict[str, bytes]) -> None:
self._tmpdir = create_temp_dir()
self.compose_file = Path(self._tmpdir) / "compose.yaml"
self.compose_file.write_text(compose_yaml)
self._compose = create_compose_renderer(compose_template)
logging.info(f"Rendered compose YAML:\n{self._compose.rendered}")
(Path(self._tmpdir) / "compose.yaml").write_text(
self._compose.rendered,
)
for filename, contents in files.items():
(Path(self._tmpdir) / filename).write_bytes(contents)

@property
def cwd(self) -> Path:
return self._tmpdir

@property
def port_mappings(self) -> PortMappings:
return self._compose.port_mappings

def map_hostname(self, hostname: str) -> str:
return self._compose.map_service(hostname)

def _run_command(self, *args: str) -> str:
command = ["docker", "compose", "-f", self.compose_file] + list(args)
command = ["docker", "compose"] + list(args)
result = subprocess.run(
command,
cwd=self._tmpdir,
Expand Down Expand Up @@ -81,6 +181,8 @@ def kill(self, service: str | None = None, signal: str | None = None) -> None:
if service:
args.append(service)
self._run_command(*args)

def cleanup(self) -> None:
shutil.rmtree(self._tmpdir)

def ps(self, services: list[str] | None = None) -> list[dict[str, Any]]:
Expand Down
18 changes: 18 additions & 0 deletions util/fs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging
import random
import string
import tempfile
from pathlib import Path


def create_temp_dir() -> Path:
while True:
# we don't need a super-secure cryptographic PRNG
random_name = "".join(random.choices(string.ascii_letters + string.digits, k=12)) # noqa: S311
temp_path = Path(tempfile.gettempdir()) / random_name
logging.info(f"Creating temporary directory {temp_path}.")
try:
temp_path.mkdir(exist_ok=False)
return temp_path
except FileExistsError:
continue
14 changes: 14 additions & 0 deletions util/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import socket
from contextlib import closing
from ipaddress import IPv4Address
from socket import AF_INET, SOCK_DGRAM, SOCK_STREAM

from process import shell_run

Expand Down Expand Up @@ -28,3 +30,15 @@ def resolve(name: str) -> IPv4Address:
def is_port_in_use(port: int, kind: socket.SocketKind = socket.SOCK_STREAM) -> bool:
with socket.socket(socket.AF_INET, kind) as s:
return s.connect_ex(("localhost", port)) == 0


def get_free_tcp_port() -> int:
with closing(socket.socket(AF_INET, SOCK_STREAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]


def get_free_udp_port() -> int:
with closing(socket.socket(AF_INET, SOCK_DGRAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]

0 comments on commit 77cdd95

Please sign in to comment.