Skip to content

Commit

Permalink
fix: DIA-1867: Fix Azure cost retrieval + task hanging on failure (#323)
Browse files Browse the repository at this point in the history
Co-authored-by: hakan458 <[email protected]>
  • Loading branch information
hakan458 and hakan458 authored Jan 29, 2025
1 parent 6157c09 commit 4132c6f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
19 changes: 15 additions & 4 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
data["_prompt_cost_usd"] = prompt_cost
data["_completion_cost_usd"] = completion_cost
data["_total_cost_usd"] = prompt_cost + completion_cost
except NotFoundError:
except:
logger.error(f"Failed to get cost for model {model}")
data["_prompt_cost_usd"] = None
data["_completion_cost_usd"] = None
Expand All @@ -161,7 +161,7 @@ def normalize_litellm_model_and_provider(model_name: str, provider: str):
This helper function contains logic which normalizes this for supported providers
"""
if "/" in model_name:
model_name = model_name.split('/', 1)[1]
model_name = model_name.split("/", 1)[1]
provider = provider.lower()
if provider == "vertexai":
provider = "vertex_ai"
Expand Down Expand Up @@ -215,6 +215,8 @@ def handle_llm_exception(
# usage = e.total_usage
# not available here, so have to approximate by hand, assuming the same error occurred each time
n_attempts = retries.stop.max_attempt_number
# Note that the default model used in token_counter is gpt-3.5-turbo as of now - if model passed in
# does not match a mapped model, falls back to default
prompt_tokens = n_attempts * litellm.token_counter(
model=model, messages=messages[:-1]
) # response is appended as the last message
Expand Down Expand Up @@ -368,11 +370,16 @@ def record_to_record(
)
usage = completion.usage
dct = to_jsonable_python(response)
# With successful completions we can get canonical model name
usage_model = completion.model

except Exception as e:
dct, usage = handle_llm_exception(e, messages, self.model, retries)
# With exceptions we dont have access to completion.model
usage_model = self.model

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=self.model))
dct.update(_get_usage_dict(usage, model=usage_model))

return dct

Expand Down Expand Up @@ -499,13 +506,17 @@ async def batch_to_batch(
dct, usage = handle_llm_exception(
response, messages, self.model, retries
)
# With exceptions we dont have access to completion.model
usage_model = self.model
else:
resp, completion = response
usage = completion.usage
dct = to_jsonable_python(resp)
# With successful completions we can get canonical model name
usage_model = completion.model

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=self.model))
dct.update(_get_usage_dict(usage, model=usage_model))

df_data.append(dct)

Expand Down
4 changes: 2 additions & 2 deletions adala/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ class TemplateChunks(TypedDict):
start: int
end: int
type: str


match_fields_regex = re.compile(r"(?<!\{)\{([a-zA-Z0-9_]+)\}(?!})")


Expand Down
1 change: 1 addition & 0 deletions server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ async def async_process_streaming_input(input_task_done: asyncio.Event, agent: A
logger.error(
f"Error in async_process_streaming_input: {e}. Traceback: {traceback.format_exc()}"
)
input_task_done.set()
# cleans up after any exceptions raised here as well as asyncio.CancelledError resulting from failure in async_process_streaming_output
finally:
await agent.environment.finalize()
Expand Down

0 comments on commit 4132c6f

Please sign in to comment.