Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Apr 7, 2024
1 parent 19b39de commit 4d278d5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
12 changes: 6 additions & 6 deletions WDL/runtime/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def run(
doc.typecheck()
Walker.SetParents()(doc)
task = doc.tasks[0]
inputs = values_from_json(inputs, task.available_inputs) # pyre-ignore
inputs = values_from_json(inputs, task.available_inputs) # type: ignore[arg-type]
subdir, outputs_env = run_local_task(
cfg, task, inputs, run_id=("download-" + task.name), **kwargs
)

recv = cor.send(
{"outputs": values_to_json(outputs_env), "dir": subdir} # pyre-ignore
{"outputs": values_to_json(outputs_env), "dir": subdir} # type: ignore[arg-type]
)

ans = recv["outputs"]["directory" if directory else "file"]
Expand Down Expand Up @@ -310,7 +310,7 @@ def prepare_aws_credentials(
host_aws_credentials["AWS_EC2_METADATA_DISABLED"] = os.environ["AWS_EC2_METADATA_DISABLED"]
# get AWS credentials from boto3 (unless prevented by configuration)
if cfg["download_awscli"].get_bool("host_credentials"):
import boto3 # pyre-fixme
import boto3 # type: ignore

try:
b3creds = boto3.session.Session().get_credentials()
Expand All @@ -322,18 +322,18 @@ def prepare_aws_credentials(

if host_aws_credentials:
# write credentials to temp file that'll self-destruct afterwards
host_aws_credentials = (
host_aws_credentials_str = (
"\n".join(f"export {k}={shlex.quote(v)}" for (k, v) in host_aws_credentials.items())
+ "\n"
)
aws_credentials_file = cleanup.enter_context(
tempfile.NamedTemporaryFile(
prefix=hashlib.sha256(host_aws_credentials.encode()).hexdigest(),
prefix=hashlib.sha256(host_aws_credentials_str.encode()).hexdigest(),
delete=True,
mode="w",
)
)
print(host_aws_credentials, file=aws_credentials_file, flush=True)
print(host_aws_credentials_str, file=aws_credentials_file, flush=True)
# make file group-readable to ensure it'll be usable if the docker image runs as non-root
os.chmod(aws_credentials_file.name, os.stat(aws_credentials_file.name).st_mode | 0o40)
logger.getChild("awscli_downloader").info("loaded host AWS credentials")
Expand Down
15 changes: 8 additions & 7 deletions WDL/runtime/task_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
import threading
import typing
from typing import Callable, Iterable, Any, Dict, Optional, ContextManager
from typing import Callable, Iterable, Any, Dict, Optional, ContextManager, Set
from abc import ABC, abstractmethod
from contextlib import suppress
from .. import Error, Env, Value, Type
Expand Down Expand Up @@ -138,7 +138,7 @@ def add_paths(self, host_paths: Iterable[str]) -> None:
assert not self._running

# partition the files by host directory
host_paths_by_dir = {}
host_paths_by_dir: Dict[str, Set[str]] = {}
for host_path in host_paths:
host_path_strip = host_path.rstrip("/")
if host_path not in self.input_path_map and host_path_strip not in self.input_path_map:
Expand Down Expand Up @@ -203,8 +203,9 @@ def process_runtime(self, logger: logging.Logger, runtime_eval: Dict[str, Value.
dockerfile = runtime_eval["inlineDockerfile"]
if not isinstance(dockerfile, Value.Array):
dockerfile = Value.Array(dockerfile.type, [dockerfile])
dockerfile = "\n".join(elt.coerce(Type.String()).value for elt in dockerfile.value)
ans["inlineDockerfile"] = dockerfile
ans["inlineDockerfile"] = "\n".join(
elt.coerce(Type.String()).value for elt in dockerfile.value
)
elif "docker" in runtime_eval or "container" in runtime_eval:
docker_value = runtime_eval["container" if "container" in runtime_eval else "docker"]
if isinstance(docker_value, Value.Array) and len(docker_value.value):
Expand Down Expand Up @@ -249,8 +250,8 @@ def process_runtime(self, logger: logging.Logger, runtime_eval: Dict[str, Value.
except ValueError:
raise Error.RuntimeError("invalid setting of runtime.memory, " + memory_str)

memory_max = self.cfg["task_runtime"]["memory_max"].strip()
memory_max = -1 if memory_max == "-1" else parse_byte_size(memory_max)
memory_max_str = self.cfg["task_runtime"]["memory_max"].strip()
memory_max = -1 if memory_max_str == "-1" else parse_byte_size(memory_max_str)
if memory_max == 0:
memory_max = host_limits["mem_bytes"]
if memory_max > 0 and memory_bytes > memory_max:
Expand Down Expand Up @@ -525,7 +526,7 @@ def new(cfg: config.Loader, logger: logging.Logger, run_id: str, host_dir: str)
with _backends_lock:
if not _backends:
for plugin_name, plugin_cls in config.load_plugins(cfg, "container_backend"):
_backends[plugin_name] = plugin_cls # pyre-fixme
_backends[plugin_name] = plugin_cls # type: ignore
backend_cls = _backends[cfg["scheduler"]["container_backend"]]
if not getattr(backend_cls, "_global_init", False):
backend_cls.global_init(cfg, logger)
Expand Down

0 comments on commit 4d278d5

Please sign in to comment.