Skip to content

Commit

Permalink
⚡ allowing greedy accuracy to be a tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Dec 2, 2024
1 parent 72f603e commit 5a4d8bd
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions hierarchicalsoftmax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, m
"""
prediction_nodes = []

if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1:
prediction_tensor = prediction_tensor[0]

if root.softmax_start_index is None:
raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")

Expand Down

0 comments on commit 5a4d8bd

Please sign in to comment.