diff --git a/DEVELOPING.md b/DEVELOPING.md index e94f658..b230366 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -14,20 +14,31 @@ Before you begin, ensure you have the following installed: - `pip` (Python package installer) - `git` (version control system) -### Installation +### Clone the repository: -1. Clone the repository: +```sh +git clone https://github.com/neuralmagic/guidellm.git +cd guidellm +``` - ```bash - git clone https://github.com/neuralmagic/guidellm.git - cd guidellm - ``` +### Install dependencies: -2. Install the required dependencies: +All the dependencies are specified in `pyproject.toml` file. There is an option to install only required dependencies and optional dependencies - ```bash - pip install -e .[dev] - ``` +Install required dependencies along with optional `dev` dependencies. + +```bash +git clone https://github.com/neuralmagic/guidellm.git +cd guidellm +pip install -e .[dev] +``` + +If you work with `deepsparse` backend, etc it has some other software limitations. In order to install dependencies for the specific backend, run: + +```sh +pip install -e .[deepsparse] +# or pip install -e '.[deepsparse]' +``` ## Project Structure @@ -46,8 +57,9 @@ guidellm/ └── README.md ``` -- **src/guidellm/**: Main source code for the project. -- **tests/**: Test cases categorized into unit, integration, and end-to-end tests. +- `pyproject.toml`: Project metadata +- `**src/guidellm/**`: Main source code for the project. +- `**tests/**`: Test cases categorized into unit, integration, and end-to-end tests. ## Development Environment Setup @@ -234,12 +246,14 @@ The project configuration entrypoint is represented by lazy-loaded `settigns` si The project is fully configurable with environment variables. All the default values and ```py -class NestedIntoLogging(BaseModel): +class Nested(BaseModel): nested: str = "default value" class LoggingSettings(BaseModel): # ... + disabled: bool = False + nested: Nested = Nested() class Settings(BaseSettings): diff --git a/pyproject.toml b/pyproject.toml index db2e65a..942c1cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ dev = [ "types-requests~=2.32.0", "types-toml", ] +deepsparse = [ + "deepsparse; python_version < '3.12'", +] [project.entry-points.console_scripts] @@ -104,6 +107,10 @@ exclude = ["venv", ".tox"] # Check: https://mypy.readthedocs.io/en/latest/config_file.html#import-discovery follow_imports = 'silent' +[[tool.mypy.overrides]] +module = ["deepsparse.*", "transformers.*"] +ignore_missing_imports=true + [tool.ruff] line-length = 88 @@ -117,11 +124,14 @@ indent-style = "space" [tool.ruff.lint] ignore = [ "PLR0913", + "PLR2004", # allow numbers without constants definitions + "RET505", # allow `else` block after `if (condition): return value` line "TCH001", "COM812", "ISC001", "TCH002", "PLW1514", # allow Path.open without encoding + "S311", # allow standard pseudo-random generators ] select = [ @@ -177,19 +187,19 @@ select = [ "FIX", # flake8-fixme: detects FIXMEs and other temporary comments that should be resolved ] -[tool.ruff.lint.extend-per-file-ignores] -"tests/**/*.py" = [ + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ "S101", # asserts allowed in tests + "S105", # allow hardcoded passwords in tests + "S106", # allow hardcoded passwords in tests "ARG", # Unused function args allowed in tests "PLR2004", # Magic value used in comparison "TCH002", # No import only type checking in tests "SLF001", # enable private member access in tests - "S105", # allow hardcoded passwords in tests - "S311", # allow standard pseudo-random generators in tests "PT011", # allow generic exceptions in tests "N806", # allow uppercase variable names in tests "PGH003", # allow general ignores in tests - "S106", # allow hardcoded passwords in tests "PLR0915", # allow complext statements in tests ] diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index 875e319..b6d1b9d 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,10 +1,3 @@ from .base import Backend, BackendEngine, BackendEnginePublic, GenerativeResponse -from .openai import OpenAIBackend -__all__ = [ - "Backend", - "BackendEngine", - "BackendEnginePublic", - "GenerativeResponse", - "OpenAIBackend", -] +__all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"] diff --git a/src/guidellm/backend/base.py b/src/guidellm/backend/base.py index d71c5f6..010cdd2 100644 --- a/src/guidellm/backend/base.py +++ b/src/guidellm/backend/base.py @@ -15,7 +15,7 @@ __all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"] -BackendEnginePublic = Literal["openai_server"] +BackendEnginePublic = Literal["openai_server", "deepsparse"] BackendEngine = Union[BackendEnginePublic, Literal["test"]] @@ -117,9 +117,10 @@ def __init__(self, type_: BackendEngine, target: str, model: str): :param target: The target URL for the backend. :param model: The model used by the backend. """ - self._type = type_ - self._target = target - self._model = model + + self._type: BackendEngine = type_ + self._target: str = target + self._model: str = model self.test_connection() diff --git a/src/guidellm/backend/deepsparse/__init__.py b/src/guidellm/backend/deepsparse/__init__.py new file mode 100644 index 0000000..c45a112 --- /dev/null +++ b/src/guidellm/backend/deepsparse/__init__.py @@ -0,0 +1,28 @@ +""" +This package encapsulates the "Deepsparse Backend" implementation. + +ref: https://github.com/neuralmagic/deepsparse + +The `deepsparse` package supports Python3.6..Python3.11, +when the `guidellm` start from Python3.8. + +Safe range of versions is Python3.8..Python3.11 +for the Deepsparse Backend implementation. + +In the end ensure that the `deepsparse` package is installed. +""" + +from guidellm.utils import check_python_version, module_is_available + +check_python_version(min_version="3.8", max_version="3.11") +module_is_available( + module="deepsparse", + helper=( + "`deepsparse` package is not available. " + "Please try `pip install -e '.[deepsparse]'`" + ), +) + +from .backend import DeepsparseBackend # noqa: E402 + +__all__ = ["DeepsparseBackend"] diff --git a/src/guidellm/backend/deepsparse/backend.py b/src/guidellm/backend/deepsparse/backend.py new file mode 100644 index 0000000..924d04e --- /dev/null +++ b/src/guidellm/backend/deepsparse/backend.py @@ -0,0 +1,121 @@ +from typing import Any, AsyncGenerator, Dict, List, Optional + +from deepsparse import Pipeline, TextGeneration +from loguru import logger + +from guidellm.backend import Backend, GenerativeResponse +from guidellm.config import settings +from guidellm.core import TextGenerationRequest + + +@Backend.register(backend_type="deepsparse") +class DeepsparseBackend(Backend): + """ + An Deepsparse backend implementation for the generative AI result. + """ + + def __init__(self, model: Optional[str] = None, **request_args): + self._request_args: Dict[str, Any] = request_args + self._model = self._get_model(model) + self.pipeline: Pipeline = TextGeneration(model=self._model) + + super().__init__(type_="deepsparse", model=self._model, target="not used") + + logger.info(f"Deepsparse Backend uses model {self._model}") + + def _get_model(self, model_from_cli: Optional[str] = None) -> str: + """Provides the model by the next priority list: + 1. from function argument (comes from CLI) + 1. from environment variable + 2. `self.default_model` from `self.available_models` + """ + + if model_from_cli is not None: + return model_from_cli + elif settings.llm_model is not None: + logger.info( + "Using Deepsparse model from environment variable: " + f"{settings.llm_model}" + ) + return settings.llm_model + else: + logger.info(f"Using default Deepsparse model: {self.default_model}") + logger.info( + "To customize the model either set the 'GUIDELLM__LLM_MODEL' " + "environment variable or set the CLI argument '--model'" + ) + return self.default_model + + async def make_request( + self, request: TextGenerationRequest + ) -> AsyncGenerator[GenerativeResponse, None]: + """ + Make a request to the Deepsparse Python API client. + + :param request: The result request to submit. + :type request: TextGenerationRequest + :return: An iterator over the generative responses. + :rtype: Iterator[GenerativeResponse] + """ + + logger.debug( + f"Making request to Deepsparse backend with prompt: {request.prompt}" + ) + + token_count = 0 + request_args = { + **self._request_args, + "streaming": True, + "max_new_tokens": request.output_token_count, + } + + if not (output := self.pipeline(prompt=request.prompt, **request_args)): + yield GenerativeResponse( + type_="final", + prompt=request.prompt, + prompt_token_count=request.prompt_token_count, + output_token_count=token_count, + ) + return + + for generation in output.generations: + if not (token := generation.text): + yield GenerativeResponse( + type_="final", + prompt=request.prompt, + prompt_token_count=request.prompt_token_count, + output_token_count=token_count, + ) + return + else: + token_count += 1 + yield GenerativeResponse( + type_="token_iter", + add_token=token, + prompt=request.prompt, + prompt_token_count=request.prompt_token_count, + output_token_count=token_count, + ) + + yield GenerativeResponse( + type_="final", + prompt=request.prompt, + prompt_token_count=request.prompt_token_count, + output_token_count=token_count, + ) + + def available_models(self) -> List[str]: + """ + Get the available models for the backend. + + :return: A list of available models. + :rtype: List[str] + """ + + # WARNING: The default model from the documentation is defined here + return ["hf:mgoin/TinyStories-33M-quant-deepsparse"] + + def _token_count(self, text: str) -> int: + token_count = len(text.split()) + logger.debug(f"Token count for text '{text}': {token_count}") + return token_count diff --git a/src/guidellm/config.py b/src/guidellm/config.py index c3d950e..e6f5432 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -128,6 +128,7 @@ class Settings(BaseSettings): ```sh export GUIDELLM__LOGGING__DISABLED=true export GUIDELLM__OPENAI__API_KEY=****** + export GUIDELLM__LLM_MODEL=****** ``` """ @@ -141,6 +142,7 @@ class Settings(BaseSettings): # general settings env: Environment = Environment.PROD + llm_model: str = "mistralai/Mistral-7B-Instruct-v0.3" request_timeout: int = 30 max_concurrency: int = 512 num_sweep_profiles: int = 9 @@ -152,8 +154,6 @@ class Settings(BaseSettings): # Request settings openai: OpenAISettings = OpenAISettings() - - # Report settings report_generation: ReportGenerationSettings = ReportGenerationSettings() @model_validator(mode="after") diff --git a/src/guidellm/executor/profile_generator.py b/src/guidellm/executor/profile_generator.py index 703ea05..c37b1da 100644 --- a/src/guidellm/executor/profile_generator.py +++ b/src/guidellm/executor/profile_generator.py @@ -190,12 +190,14 @@ def next(self, current_report: TextGenerationBenchmarkReport) -> Optional[Profil elif self.mode == "sweep": profile = self.create_sweep_profile( self.generated_count, - sync_benchmark=current_report.benchmarks[0] - if current_report.benchmarks - else None, - throughput_benchmark=current_report.benchmarks[1] - if len(current_report.benchmarks) > 1 - else None, + sync_benchmark=( + current_report.benchmarks[0] if current_report.benchmarks else None + ), + throughput_benchmark=( + current_report.benchmarks[1] + if len(current_report.benchmarks) > 1 + else None + ), ) else: err = ValueError(f"Invalid mode: {self.mode}") diff --git a/src/guidellm/scheduler/base.py b/src/guidellm/scheduler/base.py index 602166b..3f63372 100644 --- a/src/guidellm/scheduler/base.py +++ b/src/guidellm/scheduler/base.py @@ -227,9 +227,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: count_total = ( self.max_number if self.max_number - else round(self.max_duration) - if self.max_duration - else 0 + else round(self.max_duration) if self.max_duration else 0 ) # yield initial result for progress tracking @@ -246,9 +244,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: count_completed = ( min(run_count, self.max_number) if self.max_number - else round(time.time() - start_time) - if self.max_duration - else 0 + else round(time.time() - start_time) if self.max_duration else 0 ) yield SchedulerResult( @@ -267,9 +263,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]: count_completed=( benchmark.request_count + benchmark.error_count if self.max_number - else round(time.time() - start_time) - if self.max_duration - else 0 + else round(time.time() - start_time) if self.max_duration else 0 ), benchmark=benchmark, ) diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 2fdd8ca..c084202 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,3 +1,4 @@ +from .dependencies import check_python_version, module_is_available from .injector import create_report, inject_data from .progress import BenchmarkReportProgress from .text import ( @@ -9,6 +10,7 @@ load_text, load_text_lines, parse_text_objects, + random_strings, split_lines_by_punctuation, split_text, ) @@ -31,10 +33,13 @@ "load_text", "load_text_lines", "load_transformers_dataset", + "random_strings", "parse_text_objects", "resolve_transformers_dataset", "resolve_transformers_dataset_column", "resolve_transformers_dataset_split", "split_lines_by_punctuation", "split_text", + "check_python_version", + "module_is_available", ] diff --git a/src/guidellm/utils/dependencies.py b/src/guidellm/utils/dependencies.py new file mode 100644 index 0000000..5022068 --- /dev/null +++ b/src/guidellm/utils/dependencies.py @@ -0,0 +1,59 @@ +import importlib +import sys +from typing import NoReturn, Tuple, Union + + +def _extract_python_version(data: str) -> Tuple[int, ...]: + """Extract '3.12' -> (3, 12).""" + + if len(items := data.split(".")) > 2: + raise ValueError("Python version format: MAJOR.MINOR") + + if not all(item.isnumeric() for item in items): + raise ValueError("Python version must include only numbers") + + return tuple(int(item) for item in items) + + +def check_python_version( + min_version: str, max_version: str, raise_error=True +) -> Union[NoReturn, bool]: + """Validate Python version. + + :param min_version: the min (included) Python version in format: MAJOR.MINOR + :param max_version: the max (included) Python version in format: MAJOR.MINOR + :param raise_error: set to False if you don't want to raise the RuntimeError in + case the validation is failed + """ + + min_version_info: Tuple[int, ...] = _extract_python_version(min_version) + max_version_info: Tuple[int, ...] = _extract_python_version(max_version) + current_version_info: Tuple[int, int] = ( + sys.version_info.major, + sys.version_info.minor, + ) + + if not (min_version_info <= current_version_info <= max_version_info): + if raise_error is False: + return False + else: + raise RuntimeError( + "This feature requires Python version " + f"to be in range: {min_version}..{max_version}." + "You are using Python {}.{}.{}".format( + sys.version_info.major, + sys.version_info.minor, + sys.version_info.micro, + ) + ) + else: + return True + + +def module_is_available(module: str, helper: str): + """Ensure that the module is available for other project components.""" + + try: + importlib.import_module(module) + except ImportError: + raise RuntimeError(f"Module '{module}' is not available. {helper}") from None diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index 13a0dff..1bdba67 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,8 +1,10 @@ import csv import json +import random import re +import string from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from urllib.parse import urlparse import ftfy @@ -23,6 +25,7 @@ "parse_text_objects", "split_lines_by_punctuation", "split_text", + "random_strings", ] @@ -453,3 +456,32 @@ def load_text_lines( # extract the lines from the data return [row[filter_] for row in data] if filter_ else [str(row) for row in data] + + +def random_strings( + min_chars: int, max_chars: int, n: int = 0, dataset: Optional[str] = None +) -> Generator[str, None, None]: + """Yield random strings. + + :param min: the min number of output characters + :param max: the max number of output characters + :param n: the number of outputs. If `0` -> works for infinite + :param dataset: represents allowed characters for the operation + """ + + characters: str = dataset or string.printable + + if n < 0: + raise ValueError("'n' must be >= '0'") + elif n == 0: + while True: + yield "".join( + random.choice(characters) + for _ in range(random.randint(min_chars, max_chars)) + ) + else: + for _ in range(n): + yield "".join( + random.choice(characters) + for _ in range(random.randint(min_chars, max_chars)) + ) diff --git a/tests/unit/backend/test_deepsparse_backend.py b/tests/unit/backend/test_deepsparse_backend.py new file mode 100644 index 0000000..58e5761 --- /dev/null +++ b/tests/unit/backend/test_deepsparse_backend.py @@ -0,0 +1,175 @@ +import sys +from typing import Any, Dict, Generator, List, Optional + +import pytest +from pydantic import BaseModel + +from guidellm.backend import Backend +from guidellm.config import reload_settings +from guidellm.core import TextGenerationRequest +from guidellm.utils import random_strings + +pytestmark = pytest.mark.skipif( + sys.version_info >= (3, 12), reason="Unsupported Python version" +) + + +@pytest.fixture(scope="module") +def backend_class(): + from guidellm.backend.deepsparse import DeepsparseBackend + + return DeepsparseBackend + + +class TestDeepsparseTextGeneration(BaseModel): + """The representation of a deepsparse data structure.""" + + text: str + + +class TestTextGenerationPipeline: + """Deepsparse TextGeneration test interface. + + By default this class generates '10' text responses. + + This class includes an additional development information + for better testing experience. + + Method `__call__` allows to mock the result object that comes from + `deepsparse.pipeline.Pipeline()` so everything is encapsulated right here. + + :param self._generation: dynamic representation of generated responses + from deepsparse interface. + """ + + def __init__(self): + self._generations: List[TestDeepsparseTextGeneration] = [] + self._prompt: Optional[str] = None + self._max_new_tokens: Optional[int] = None + + def __call__( + self, *_, prompt: str, max_new_tokens: Optional[int] = None, **kwargs + ) -> Any: + """Mocks the result from `deepsparse.pipeline.Pipeline()()`. + Set reserved request arguments on call. + + Note: `**kwargs` is required since it allows to mimic + the `deepsparse.Pipeline` behavior. + """ + + self._prompt = prompt + self._max_new_tokens = max_new_tokens + + return self + + @property + def generations(self) -> Generator[TestDeepsparseTextGeneration, None, None]: + for text in random_strings( + min_chars=10, + max_chars=50, + n=self._max_new_tokens if self._max_new_tokens else 10, + ): + generation = TestDeepsparseTextGeneration(text=text) + self._generations.append(generation) + yield generation + + +@pytest.fixture(autouse=True) +def mock_deepsparse_pipeline(mocker): + return mocker.patch( + "deepsparse.Pipeline.create", return_value=TestTextGenerationPipeline() + ) + + +@pytest.mark.smoke() +@pytest.mark.parametrize( + "create_payload", + [ + {}, + {"model": "test/custom_llm"}, + ], +) +def test_backend_creation(create_payload: Dict, backend_class): + """Test the "Deepspaarse Backend" class + with defaults and custom input parameters. + """ + + backends = [ + Backend.create("deepsparse", **create_payload), + backend_class(**create_payload), + ] + + for backend in backends: + assert backend.pipeline + ( + backend.model == custom_model + if (custom_model := create_payload.get("model")) + else backend.default_model + ) + + +@pytest.mark.smoke() +def test_backend_model_from_env(mocker, backend_class): + mocker.patch.dict( + "os.environ", + {"GUIDELLM__LLM_MODEL": "test_backend_model_from_env"}, + ) + + reload_settings() + + backends = [Backend.create("deepsparse"), backend_class()] + + for backend in backends: + assert backend.model == "test_backend_model_from_env" + + +@pytest.mark.smoke() +@pytest.mark.parametrize( + "text_generation_request_create_payload", + [ + {"prompt": "Test prompt"}, + {"prompt": "Test prompt", "output_token_count": 20}, + ], +) +@pytest.mark.asyncio() +async def test_make_request( + text_generation_request_create_payload: Dict, backend_class +): + backend = backend_class() + + output_tokens: List[str] = [] + async for response in backend.make_request( + request=TextGenerationRequest(**text_generation_request_create_payload) + ): + if response.add_token: + output_tokens.append(response.add_token) + assert "".join(output_tokens) == "".join( + generation.text for generation in backend.pipeline._generations + ) + + if max_tokens := text_generation_request_create_payload.get("output_token_count"): + assert len(backend.pipeline._generations) == max_tokens + + +@pytest.mark.smoke() +@pytest.mark.parametrize( + ("text_generation_request_create_payload", "error"), + [ + ( + {"prompt": "Test prompt", "output_token_count": -1}, + ValueError, + ), + ], +) +@pytest.mark.asyncio() +async def test_make_request_invalid_request_payload( + text_generation_request_create_payload: Dict, error, backend_class +): + backend = backend_class() + with pytest.raises(error): + [ + respnose + async for respnose in backend.make_request( + request=TextGenerationRequest(**text_generation_request_create_payload) + ) + ] diff --git a/tests/unit/backend/test_openai_backend.py b/tests/unit/backend/test_openai_backend.py index 396eb4c..6c11081 100644 --- a/tests/unit/backend/test_openai_backend.py +++ b/tests/unit/backend/test_openai_backend.py @@ -2,7 +2,8 @@ import pytest -from guidellm.backend import Backend, OpenAIBackend +from guidellm.backend import Backend +from guidellm.backend.openai import OpenAIBackend from guidellm.config import reload_settings, settings from guidellm.core import TextGenerationRequest @@ -245,8 +246,8 @@ def test_openai_backend_target(mock_openai_client): assert backend._client.kwargs["base_url"] == "http://test-target" # type: ignore backend = OpenAIBackend() - assert backend._async_client.kwargs["base_url"] == "http://localhost:8000/v1" # type: ignore - assert backend._client.kwargs["base_url"] == "http://localhost:8000/v1" # type: ignore + assert backend._async_client.kwargs["base_url"] == settings.openai.base_url # type: ignore + assert backend._client.kwargs["base_url"] == settings.openai.base_url # type: ignore backend = OpenAIBackend() assert backend._async_client.kwargs["base_url"] == settings.openai.base_url # type: ignore diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3257a8d..406460b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,6 +5,8 @@ import pytest import requests_mock +from guidellm.config import settings + @pytest.fixture() def mock_auto_tokenizer(): @@ -22,14 +24,9 @@ def _fake_tokenize(text: str) -> List[int]: @pytest.fixture() def mock_requests_pride_and_prejudice(): - text_path = ( - Path(__file__).parent.parent / "dummy" / "data" / "pride_and_prejudice.txt" - ) + text_path = Path(__file__).parent.parent / "dummy/data/pride_and_prejudice.txt" text_content = text_path.read_text() with requests_mock.Mocker() as mock: - mock.get( - "https://www.gutenberg.org/files/1342/1342-0.txt", - text=text_content, - ) + mock.get(settings.emulated_data.source, text=text_content) yield mock diff --git a/tests/unit/core/test_distribution.py b/tests/unit/core/test_distribution.py index 95b7e92..2e2dd50 100644 --- a/tests/unit/core/test_distribution.py +++ b/tests/unit/core/test_distribution.py @@ -73,9 +73,9 @@ def test_distribution_str(): "'percentile_values': [1.4, 1.8, 2.2, 2.6, 3.0, 3.4, 3.8, 4.2, 4.6, 4.8, 4.96]" in str(dist) ) - assert "'min': 1" in str(dist) - assert "'max': 5" in str(dist) - assert "'range': 4" in str(dist) + assert "'min': 1.0" in str(dist) + assert "'max': 5.0" in str(dist) + assert "'range': 4.0" in str(dist) @pytest.mark.regression() diff --git a/tests/unit/executor/test_base.py b/tests/unit/executor/test_base.py index 844cf7f..1eec6ef 100644 --- a/tests/unit/executor/test_base.py +++ b/tests/unit/executor/test_base.py @@ -5,9 +5,7 @@ from guidellm.backend import Backend from guidellm.config import settings -from guidellm.core import ( - TextGenerationBenchmarkReport, -) +from guidellm.core import TextGenerationBenchmarkReport from guidellm.executor import ( Executor, ExecutorResult, @@ -269,9 +267,11 @@ async def test_executor_run_sweep(mock_scheduler): result=result, expected_completed=False, expected_count_total=num_profiles, - expected_count_completed=scheduler_index - if request_index < num_requests + 1 - else scheduler_index + 1, + expected_count_completed=( + scheduler_index + if request_index < num_requests + 1 + else scheduler_index + 1 + ), expected_generation_modes=generation_modes, # type: ignore ) _check_executor_result_report( @@ -280,9 +280,11 @@ async def test_executor_run_sweep(mock_scheduler): rate=None, max_number=num_requests, max_duration=None, - benchmarks_count=scheduler_index - if request_index < num_requests + 1 - else scheduler_index + 1, + benchmarks_count=( + scheduler_index + if request_index < num_requests + 1 + else scheduler_index + 1 + ), ) _check_executor_result_scheduler( result=result, @@ -429,7 +431,6 @@ async def test_executor_run_non_rate_modes(mock_scheduler, mode): @pytest.mark.smoke() -@pytest.mark.asyncio() @pytest.mark.parametrize( ("mode", "rate"), [ @@ -491,9 +492,11 @@ async def test_executor_run_rate_modes(mock_scheduler, mode, rate): result=result, expected_completed=False, expected_count_total=num_profiles, - expected_count_completed=scheduler_index - if request_index < num_requests + 1 - else scheduler_index + 1, + expected_count_completed=( + scheduler_index + if request_index < num_requests + 1 + else scheduler_index + 1 + ), expected_generation_modes=generation_modes, ) _check_executor_result_report( @@ -502,9 +505,11 @@ async def test_executor_run_rate_modes(mock_scheduler, mode, rate): rate=rate, max_number=num_requests, max_duration=None, - benchmarks_count=scheduler_index - if request_index < num_requests + 1 - else scheduler_index + 1, + benchmarks_count=( + scheduler_index + if request_index < num_requests + 1 + else scheduler_index + 1 + ), ) _check_executor_result_scheduler( result=result, diff --git a/tests/unit/executor/test_profile_generator.py b/tests/unit/executor/test_profile_generator.py index 9c91d57..9a01961 100644 --- a/tests/unit/executor/test_profile_generator.py +++ b/tests/unit/executor/test_profile_generator.py @@ -4,10 +4,7 @@ import pytest from guidellm import settings -from guidellm.core import ( - TextGenerationBenchmark, - TextGenerationBenchmarkReport, -) +from guidellm.core import TextGenerationBenchmark, TextGenerationBenchmarkReport from guidellm.executor import Profile, ProfileGenerationMode, ProfileGenerator diff --git a/tests/unit/request/test_emulated.py b/tests/unit/request/test_emulated.py index f6af130..8c010ab 100644 --- a/tests/unit/request/test_emulated.py +++ b/tests/unit/request/test_emulated.py @@ -355,6 +355,8 @@ def test_emulated_request_generator_lifecycle( str(file_path) if config_type == "file_str" else file_path, tokenizer="mock-tokenizer", ) + else: + raise Exception for _ in range(5): request = generator.create_item() diff --git a/tests/unit/scheduler/test_base.py b/tests/unit/scheduler/test_base.py index b485e59..cb51cb0 100644 --- a/tests/unit/scheduler/test_base.py +++ b/tests/unit/scheduler/test_base.py @@ -11,11 +11,7 @@ TextGenerationResult, ) from guidellm.request import RequestGenerator -from guidellm.scheduler import ( - LoadGenerator, - Scheduler, - SchedulerResult, -) +from guidellm.scheduler import LoadGenerator, Scheduler, SchedulerResult @pytest.mark.smoke() @@ -109,7 +105,6 @@ def test_scheduler_invalid_instantiation( @pytest.mark.sanity() -@pytest.mark.asyncio() @pytest.mark.parametrize( "mode", [ @@ -119,6 +114,7 @@ def test_scheduler_invalid_instantiation( "constant", ], ) +@pytest.mark.asyncio() async def test_scheduler_run_number(mode): rate = 10.0 max_number = 20 @@ -194,7 +190,6 @@ def _submit(req): @pytest.mark.sanity() -@pytest.mark.asyncio() @pytest.mark.parametrize( "mode", [ @@ -203,6 +198,7 @@ def _submit(req): ], ) @pytest.mark.flaky(reruns=5) +@pytest.mark.asyncio() async def test_scheduler_run_duration(mode): rate = 10 max_duration = 2 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 13e1699..c79ba88 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,4 +1,5 @@ import pytest +from pydantic_settings import BaseSettings, SettingsConfigDict from guidellm.config import ( Environment, @@ -9,9 +10,18 @@ ) +class DefaultSettings(Settings, BaseSettings): + """ + This class overrides the original `Settings` class with another `model_config` + to ignore local environment variables for each runtime. + """ + + model_config = SettingsConfigDict() + + @pytest.mark.smoke() def test_default_settings(): - settings = Settings() + settings = DefaultSettings() assert settings.env == Environment.PROD assert settings.logging == LoggingSettings() assert settings.openai == OpenAISettings() diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py index 1d89ee3..9173cf7 100644 --- a/tests/unit/utils/test_text.py +++ b/tests/unit/utils/test_text.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import List from unittest.mock import patch import pytest @@ -13,6 +14,7 @@ load_text, load_text_lines, parse_text_objects, + random_strings, split_lines_by_punctuation, split_text, ) @@ -392,3 +394,25 @@ def test_split_text_with_mixed_separators(): assert words == ["This", "is", "a", "test", "with", "mixed", "separators."] assert separators == ["\t", " ", " ", "\n", " ", " ", " "] assert new_lines == [0, 4] + + +@pytest.mark.regression() +@pytest.mark.parametrize( + ("min_chars", "max_chars", "n", "dataset", "total_chars_len"), + [ + (5, 5, 10, None, 50), # always 5 chars per response + (1, 10, 10, None, None), # 1..10 chars per each + ], +) +def test_random_strings_generation(min_chars, max_chars, n, dataset, total_chars_len): + results: List[str] = list( + random_strings(min_chars=min_chars, max_chars=max_chars, n=n, dataset=dataset) + ) + + # Ensure total results + assert len(results) == n + + if total_chars_len is not None: + assert sum(len(r) for r in results) == total_chars_len + else: + assert min_chars * n <= sum(len(r) for r in results) < max_chars * n diff --git a/tox.ini b/tox.ini index 36e2809..40611c5 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ env_list = py38,py39,py310,py311,py312 [testenv] description = Run all tests deps = - .[dev] + .[dev,deepsparse] commands = pytest tests/ {posargs} @@ -14,7 +14,7 @@ commands = [testenv:test-unit] description = Run unit tests deps = - .[dev] + .[dev,deepsparse] commands = python -m pytest tests/unit {posargs} @@ -22,7 +22,7 @@ commands = [testenv:test-integration] description = Run integration tests deps = - .[dev] + .[dev,deepsparse] commands = python -m pytest tests/integration {posargs} @@ -30,7 +30,7 @@ commands = [testenv:test-e2e] description = Run end-to-end tests deps = - .[dev] + .[dev,deepsparse] commands = python -m pytest tests/e2e {posargs}