diff --git a/avalanche/models/ncm_classifier.py b/avalanche/models/ncm_classifier.py index 24e2556a4..cda20c9a7 100644 --- a/avalanche/models/ncm_classifier.py +++ b/avalanche/models/ncm_classifier.py @@ -3,8 +3,10 @@ import torch from torch import Tensor, nn +from avalanche.models import DynamicModule -class NCMClassifier(nn.Module): + +class NCMClassifier(DynamicModule): """ NCM Classifier. NCMClassifier performs nearest class mean classification @@ -122,5 +124,15 @@ def replace_class_means_dict(self, class_means_dict: Dict[int, Tensor]): self._vectorize_means_dict() + def eval_adaptation(self, experience): + if self.class_means is None: + return + for k in experience.classes_in_this_experience: + if k not in self.class_means_dict: + self.class_means_dict[k] = torch.zeros(self.class_means.shape[1]).to( + self.class_means.device + ) + self._vectorize_means_dict() + __all__ = ["NCMClassifier"]