From 1654d46d78dff6ec7e8e230074200e7ab030bda0 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 27 May 2024 18:25:31 +0100 Subject: [PATCH] Work on checkpoints --- src/anemoi/inference/commands/edit.py | 55 +++++++++++++-------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/anemoi/inference/commands/edit.py b/src/anemoi/inference/commands/edit.py index 4e337c6..c9dcd3a 100644 --- a/src/anemoi/inference/commands/edit.py +++ b/src/anemoi/inference/commands/edit.py @@ -9,28 +9,13 @@ # -from ..checkpoint import Checkpoint -from . import Command - - -def visit(x, path, name, value): - if isinstance(x, dict): - for k, v in x.items(): - if k == name: - print(".".join(path), k, v) +import os +import subprocess +from tempfile import TemporaryDirectory - if v == value: - print(".".join(path), k, v) +import yaml - path.append(k) - visit(v, path, name, value) - path.pop() - - if isinstance(x, list): - for i, v in enumerate(x): - path.append(str(i)) - visit(v, path, name, value) - path.pop() +from . import Command class EditCmd(Command): @@ -39,19 +24,31 @@ class EditCmd(Command): def add_arguments(self, command_parser): command_parser.add_argument("path", help="Path to the checkpoint.") - command_parser.add_argument("--name", help="Search for a specific name.") - command_parser.add_argument("--value", help="Search for a specific value.") + command_parser.add_argument("--editor", help="Editor to use.", default=os.environ.get("EDITOR", "vi")) def run(self, args): - checkpoint = Checkpoint(args.path) + from anemoi.utils.checkpoints import DEFAULT_NAME + from anemoi.utils.checkpoints import load_metadata + from anemoi.utils.checkpoints import metadata_files + from anemoi.utils.checkpoints import replace_metadata + + OLD_NAME = "ai-models.json" + + names = metadata_files(args.path) + name = OLD_NAME if OLD_NAME in names else DEFAULT_NAME + + with TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "checkpoint.yaml") + with open(path, "w") as f: + yaml.dump(load_metadata(args.path, name), f) + + subprocess.check_call([args.editor, path]) + + with open(path) as f: + replace_metadata(args.path, yaml.safe_load(f), OLD_NAME) - visit( - checkpoint, - [], - args.name if args.name is not None else object(), - args.value if args.value is not None else object(), - ) + # checkpoint.pack(temp_dir, args.path) command = EditCmd