Skip to content

Commit

Permalink
Merge pull request #1482 from AndreaCossu/master
Browse files Browse the repository at this point in the history
Fix for NCM forward pass when no class mean is present yet
  • Loading branch information
AntonioCarta authored Jul 28, 2023
2 parents 7771033 + 6f91540 commit 5700405
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
31 changes: 20 additions & 11 deletions avalanche/models/ncm_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, normalize: bool = True):
self.class_means_dict = {}

self.normalize = normalize
self.max_class = -1

def load_state_dict(self, state_dict, strict: bool = True):
self.class_means = state_dict["class_means"]
Expand All @@ -43,6 +44,7 @@ def load_state_dict(self, state_dict, strict: bool = True):
for i in range(self.class_means.shape[0]):
if (self.class_means[i] != 0).any():
self.class_means_dict[i] = self.class_means[i].clone()
self.max_class = max(self.class_means_dict.keys())

def _vectorize_means_dict(self):
"""
Expand All @@ -55,11 +57,12 @@ def _vectorize_means_dict(self):
if self.class_means_dict == {}:
return

max_class = max(self.class_means_dict.keys()) + 1
max_class = max(self.class_means_dict.keys())
self.max_class = max(max_class, self.max_class)
first_mean = list(self.class_means_dict.values())[0]
feature_size = first_mean.size(0)
device = first_mean.device
self.class_means = torch.zeros(max_class, feature_size).to(device)
self.class_means = torch.zeros(self.max_class + 1, feature_size).to(device)

for k, v in self.class_means_dict.items():
self.class_means[k] = self.class_means_dict[k].clone()
Expand All @@ -73,6 +76,8 @@ def forward(self, x):
negative distance of each element in the mini-batch
with respect to each class.
"""
if self.class_means_dict == {}:
self.init_missing_classes(range(self.max_class + 1), x.shape[1], x.device)

assert self.class_means_dict != {}, "no class means available."
if self.normalize:
Expand Down Expand Up @@ -102,7 +107,7 @@ def update_class_means_dict(
"class_means_dict must be a dictionary mapping class_id " "to mean vector"
)
for k, v in class_means_dict.items():
if k not in self.class_means_dict:
if k not in self.class_means_dict or (self.class_means_dict[k] == 0).all():
self.class_means_dict[k] = class_means_dict[k].clone()
else:
device = self.class_means_dict[k].device
Expand All @@ -121,18 +126,22 @@ def replace_class_means_dict(self, class_means_dict: Dict[int, Tensor]):
"class_means_dict must be a dictionary mapping class_id " "to mean vector"
)
self.class_means_dict = {k: v.clone() for k, v in class_means_dict.items()}

self._vectorize_means_dict()

def eval_adaptation(self, experience):
if self.class_means is None:
return
for k in experience.classes_in_this_experience:
def init_missing_classes(self, classes, class_size, device):
for k in classes:
if k not in self.class_means_dict:
self.class_means_dict[k] = torch.zeros(self.class_means.shape[1]).to(
self.class_means.device
)
self.class_means_dict[k] = torch.zeros(class_size).to(device)
self._vectorize_means_dict()

def eval_adaptation(self, experience):
classes = experience.classes_in_this_experience
for k in classes:
self.max_class = max(k, self.max_class)
if self.class_means is not None:
self.init_missing_classes(
classes, self.class_means.shape[1], self.class_means.device
)


__all__ = ["NCMClassifier"]
6 changes: 6 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ def test_ncm_replace_means(self):
classifier.replace_class_means_dict(new_dict)
assert (classifier.class_means[:, 0] == 2).all()

def test_ncm_forward_without_class_means(self):
classifier = NCMClassifier()
classifier.init_missing_classes(list(range(10)), 7, "cpu")
logits = classifier(torch.randn(2, 7))
assert logits.shape == (2, 10)

def test_ncm_save_load(self):
classifier = NCMClassifier()
classifier.update_class_means_dict(
Expand Down
1 change: 1 addition & 0 deletions tests/training/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def test_icarl(self):
train_epochs=2,
eval_mb_size=50,
device=self.device,
eval_every=1,
)

run_strategy(benchmark, strategy)
Expand Down

0 comments on commit 5700405

Please sign in to comment.