Skip to content

Commit

Permalink
added annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed May 13, 2024
1 parent 617dee9 commit 95d7b44
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions luxonis_train/attached_modules/metrics/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import torchmetrics
from torch import Tensor

from .base_metric import BaseMetric

Expand Down Expand Up @@ -47,12 +48,12 @@ def __init__(self, **kwargs):

self.metric = self.Metric(**kwargs)

def update(self, preds, target, *args, **kwargs):
def update(self, preds, target, *args, **kwargs) -> None:
if self.task in ["multiclass"]:
target = target.argmax(dim=1)
self.metric.update(preds, target, *args, **kwargs)

def compute(self):
def compute(self) -> Tensor:
return self.metric.compute()

def reset(self) -> None:
Expand Down

0 comments on commit 95d7b44

Please sign in to comment.