Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Apr 6, 2024
1 parent 29c0ef9 commit 19b39de
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
55 changes: 29 additions & 26 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import threading
import shutil
import regex
from typing import Tuple, List, Dict, Optional, Callable, Set, Any, Union
from typing import Tuple, List, Dict, Optional, Callable, Set, Any, Union, TYPE_CHECKING
from contextlib import ExitStack, suppress
from collections import Counter

Expand All @@ -36,8 +36,11 @@
from .cache import CallCache, new as new_call_cache
from .error import OutputError, Interrupted, Terminated, RunFailed, error_json

if TYPE_CHECKING: # otherwise-delayed heavy imports
from .task_container import TaskContainer

def run_local_task(

def run_local_task( # type: ignore[return]
cfg: config.Loader,
task: Tree.Task,
inputs: Env.Bindings[Value.Base],
Expand Down Expand Up @@ -108,7 +111,7 @@ def run_local_task(
assert cache

cleanup.enter_context(_statusbar.task_slotted())
container = None
maybe_container = None
try:
cache_key = f"{task.name}/{task.digest}/{Value.digest_env(inputs)}"
cached = cache.get(cache_key, inputs, task.effective_outputs)
Expand Down Expand Up @@ -142,7 +145,7 @@ def run_local_task(
with compose_coroutines(
[
(
lambda kwargs, cor=cor: cor(
lambda kwargs, cor=cor: cor( # type: ignore
cfg, logger, _run_id_stack + [run_id], run_dir, task, **kwargs
)
)
Expand All @@ -168,6 +171,7 @@ def run_local_task(

# create TaskContainer according to configuration
container = new_task_container(cfg, logger, run_id, run_dir)
maybe_container = container

# evaluate input/postinput declarations, including mapping from host to
# in-container file paths
Expand Down Expand Up @@ -263,7 +267,8 @@ def run_local_task(
logger.debug(traceback.format_exc())
logger.critical(_("failed to write error.json", dir=run_dir, message=str(exn2)))
try:
_delete_work(cfg, logger, container, False)
if maybe_container:
_delete_work(cfg, logger, maybe_container, False)
except Exception as exn2:
logger.debug(traceback.format_exc())
logger.error(_("delete_work also failed", exception=str(exn2)))
Expand Down Expand Up @@ -360,7 +365,7 @@ def _eval_task_inputs(
logger: logging.Logger,
task: Tree.Task,
posix_inputs: Env.Bindings[Value.Base],
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
) -> Env.Bindings[Value.Base]:
# Preprocess inputs: if None value is supplied for an input declared with a default but without
# the ? type quantifier, remove the binding entirely so that the default will be used. In
Expand Down Expand Up @@ -390,7 +395,7 @@ def map_paths(fn: Union[Value.File, Value.Directory]) -> str:
container_inputs = Value.rewrite_env_paths(posix_inputs, map_paths)

# initialize value environment with the inputs
container_env = Env.Bindings()
container_env: Env.Bindings[Value.Base] = Env.Bindings()
for b in container_inputs:
assert isinstance(b, Env.Binding)
v = b.value
Expand Down Expand Up @@ -457,9 +462,7 @@ def collector(v: Value.Base) -> None:
return ans


def _warn_input_basename_collisions(
logger: logging.Logger, container: "runtime.task_container.TaskContainer"
) -> None:
def _warn_input_basename_collisions(logger: logging.Logger, container: "TaskContainer") -> None:
basenames = Counter(
[os.path.basename((p[:-1] if p.endswith("/") else p)) for p in container.input_path_map_rev]
)
Expand All @@ -479,7 +482,7 @@ def _eval_task_runtime(
run_id: str,
task: Tree.Task,
inputs: Env.Bindings[Value.Base],
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
env: Env.Bindings[Value.Base],
stdlib: StdLib.Base,
) -> None:
Expand Down Expand Up @@ -538,7 +541,7 @@ def _eval_task_runtime(
container.runtime_values["env"].update(env_vars_override)

# process decls with "env" decorator (EXPERIMENTAL)
env_decls = {}
env_decls: Dict[str, Value.Base] = {}
for decl in (task.inputs or []) + task.postinputs:
if decl.decor.get("env", False) is True:
if not env_decls:
Expand Down Expand Up @@ -566,7 +569,7 @@ def _try_task(
cfg: config.Loader,
task: Tree.Task,
logger: logging.Logger,
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
command: str,
terminating: Callable[[], bool],
) -> None:
Expand Down Expand Up @@ -651,7 +654,7 @@ def _eval_task_outputs(
run_id: str,
task: Tree.Task,
env: Env.Bindings[Value.Base],
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
) -> Env.Bindings[Value.Base]:
stdout_file = os.path.join(container.host_dir, "stdout.txt")
with suppress(FileNotFoundError):
Expand All @@ -666,7 +669,7 @@ def _eval_task_outputs(
if isinstance(expr, Expr.Apply) and expr.function_name == "stdout":
stdout_used = True
else:
expr_stack.extend(expr.children)
expr_stack.extend(expr.children) # type: ignore[arg-type]
if not stdout_used:
logger.info(
_(
Expand Down Expand Up @@ -710,7 +713,7 @@ def rewriter2(v: Union[Value.File, Value.Directory], output_name: str) -> Option
return host_path

stdlib = OutputStdLib(task.effective_wdl_version, logger, container)
outputs = Env.Bindings()
outputs: Env.Bindings[Value.Base] = Env.Bindings()
for decl in task.outputs:
assert decl.expr
try:
Expand Down Expand Up @@ -741,9 +744,9 @@ def rewriter2(v: Union[Value.File, Value.Directory], output_name: str) -> Option
try:
v = v.coerce(decl.type)
except FileNotFoundError:
exn = OutputError("File/Directory path not found in task output " + decl.name)
setattr(exn, "job_id", decl.workflow_node_id)
raise exn
err = OutputError("File/Directory path not found in task output " + decl.name)
setattr(err, "job_id", decl.workflow_node_id)
raise err
# Rewrite in-container paths to host paths
v = Value.rewrite_paths(v, lambda w: rewriter2(w, decl.name))
outputs = outputs.bind(decl.name, v)
Expand Down Expand Up @@ -907,7 +910,7 @@ def link_outputs_relative(
link_outputs with [file_io] use_relative_output_paths = true. We organize the links to reflect
the generated files' paths relative to their task working directory.
"""
link_destinations = dict()
link_destinations: Dict[str, str] = dict()

def map_path_relative(v: Union[Value.File, Value.Directory]) -> str:
target = (
Expand Down Expand Up @@ -961,7 +964,7 @@ def map_path_relative(v: Union[Value.File, Value.Directory]) -> str:
def _warn_output_basename_collisions(
logger: logging.Logger, outputs: Env.Bindings[Value.Base]
) -> None:
targets_by_basename = {}
targets_by_basename: Dict[str, Set[str]] = {}

def walker(v: Union[Value.File, Value.Directory]) -> str:
target = v.value
Expand All @@ -987,7 +990,7 @@ def walker(v: Union[Value.File, Value.Directory]) -> str:
def _delete_work(
cfg: config.Loader,
logger: logging.Logger,
container: "Optional[runtime.task_container.TaskContainer]",
container: "Optional[TaskContainer]",
success: bool,
) -> None:
opt = cfg["file_io"]["delete_work"].strip().lower()
Expand All @@ -1004,14 +1007,14 @@ def _delete_work(

class _StdLib(StdLib.Base):
logger: logging.Logger
container: "runtime.task_container.TaskContainer"
container: "TaskContainer"
inputs_only: bool # if True then only permit access to input files

def __init__(
self,
wdl_version: str,
logger: logging.Logger,
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
inputs_only: bool,
) -> None:
super().__init__(wdl_version, write_dir=os.path.join(container.host_dir, "write_"))
Expand Down Expand Up @@ -1043,7 +1046,7 @@ def __init__(
self,
wdl_version: str,
logger: logging.Logger,
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
) -> None:
super().__init__(wdl_version, logger, container, True)

Expand All @@ -1054,7 +1057,7 @@ def __init__(
self,
wdl_version: str,
logger: logging.Logger,
container: "runtime.task_container.TaskContainer",
container: "TaskContainer",
) -> None:
super().__init__(wdl_version, logger, container, False)

Expand Down
14 changes: 1 addition & 13 deletions stubs/docker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, List, Iterable, Tuple
from typing import Dict, Any, List, Iterable, Tuple, Optional

class Container:
@property
Expand Down Expand Up @@ -96,18 +96,6 @@ class Mount:
def __init__(self, *args, **kwargs):
...

class errors:
class BuildError(Exception):
msg : str
build_log : Iterable[Dict[str,str]]

class ImageNotFound(Exception):
pass

class APIError(Exception):
def is_server_error(self) -> bool:
...

class DockerClient:
@property
def containers(self) -> Containers:
Expand Down
12 changes: 12 additions & 0 deletions stubs/docker/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Iterable, Dict

class BuildError(Exception):
msg : str
build_log : Iterable[Dict[str,str]]

class ImageNotFound(Exception):
pass

class APIError(Exception):
def is_server_error(self) -> bool:
...

0 comments on commit 19b39de

Please sign in to comment.