diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index 34dde1ac..27719420 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -533,10 +533,15 @@ class Project(BaseModel): ) jobstore: dict = Field( default_factory=lambda: dict(DEFAULT_JOBSTORE), - description="The JobStore used for the input. Can contain the monty " - "serialized dictionary or the Store int the Jobflow format", + description="The JobStore used for the output. Can contain the monty " + "serialized dictionary or the Store in the Jobflow format", validate_default=True, ) + remote_jobstore: Optional[dict] = Field( + None, + description="The JobStore used for the data transfer between the Runner" + "and the workers. Can be a string with the standard values", + ) metadata: Optional[dict] = Field( None, description="A dictionary with metadata associated to the project" ) diff --git a/src/jobflow_remote/jobs/jobcontroller.py b/src/jobflow_remote/jobs/jobcontroller.py index 9c72938a..88525611 100644 --- a/src/jobflow_remote/jobs/jobcontroller.py +++ b/src/jobflow_remote/jobs/jobcontroller.py @@ -2796,7 +2796,9 @@ def complete_job( self.update_flow_state(host_flow_id) return True - remote_store = get_remote_store(store, local_path) + remote_store = get_remote_store( + store, local_path, self.project.remote_jobstore + ) update_store(store, remote_store, job_doc["db_id"]) diff --git a/src/jobflow_remote/jobs/run.py b/src/jobflow_remote/jobs/run.py index 5e121654..77909d9d 100644 --- a/src/jobflow_remote/jobs/run.py +++ b/src/jobflow_remote/jobs/run.py @@ -18,11 +18,7 @@ from jobflow_remote.jobs.batch import LocalBatchManager from jobflow_remote.jobs.data import IN_FILENAME, OUT_FILENAME -from jobflow_remote.remote.data import ( - default_orjson_serializer, - get_job_path, - get_remote_store_filenames, -) +from jobflow_remote.remote.data import get_job_path, get_store_file_paths from jobflow_remote.utils.log import initialize_remote_run_log logger = logging.getLogger(__name__) @@ -43,14 +39,6 @@ def run_remote_job(run_dir: str | Path = "."): job: Job = in_data["job"] store = in_data["store"] - # needs to be set here again since it does not get properly serialized. - # it is possible to serialize the default function before serializing, but - # avoided that to avoid that any refactoring of the - # default_orjson_serializer breaks the deserialization of old Fireworks - store.docs_store.serialization_default = default_orjson_serializer - for additional_store in store.additional_stores.values(): - additional_store.serialization_default = default_orjson_serializer - store.connect() initialize_logger() @@ -58,9 +46,16 @@ def run_remote_job(run_dir: str | Path = "."): response = job.run(store=store) finally: # some jobs may have compressed the FW files while being executed, - # try to decompress them if that is the case. + # try to decompress them if that is the case and files need to be + # decompressed. decompress_files(store) + # Close the store explicitly, as minimal stores may require it. + try: + store.close() + except Exception: + logger.error("Error while closing the store", exc_info=True) + # The output of the response has already been stored in the store. response.output = None @@ -161,7 +156,7 @@ def run_batch_jobs( def decompress_files(store: JobStore): file_names = [OUT_FILENAME] - file_names.extend(get_remote_store_filenames(store)) + file_names.extend(os.path.basename(p) for p in get_store_file_paths(store)) for fn in file_names: # If the file is already present do not decompress it, even if diff --git a/src/jobflow_remote/jobs/runner.py b/src/jobflow_remote/jobs/runner.py index 045b54c2..d5cd9155 100644 --- a/src/jobflow_remote/jobs/runner.py +++ b/src/jobflow_remote/jobs/runner.py @@ -427,7 +427,7 @@ def upload(self, lock: MongoLock): # serializer could undergo refactoring and this could break deserialization # of older FWs. It is set in the FireTask at runtime. remote_store = get_remote_store( - store=store, launch_dir=remote_path, add_orjson_serializer=False + store=store, work_dir=remote_path, config_dict=self.project.remote_jobstore ) created = host.mkdir(remote_path) @@ -596,7 +596,11 @@ def download(self, lock): makedirs_p(local_path) fnames = [OUT_FILENAME] - fnames.extend(get_remote_store_filenames(store)) + fnames.extend( + get_remote_store_filenames( + store, config_dict=self.project.remote_jobstore + ) + ) for fname in fnames: # in principle fabric should work by just passing the diff --git a/src/jobflow_remote/remote/data.py b/src/jobflow_remote/remote/data.py index aac851d0..c25f2cb5 100644 --- a/src/jobflow_remote/remote/data.py +++ b/src/jobflow_remote/remote/data.py @@ -7,17 +7,18 @@ import os from collections.abc import Iterator from pathlib import Path -from typing import Any, Callable +from typing import Any import orjson from jobflow.core.job import Job from jobflow.core.store import JobStore from maggma.core import Sort, Store +from maggma.stores import JSONStore from maggma.utils import to_dt from monty.io import zopen # from maggma.stores.mongolike import JSONStore -from monty.json import jsanitize +from monty.json import MontyDecoder, jsanitize from jobflow_remote.jobs.data import RemoteError from jobflow_remote.utils.data import uuid_to_path @@ -58,22 +59,21 @@ def default_orjson_serializer(obj: Any) -> Any: def get_remote_store( - store: JobStore, launch_dir: str | Path, add_orjson_serializer: bool = True + store: JobStore, work_dir: str | Path, config_dict: dict | None ) -> JobStore: - serialization_default = None - if add_orjson_serializer: - serialization_default = default_orjson_serializer - docs_store = MinimalJSONStore( - os.path.join(launch_dir, "remote_job_data.json.gz"), - serialization_default=serialization_default, + docs_store = get_single_store( + config_dict=config_dict, file_name="remote_job_data", dir_path=work_dir ) + additional_stores = {} for k in store.additional_stores.keys(): - additional_stores[k] = MinimalJSONStore( - os.path.join(launch_dir, f"additional_store_{k}.json.gz"), - serialization_default=serialization_default, + additional_stores[k] = get_single_store( + config_dict=config_dict, + file_name=f"additional_store_{k}", + dir_path=work_dir, ) + remote_store = JobStore( docs_store=docs_store, additional_stores=additional_stores, @@ -84,14 +84,86 @@ def get_remote_store( return remote_store -def get_remote_store_filenames(store: JobStore) -> list[str]: - filenames = ["remote_job_data.json.gz"] +default_remote_store = {"store": "maggma_json", "zip": False} + + +def get_single_store( + config_dict: dict | None, file_name: str, dir_path: str | Path +) -> Store: + config_dict = config_dict or default_remote_store + + store_type = config_dict.get("store", default_remote_store["store"]) + total_file_name = get_single_store_file_name(config_dict, file_name) + file_path = os.path.join(dir_path, total_file_name) + if store_type == "maggma_json": + return StdJSONStore(file_path) + elif store_type == "orjson": + return MinimalORJSONStore(file_path) + elif store_type == "msgspec_json": + return MinimalMsgspecJSONStore(file_path) + elif store_type == "msgpack": + return MinimalMsgpackStore(file_path) + elif isinstance(store_type, dict): + store_type = dict(store_type) + store_type["path"] = file_path + store = MontyDecoder().process_decoded(store_type) + if not isinstance(store, Store): + raise ValueError( + f"Could not instantiate a proper store from remote config dict {store_type}" + ) + else: + raise ValueError(f"remote store type not supported: {store_type}") + + +def get_single_store_file_name(config_dict: dict | None, file_name: str) -> str: + config_dict = config_dict or default_remote_store + store_type = config_dict.get("store", default_remote_store["store"]) + + if isinstance(store_type, str) and "json" in store_type: + ext = "json" + elif isinstance(store_type, str) and "msgpack" in store_type: + ext = "msgpack" + else: + ext = config_dict.get("extension") # type: ignore + if not ext: + raise ValueError( + f"Could not determine extension for remote store config dict: {config_dict}" + ) + total_file_name = f"{file_name}.{ext}" + if config_dict.get("zip", False): + total_file_name += ".gz" + return total_file_name + + +def get_remote_store_filenames(store: JobStore, config_dict: dict | None) -> list[str]: + filenames = [ + get_single_store_file_name(config_dict=config_dict, file_name="remote_job_data") + ] for k in store.additional_stores.keys(): - filenames.append(f"additional_store_{k}.json.gz") + filenames.append( + get_single_store_file_name( + config_dict=config_dict, file_name=f"additional_store_{k}" + ) + ) return filenames +def get_store_file_paths(store: JobStore) -> list[str]: + def get_single_path(base_store: Store): + paths = getattr(base_store, "paths", None) + if paths: + return paths[0] + path = getattr(base_store, "path", None) + if not path: + raise RuntimeError(f"Could not determine the path for {base_store}") + return path + + store_paths = [get_single_path(store.docs_store)] + store_paths.extend(get_single_path(s) for s in store.additional_stores.values()) + return store_paths + + def update_store(store: JobStore, remote_store: JobStore, db_id: int): try: store.connect() @@ -221,6 +293,21 @@ def check_additional_stores(job: dict | Job, store: JobStore) -> list[str]: return missing_stores +class StdJSONStore(JSONStore): + """ + Simple subclass of the JSONStore defining the serialization_default + that cannot be dumped to json + """ + + def __init__(self, paths, **kwargs): + super().__init__( + paths=paths, + serialization_default=default_orjson_serializer, + read_only=False, + **kwargs, + ) + + class MinimalFileStore(Store): """ A Minimal Store for access to a single file. @@ -231,12 +318,8 @@ class MinimalFileStore(Store): def _collection(self): raise NotImplementedError - @property - def name(self) -> str: - return f"json://{self.path}" - def close(self): - pass + self.update_file() def count(self, criteria: dict | None = None) -> int: return len(self.data) @@ -273,8 +356,6 @@ def groupby( def __init__( self, path: str, - serialization_option: int | None = None, - serialization_default: Callable[[Any], Any] | None = None, **kwargs, ): """ @@ -287,8 +368,6 @@ def __init__( self.kwargs = kwargs self.default_sort = None - self.serialization_option = serialization_option - self.serialization_default = serialization_default self.data: list[dict] = [] super().__init__(**kwargs) @@ -322,8 +401,6 @@ def update(self, docs: list[dict] | dict, key: list | str | None = None): self.data.extend(docs) - self.update_file() - def update_file(self): raise NotImplementedError @@ -358,41 +435,65 @@ def __eq__(self, other: object) -> bool: return all(getattr(self, f) == getattr(other, f) for f in fields) -class MinimalJSONStore(MinimalFileStore): +class MinimalORJSONStore(MinimalFileStore): - def __init__( - self, - path: str, - serialization_option: int | None = None, - serialization_default: Callable[[Any], Any] | None = None, - **kwargs, - ): + @property + def name(self) -> str: + return f"json://{self.path}" + + def update_file(self): """ - Args: - path: paths for json files to turn into a Store - serialization_option: - option that will be passed to the orjson.dump when saving to the - json file. - serialization_default: - default that will be passed to the orjson.dump when saving to the - json file. + Updates the json file when a write-like operation is performed. """ - self.serialization_option = serialization_option - self.serialization_default = serialization_default + with zopen(self.path, "wb") as f: + for d in self.data: + d.pop("_id", None) + bytesdata = orjson.dumps( + self.data, + default=default_orjson_serializer, + ) + f.write(bytesdata) - super().__init__(path=path, **kwargs) + def read_file(self) -> list: + """ + Helper method to read the contents of a JSON file and generate + a list of docs. + """ + with zopen(self.path, "rb") as f: + data = f.read() + if not data: + return [] + objects = orjson.loads(data) + objects = [objects] if not isinstance(objects, list) else objects + # datetime objects deserialize to str. Try to convert the last_updated + # field back to datetime. + # # TODO - there may still be problems caused if a JSONStore is init'ed from + # documents that don't contain a last_updated field + # See Store.last_updated in store.py. + for obj in objects: + if obj.get(self.last_updated_field): + obj[self.last_updated_field] = to_dt(obj[self.last_updated_field]) + + return objects + + +class MinimalMsgspecJSONStore(MinimalFileStore): + + @property + def name(self) -> str: + return f"json://{self.path}" def update_file(self): """ Updates the json file when a write-like operation is performed. """ + import msgspec + with zopen(self.path, "wb") as f: for d in self.data: d.pop("_id", None) - bytesdata = orjson.dumps( + bytesdata = msgspec.json.encode( self.data, - option=self.serialization_option, - default=self.serialization_default, ) f.write(bytesdata) @@ -401,9 +502,13 @@ def read_file(self) -> list: Helper method to read the contents of a JSON file and generate a list of docs. """ + import msgspec + with zopen(self.path, "rb") as f: data = f.read() - objects = orjson.loads(data) + if not data: + return [] + objects = msgspec.json.decode(data) objects = [objects] if not isinstance(objects, list) else objects # datetime objects deserialize to str. Try to convert the last_updated # field back to datetime. @@ -431,6 +536,10 @@ def encode_datetime(obj): class MinimalMsgpackStore(MinimalFileStore): + @property + def name(self) -> str: + return f"msgpack://{self.path}" + def update_file(self): """ Updates the msgpack file when a write-like operation is performed.