diff --git a/models/llama3/tests/api/test_generation.py b/models/llama3/tests/api/test_generation.py index a71738ba..43c4988b 100644 --- a/models/llama3/tests/api/test_generation.py +++ b/models/llama3/tests/api/test_generation.py @@ -7,13 +7,10 @@ import os import unittest - from pathlib import Path - import numpy as np import pytest from llama_models.llama3.api.datatypes import ImageMedia, SystemMessage, UserMessage - from llama_models.llama3.reference_impl.generation import Llama from PIL import Image as PIL_Image @@ -42,33 +39,44 @@ class TestTextModelInference(unittest.TestCase): def setUpClass(cls): cls.generator = build_generator("TEXT_MODEL_CHECKPOINT_DIR") - def test_run_generation(self): - dialogs = [ + @pytest.mark.parametrize( + "dialogs", + [ [ - SystemMessage(content="Always answer with Haiku"), - UserMessage(content="I am going to Paris, what should I see?"), + [ + SystemMessage(content="Always answer with Haiku"), + UserMessage(content="I am going to Paris, what should I see?"), + ], + [ + SystemMessage(content="Always answer with emojis"), + UserMessage(content="How to go from Beijing to NY?"), + ], ], [ - SystemMessage( - content="Always answer with emojis", - ), - UserMessage(content="How to go from Beijing to NY?"), + [ + SystemMessage(content="Always answer in riddles"), + UserMessage(content="What has keys but can't open locks?"), + ] ], - ] + ], + ) + def test_run_generation(self, dialogs): for dialog in dialogs: result = self.__class__.generator.chat_completion( dialog, temperature=0, logprobs=True, ) - out_message = result.generation self.assertTrue(len(out_message.content) > 0) shape = np.array(result.logprobs).shape - # assert at least 10 tokens self.assertTrue(shape[0] > 10) self.assertEqual(shape[1], 1) + def test_empty_dialog(self): + with self.assertRaises(ValueError): + self.__class__.generator.chat_completion([], temperature=0) + class TestVisionModelInference(unittest.TestCase): @@ -106,10 +114,8 @@ def test_run_generation(self): temperature=0, logprobs=True, ) - out_message = result.generation self.assertTrue(len(out_message.content) > 0) shape = np.array(result.logprobs).shape - # assert at least 10 tokens self.assertTrue(shape[0] > 10) self.assertEqual(shape[1], 1)