Skip to content

Commit

Permalink
fix modelpool, add tests in test/test_model_helpers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Nov 18, 2024
1 parent fa8825f commit 01e6b93
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 18 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ jobs:
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
echo "Running tokenizer tests..."
python3 ./test/test_tokenizers.py
python3 ./test/test_model_helpers.py
discovery_integration_test:
macos:
Expand Down
17 changes: 2 additions & 15 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict
from exo.inference.inference_engine import inference_engine_classes
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from typing import Callable, Optional


Expand Down Expand Up @@ -208,18 +206,7 @@ async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in [
model_id for model_id, model_info in model_cards.items()
if all(map(
lambda engine: engine in model_info["repo"],
list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))
))
]
for model_name in get_supported_models(self.node.topology_inference_engines_pool)
}
})

Expand Down
27 changes: 24 additions & 3 deletions exo/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from exo.inference.shard import Shard
from typing import Optional
from typing import Optional, List

model_cards = {
### llama
"llama-3.2-1b": {
"layers": 16,
"repo": {
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
},
Expand Down Expand Up @@ -124,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
if repo is None or n_layers < 1:
return None
return Shard(model_id, 0, 0, n_layers)


def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
if not supported_inference_engine_lists:
return list(model_cards.keys())

from exo.inference.inference_engine import inference_engine_classes
supported_inference_engine_lists = [
[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
for engine_list in supported_inference_engine_lists
]

def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
return any(engine in model_info.get("repo", {}) for engine in engine_list)

def supports_all_engine_lists(model_info: dict) -> bool:
return all(has_any_engine(model_info, engine_list)
for engine_list in supported_inference_engine_lists)

return [
model_id for model_id, model_info in model_cards.items()
if supports_all_engine_lists(model_info)
]
121 changes: 121 additions & 0 deletions test/test_model_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import unittest
from exo.models import get_supported_models, model_cards
from exo.inference.inference_engine import inference_engine_classes
from typing import NamedTuple

class TestCase(NamedTuple):
name: str
engine_lists: list # Will contain short names, will be mapped to class names
expected_models_contains: list
min_count: int | None
exact_count: int | None
max_count: int | None

# Helper function to map short names to class names
def expand_engine_lists(engine_lists):
def map_engine(engine):
return inference_engine_classes.get(engine, engine) # Return original name if not found

return [[map_engine(engine) for engine in sublist]
for sublist in engine_lists]

test_cases = [
TestCase(
name="single_mlx_engine",
engine_lists=[["mlx"]],
expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="single_tinygrad_engine",
engine_lists=[["tinygrad"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
min_count=5,
exact_count=None,
max_count=10
),
TestCase(
name="multiple_engines_or",
engine_lists=[["mlx", "tinygrad"], ["mlx"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="multiple_engines_all",
engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="distinct_engine_lists",
engine_lists=[["mlx"], ["tinygrad"]],
expected_models_contains=["llama-3.2-1b"],
min_count=5,
exact_count=None,
max_count=10
),
TestCase(
name="no_engines",
engine_lists=[],
expected_models_contains=None,
min_count=None,
exact_count=len(model_cards),
max_count=None
),
TestCase(
name="nonexistent_engine",
engine_lists=[["NonexistentEngine"]],
expected_models_contains=[],
min_count=None,
exact_count=0,
max_count=None
),
TestCase(
name="dummy_engine",
engine_lists=[["dummy"]],
expected_models_contains=["dummy"],
min_count=None,
exact_count=1,
max_count=None
),
]

class TestModelHelpers(unittest.TestCase):
def test_get_supported_models(self):
for case in test_cases:
with self.subTest(f"{case.name}_short_names"):
result = get_supported_models(case.engine_lists)
self._verify_results(case, result)

with self.subTest(f"{case.name}_class_names"):
class_name_lists = expand_engine_lists(case.engine_lists)
result = get_supported_models(class_name_lists)
self._verify_results(case, result)

def _verify_results(self, case, result):
if case.expected_models_contains:
for model in case.expected_models_contains:
self.assertIn(model, result)

if case.min_count:
self.assertGreater(len(result), case.min_count)

if case.exact_count is not None:
self.assertEqual(len(result), case.exact_count)

# Special case for distinct lists test
if case.name == "distinct_engine_lists":
self.assertLess(len(result), 10)
self.assertNotIn("mistral-nemo", result)

if case.max_count:
self.assertLess(len(result), case.max_count)

if __name__ == '__main__':
unittest.main()

0 comments on commit 01e6b93

Please sign in to comment.