Skip to content

Commit

Permalink
⚡ updating metric to be more robust for different cases
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Aug 25, 2024
1 parent d856813 commit 42c6add
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def depth_accurate(prediction_tensor, target_tensor, root:nodes.SoftmaxNode, max
if root.softmax_start_index is None:
raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")

if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1:
prediction_tensor = prediction_tensor[0]

if prediction_tensor.shape[-1] != root.layer_size:
raise ShapeError(
f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
Expand All @@ -60,6 +63,9 @@ def depth_accurate(prediction_tensor, target_tensor, root:nodes.SoftmaxNode, max
node = root
depth = 0
target_node = root.node_list[target]
target_path = target_node.path
target_path_length = len(target_path)


while (node.children):
# This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch
Expand All @@ -70,7 +76,7 @@ def depth_accurate(prediction_tensor, target_tensor, root:nodes.SoftmaxNode, max
node = node.children[prediction_child_index]
depth += 1

if node != target_node.path[depth]:
if depth < target_path_length and node != target_path[depth]:
depth -= 1
break

Expand Down

0 comments on commit 42c6add

Please sign in to comment.