diff --git a/interpreter/core/llm/llm.py b/interpreter/core/llm/llm.py index e6de76365e..f95f5fc18d 100644 --- a/interpreter/core/llm/llm.py +++ b/interpreter/core/llm/llm.py @@ -1,4 +1,5 @@ import litellm +import tokentrim as tt from .run_function_calling_llm import run_function_calling_llm from .run_text_llm import run_text_llm @@ -93,25 +94,32 @@ def run(self, messages): shrink_images=self.interpreter.shrink_images, ) + system_message = messages[0]["content"] + messages = messages[1:] + # Trim messages try: if self.context_window and self.max_tokens: trim_to_be_this_many_tokens = ( self.context_window - self.max_tokens - 25 ) # arbitrary buffer - messages = litellm.utils.trim_messages( + messages = tt.trim( messages, + system_message=system_message, max_tokens=trim_to_be_this_many_tokens, ) elif self.context_window and not self.max_tokens: # Just trim to the context window if max_tokens not set - messages = litellm.utils.trim_messages( + messages = tt.trim( messages, + system_message=system_message, max_tokens=self.context_window, ) else: try: - messages = litellm.utils.trim_messages(messages, model=self.model) + messages = tt.trim( + messages, system_message=system_message, model=self.model + ) except: if len(messages) == 1: print( @@ -121,11 +129,17 @@ def run(self, messages): Also, please set max_tokens: `interpreter --max_tokens {max tokens per response}` or `interpreter.llm.max_tokens = {max tokens per response}` """ ) - messages = litellm.utils.trim_messages(messages, max_tokens=3000) + messages = tt.trim( + messages, system_message=system_message, max_tokens=3000 + ) except: # If we're trimming messages, this won't work. # If we're trimming from a model we don't know, this won't work. # Better not to fail until `messages` is too big, just for frustrations sake, I suppose. + + # Reunite system message with messages + messages = [system_message] + messages + pass ## Start forming the request diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index 856f6997fb..5724c765a9 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -12,6 +12,20 @@ ) +def test_long_message(): + messages = [ + { + "role": "user", + "type": "message", + "content": "ABCD" * 20000 + "\ndescribe to me what i just said", + } + ] + interpreter.llm.context_window = 300 + interpreter.chat(messages) + assert len(interpreter.messages) > 1 + assert "ABCD" in interpreter.messages[-1]["content"] + + @pytest.mark.skip(reason="Computer with display only + no way to fail test") def test_display_api(): interpreter.computer.mouse.move(icon="gear")