diff --git a/src/ai_server/src/utils.py b/src/ai_server/src/utils.py index 667f5b0..2cb549f 100644 --- a/src/ai_server/src/utils.py +++ b/src/ai_server/src/utils.py @@ -4,6 +4,12 @@ from dataclasses import dataclass, field +DEFAULT_GENERAL_MODEL = "microsoft/Phi-3.5-mini-instruct" +DEFAULT_VISION_MODEL = "microsoft/Phi-3.5-vision-instruct" +DEFAULT_CONTENT_EXPANSION_MODEL = "microsoft/Phi-3.5-mini-instruct" +DEFAULT_GOOGLE_QUERY_TRANSLATOR_MODEL = "microsoft/Phi-3.5-mini-instruct" + + @dataclass class EmbeddingsConfig: default_model_group: Optional[str] = "en" @@ -22,14 +28,14 @@ def __post_init__(self): @dataclass class SamplingParams: - temperature: float = 0.8 + temperature: float = 0.2 top_p: float = 0.95 max_tokens: int = 512 @dataclass class ModelConfig: - id: str = "microsoft/Phi-3.5-mini-instruct" + id: str = DEFAULT_GENERAL_MODEL tensor_parallel_size: int = 1 sampling_params: SamplingParams = field(default_factory=SamplingParams) @@ -38,23 +44,23 @@ class ModelConfig: class LLMs: content_expansion: Optional[ModelConfig] = field( default_factory=lambda: ModelConfig( - id="microsoft/Phi-3.5-mini-instruct", + id=DEFAULT_CONTENT_EXPANSION_MODEL, tensor_parallel_size=1, - sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512), + sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=512), ) ) vision: Optional[ModelConfig] = field( default_factory=lambda: ModelConfig( - id="microsoft/Phi-3.5-vision-instruct", + id=DEFAULT_VISION_MODEL, tensor_parallel_size=1, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95, max_tokens=512), + sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=512), ) ) google_query_translator: Optional[ModelConfig] = field( default_factory=lambda: ModelConfig( - id="microsoft/Phi-3.5-mini-instruct", + id=DEFAULT_GOOGLE_QUERY_TRANSLATOR_MODEL, tensor_parallel_size=1, - sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20), + sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=20), ) )