Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to specify times/durations in config as user-friendly strings #876

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cluster: testworker
loglevel: DEBUG
update_frequency: 5.0
update_frequency: 5

server:
fractal_uri: http://localhost:7900
Expand Down
13 changes: 12 additions & 1 deletion qcfractal/qcfractal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlalchemy.engine.url import URL, make_url

from qcfractal.port_util import find_open_port
from qcportal.utils import duration_to_seconds


def update_nested_dict(d, u):
Expand Down Expand Up @@ -315,6 +316,10 @@ class WebAPIConfig(ConfigBase):
None, description="Any additional options to pass directly to the waitress serve function"
)

@validator("jwt_access_token_expires", "jwt_refresh_token_expires", pre=True)
def _convert_durations(cls, v):
return duration_to_seconds(v)

class Config(ConfigCommon):
env_prefix = "QCF_API_"

Expand Down Expand Up @@ -374,7 +379,9 @@ class FractalConfig(ConfigBase):

# Access logging
log_access: bool = Field(False, description="Store API access in the database")
access_log_keep: int = Field(0, description="Number of days of access logs to keep. 0 means keep all")
access_log_keep: int = Field(
0, description="How far back to keep access logs (in seconds or a string). 0 means keep all"
)

# maxmind_account_id: Optional[int] = Field(None, description="Account ID for MaxMind GeoIP2 service")
maxmind_license_key: Optional[str] = Field(
Expand Down Expand Up @@ -454,6 +461,10 @@ def _check_loglevel(cls, v):
raise ValidationError(f"{v} is not a valid loglevel. Must be DEBUG, INFO, WARNING, ERROR, or CRITICAL")
return v

@validator("service_frequency", "heartbeat_frequency", "access_log_keep", pre=True)
def _convert_durations(cls, v):
return duration_to_seconds(v)

class Config(ConfigCommon):
env_prefix = "QCF_"

Expand Down
65 changes: 65 additions & 0 deletions qcfractal/qcfractal/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import copy

from qcfractal.config import FractalConfig

_base_config = {
"api": {
"secret_key": "abc1234def456",
"jwt_secret_key": "abc123def456",
},
"database": {"username": "qcfractal", "password": "abc123def456"},
}


def test_config_durations_plain(tmp_path):
base_folder = str(tmp_path)

base_config = copy.deepcopy(_base_config)
base_config["service_frequency"] = 3600
base_config["heartbeat_frequency"] = 30
base_config["access_log_keep"] = 100802
base_config["api"]["jwt_access_token_expires"] = 7450
base_config["api"]["jwt_refresh_token_expires"] = 637277
cfg = FractalConfig(base_folder=base_folder, **base_config)

assert cfg.service_frequency == 3600
assert cfg.heartbeat_frequency == 30
assert cfg.access_log_keep == 100802
assert cfg.api.jwt_access_token_expires == 7450
assert cfg.api.jwt_refresh_token_expires == 637277


def test_config_durations_str(tmp_path):
base_folder = str(tmp_path)

base_config = copy.deepcopy(_base_config)
base_config["service_frequency"] = "1h"
base_config["heartbeat_frequency"] = "30s"
base_config["access_log_keep"] = "1d4h2s"
base_config["api"]["jwt_access_token_expires"] = "2h4m10s"
base_config["api"]["jwt_refresh_token_expires"] = "7d9h77s"
cfg = FractalConfig(base_folder=base_folder, **base_config)

assert cfg.service_frequency == 3600
assert cfg.heartbeat_frequency == 30
assert cfg.access_log_keep == 100802
assert cfg.api.jwt_access_token_expires == 7450
assert cfg.api.jwt_refresh_token_expires == 637277


def test_config_durations_dhms(tmp_path):
base_folder = str(tmp_path)

base_config = copy.deepcopy(_base_config)
base_config["service_frequency"] = "1:00:00"
base_config["heartbeat_frequency"] = "30"
base_config["access_log_keep"] = "1:04:00:02"
base_config["api"]["jwt_access_token_expires"] = "2:04:10"
base_config["api"]["jwt_refresh_token_expires"] = "7:09:00:77"
cfg = FractalConfig(base_folder=base_folder, **base_config)

assert cfg.service_frequency == 3600
assert cfg.heartbeat_frequency == 30
assert cfg.access_log_keep == 100802
assert cfg.api.jwt_access_token_expires == 7450
assert cfg.api.jwt_refresh_token_expires == 637277
16 changes: 7 additions & 9 deletions qcfractalcompute/qcfractalcompute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel, Field, validator
from typing_extensions import Literal

from qcportal.utils import seconds_to_hms
from qcportal.utils import seconds_to_hms, duration_to_seconds


def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -120,10 +120,7 @@ class TorqueExecutorConfig(ExecutorConfig):

@validator("walltime", pre=True)
def walltime_must_be_str(cls, v):
if isinstance(v, int):
return seconds_to_hms(v)
else:
return v
return seconds_to_hms(duration_to_seconds(v))


class LSFExecutorConfig(ExecutorConfig):
Expand All @@ -143,10 +140,7 @@ class LSFExecutorConfig(ExecutorConfig):

@validator("walltime", pre=True)
def walltime_must_be_str(cls, v):
if isinstance(v, int):
return seconds_to_hms(v)
else:
return v
return seconds_to_hms(duration_to_seconds(v))


AllExecutorTypes = Union[
Expand Down Expand Up @@ -226,6 +220,10 @@ def _check_logfile(cls, v, values):
def _check_run_dir(cls, v, values):
return _make_abs_path(v, values["base_folder"], "parsl_run_dir")

@validator("update_frequency", pre=True)
def _convert_durations(cls, v):
return duration_to_seconds(v)


def read_configuration(file_paths: List[str], extra_config: Optional[Dict[str, Any]] = None) -> FractalComputeConfig:
logger = logging.getLogger(__name__)
Expand Down
33 changes: 32 additions & 1 deletion qcfractalcompute/qcfractalcompute/test_manager_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from __future__ import annotations

import copy

import pytest
import yaml

from qcfractalcompute.config import SlurmExecutorConfig
from qcfractalcompute.config import SlurmExecutorConfig, FractalComputeConfig

_base_config = {
"cluster": "testcluster",
"server": {
"fractal_uri": "http://localhost:7777/",
},
"executors": {},
}


@pytest.mark.parametrize("time_str", ["02:01:59", "72:00:00", "10:00:00"])
Expand Down Expand Up @@ -39,3 +49,24 @@ def test_manager_config_walltime(time_str):
config = yaml.load(config_yaml, yaml.SafeLoader)
executor_config = SlurmExecutorConfig(**config)
assert executor_config.walltime == time_str


def test_manager_config_durations(tmp_path):
base_folder = str(tmp_path)
base_config = copy.deepcopy(_base_config)

base_config["update_frequency"] = "900"
manager_config = FractalComputeConfig(base_folder=base_folder, **base_config)
assert manager_config.update_frequency == 900

base_config["update_frequency"] = 900
manager_config = FractalComputeConfig(base_folder=base_folder, **base_config)
assert manager_config.update_frequency == 900

base_config["update_frequency"] = "3d4h80m09s"
manager_config = FractalComputeConfig(base_folder=base_folder, **base_config)
assert manager_config.update_frequency == 278409

base_config["update_frequency"] = "3:04:80:9"
manager_config = FractalComputeConfig(base_folder=base_folder, **base_config)
assert manager_config.update_frequency == 278409
34 changes: 33 additions & 1 deletion qcportal/qcportal/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from qcportal.utils import chunk_iterable, seconds_to_hms, is_included
from qcportal.utils import chunk_iterable, seconds_to_hms, duration_to_seconds, is_included


def test_chunk_iterable():
Expand Down Expand Up @@ -29,6 +29,38 @@ def test_seconds_to_hms():
assert seconds_to_hms(3670.12) == "01:01:10.12"


def test_duration_to_seconds():
assert duration_to_seconds(0) == 0
assert duration_to_seconds("0") == 0
assert duration_to_seconds(17) == 17
assert duration_to_seconds("17") == 17
assert duration_to_seconds(17.0) == 17
assert duration_to_seconds("17.0") == 17

assert duration_to_seconds("17s") == 17
assert duration_to_seconds("70s") == 70
assert duration_to_seconds("8m17s") == 497
assert duration_to_seconds("80m72s") == 4872
assert duration_to_seconds("3h8m17s") == 11297
assert duration_to_seconds("03h08m07s") == 11287
assert duration_to_seconds("03h08m070s") == 11350
assert duration_to_seconds("9d03h08m070s") == 788950

assert duration_to_seconds("9d") == 777600
assert duration_to_seconds("10m") == 600
assert duration_to_seconds("90m") == 5400
assert duration_to_seconds("04h") == 14400
assert duration_to_seconds("4h5s") == 14405
assert duration_to_seconds("1d9s") == 86409

assert duration_to_seconds("8:17") == 497
assert duration_to_seconds("80:72") == 4872
assert duration_to_seconds("3:8:17") == 11297
assert duration_to_seconds("03:08:07") == 11287
assert duration_to_seconds("03:08:070") == 11350
assert duration_to_seconds("9:03:08:07") == 788887


def test_is_included():
assert is_included("test", None, None, True) is True
assert is_included("test", None, None, False) is False
Expand Down
48 changes: 48 additions & 0 deletions qcportal/qcportal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import logging
import math
import re
import time
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from hashlib import sha256
Expand Down Expand Up @@ -261,6 +262,53 @@ def seconds_to_hms(seconds: Union[float, int]) -> str:
return f"{hours:02d}:{minutes:02d}:{seconds+fraction:02.2f}"


def duration_to_seconds(s: Union[int, str, float]) -> int:
"""
Parses a string in dd:hh:mm:ss or 1d2h3m4s to an integer number of seconds
"""

# Is already an int
if isinstance(s, int):
return s

# Is a float but represents an integer
if isinstance(s, float):
if s.is_integer():
return int(s)
else:
raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds")

# Plain number of seconds (as a string)
if s.isdigit():
return int(s)

try:
f = float(s)
if f.is_integer():
return int(f)
else:
raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds")
except ValueError:
pass

# Handle dd:hh:mm:ss format
if ":" in s:
parts = list(map(int, s.split(":")))
while len(parts) < 4: # Pad missing parts with zeros
parts.insert(0, 0)
days, hours, minutes, seconds = parts
return days * 86400 + hours * 3600 + minutes * 60 + seconds

# Handle format like 3d4h7m10s
pattern = re.compile(r"(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?")
match = pattern.fullmatch(s)
if not match:
raise ValueError(f"Invalid duration format: {s}")

days, hours, minutes, seconds = map(lambda x: int(x) if x else 0, match.groups())
return days * 86400 + hours * 3600 + minutes * 60 + seconds


def recursive_normalizer(value: Any, digits: int = 10, lowercase: bool = True) -> Any:
"""
Prepare a structure for hashing by lowercasing all values and round all floats
Expand Down
Loading