Skip to content

Commit

Permalink
feat: improves defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Jan 11, 2025
1 parent 5edb115 commit a6699aa
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/ai_server/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from dataclasses import dataclass, field


DEFAULT_GENERAL_MODEL = "microsoft/Phi-3.5-mini-instruct"
DEFAULT_VISION_MODEL = "microsoft/Phi-3.5-vision-instruct"
DEFAULT_CONTENT_EXPANSION_MODEL = "microsoft/Phi-3.5-mini-instruct"
DEFAULT_GOOGLE_QUERY_TRANSLATOR_MODEL = "microsoft/Phi-3.5-mini-instruct"


@dataclass
class EmbeddingsConfig:
default_model_group: Optional[str] = "en"
Expand All @@ -22,14 +28,14 @@ def __post_init__(self):

@dataclass
class SamplingParams:
temperature: float = 0.8
temperature: float = 0.2
top_p: float = 0.95
max_tokens: int = 512


@dataclass
class ModelConfig:
id: str = "microsoft/Phi-3.5-mini-instruct"
id: str = DEFAULT_GENERAL_MODEL
tensor_parallel_size: int = 1
sampling_params: SamplingParams = field(default_factory=SamplingParams)

Expand All @@ -38,23 +44,23 @@ class ModelConfig:
class LLMs:
content_expansion: Optional[ModelConfig] = field(
default_factory=lambda: ModelConfig(
id="microsoft/Phi-3.5-mini-instruct",
id=DEFAULT_CONTENT_EXPANSION_MODEL,
tensor_parallel_size=1,
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512),
sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=512),
)
)
vision: Optional[ModelConfig] = field(
default_factory=lambda: ModelConfig(
id="microsoft/Phi-3.5-vision-instruct",
id=DEFAULT_VISION_MODEL,
tensor_parallel_size=1,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95, max_tokens=512),
sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=512),
)
)
google_query_translator: Optional[ModelConfig] = field(
default_factory=lambda: ModelConfig(
id="microsoft/Phi-3.5-mini-instruct",
id=DEFAULT_GOOGLE_QUERY_TRANSLATOR_MODEL,
tensor_parallel_size=1,
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20),
sampling_params=SamplingParams(temperature=0.2, top_p=0.95, max_tokens=20),
)
)

Expand Down

0 comments on commit a6699aa

Please sign in to comment.