From 082211cbe95806620c3b00b912ff071a66164e65 Mon Sep 17 00:00:00 2001 From: Haiqi Xu <14502009+haiqi96@users.noreply.github.com> Date: Fri, 17 Jan 2025 20:47:33 -0500 Subject: [PATCH] fix --- .../clp-py-utils/clp_py_utils/clp_config.py | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/components/clp-py-utils/clp_py_utils/clp_config.py b/components/clp-py-utils/clp_py_utils/clp_config.py index 11c2b8aea..55b68d365 100644 --- a/components/clp-py-utils/clp_py_utils/clp_config.py +++ b/components/clp-py-utils/clp_py_utils/clp_config.py @@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple, Union from dotenv import dotenv_values -from pydantic import BaseModel, PrivateAttr, validator +from pydantic import BaseModel, Extra, PrivateAttr, validator from strenum import KebabCaseStrEnum, LowercaseStrEnum from .clp_logging import get_valid_logging_level, is_valid_logging_level @@ -57,7 +57,12 @@ class StorageType(LowercaseStrEnum): VALID_STORAGE_ENGINES = [storage_engine.value for storage_engine in StorageEngine] -class Package(BaseModel): +class BaseModelForbidExtra(BaseModel): + class Config: + extra = Extra.forbid + + +class Package(BaseModelForbidExtra): storage_engine: str = "clp" @validator("storage_engine") @@ -70,7 +75,7 @@ def validate_storage_engine(cls, field): return field -class Database(BaseModel): +class Database(BaseModelForbidExtra): type: str = "mariadb" host: str = "localhost" port: int = 3306 @@ -174,7 +179,7 @@ def _validate_port(cls, field): ) -class CompressionScheduler(BaseModel): +class CompressionScheduler(BaseModelForbidExtra): jobs_poll_delay: float = 0.1 # seconds logging_level: str = "INFO" @@ -184,7 +189,7 @@ def validate_logging_level(cls, field): return field -class QueryScheduler(BaseModel): +class QueryScheduler(BaseModelForbidExtra): host = "localhost" port = 7000 jobs_poll_delay: float = 0.1 # seconds @@ -209,7 +214,7 @@ def validate_port(cls, field): return field -class CompressionWorker(BaseModel): +class CompressionWorker(BaseModelForbidExtra): logging_level: str = "INFO" @validator("logging_level") @@ -218,7 +223,7 @@ def validate_logging_level(cls, field): return field -class QueryWorker(BaseModel): +class QueryWorker(BaseModelForbidExtra): logging_level: str = "INFO" @validator("logging_level") @@ -227,7 +232,7 @@ def validate_logging_level(cls, field): return field -class Redis(BaseModel): +class Redis(BaseModelForbidExtra): host: str = "localhost" port: int = 6379 query_backend_database: int = 0 @@ -242,7 +247,7 @@ def validate_host(cls, field): return field -class Reducer(BaseModel): +class Reducer(BaseModelForbidExtra): host: str = "localhost" base_port: int = 14009 logging_level: str = "INFO" @@ -272,7 +277,7 @@ def validate_upsert_interval(cls, field): return field -class ResultsCache(BaseModel): +class ResultsCache(BaseModelForbidExtra): host: str = "localhost" port: int = 27017 db_name: str = "clp-query-results" @@ -302,7 +307,7 @@ def get_uri(self): return f"mongodb://{self.host}:{self.port}/{self.db_name}" -class Queue(BaseModel): +class Queue(BaseModelForbidExtra): host: str = "localhost" port: int = 5672 @@ -310,7 +315,7 @@ class Queue(BaseModel): password: Optional[str] -class S3Credentials(BaseModel): +class S3Credentials(BaseModelForbidExtra): access_key_id: str secret_access_key: str @@ -327,7 +332,7 @@ def validate_secret_access_key(cls, field): return field -class S3Config(BaseModel): +class S3Config(BaseModelForbidExtra): region_code: str bucket: str key_prefix: str @@ -360,7 +365,7 @@ def get_credentials(self) -> Tuple[Optional[str], Optional[str]]: return self.credentials.access_key_id, self.credentials.secret_access_key -class FsStorage(BaseModel): +class FsStorage(BaseModelForbidExtra): type: Literal[StorageType.FS.value] = StorageType.FS.value directory: pathlib.Path @@ -379,7 +384,7 @@ def dump_to_primitive_dict(self): return d -class S3Storage(BaseModel): +class S3Storage(BaseModelForbidExtra): type: Literal[StorageType.S3.value] = StorageType.S3.value staging_directory: pathlib.Path s3_config: S3Config @@ -437,7 +442,7 @@ def _set_directory_for_storage_config( raise NotImplementedError(f"storage.type {storage_type} is not supported") -class ArchiveOutput(BaseModel): +class ArchiveOutput(BaseModelForbidExtra): storage: Union[ArchiveFsStorage, ArchiveS3Storage] = ArchiveFsStorage() target_archive_size: int = 256 * 1024 * 1024 # 256 MB target_dictionaries_size: int = 32 * 1024 * 1024 # 32 MB @@ -480,7 +485,7 @@ def dump_to_primitive_dict(self): return d -class StreamOutput(BaseModel): +class StreamOutput(BaseModelForbidExtra): storage: Union[StreamFsStorage, StreamS3Storage] = StreamFsStorage() target_uncompressed_size: int = 128 * 1024 * 1024 @@ -502,7 +507,7 @@ def dump_to_primitive_dict(self): return d -class WebUi(BaseModel): +class WebUi(BaseModelForbidExtra): host: str = "localhost" port: int = 4000 logging_level: str = "INFO" @@ -523,7 +528,7 @@ def validate_logging_level(cls, field): return field -class LogViewerWebUi(BaseModel): +class LogViewerWebUi(BaseModelForbidExtra): host: str = "localhost" port: int = 3000 @@ -538,7 +543,7 @@ def validate_port(cls, field): return field -class CLPConfig(BaseModel): +class CLPConfig(BaseModelForbidExtra): execution_container: Optional[str] = None input_logs_directory: pathlib.Path = pathlib.Path("/") @@ -678,7 +683,7 @@ def dump_to_primitive_dict(self): return d -class WorkerConfig(BaseModel): +class WorkerConfig(BaseModelForbidExtra): package: Package = Package() archive_output: ArchiveOutput = ArchiveOutput() data_directory: pathlib.Path = CLPConfig().data_directory