From 555de3d16d793ea5ec93470bd256e01ea4df0973 Mon Sep 17 00:00:00 2001 From: AndreaCossu Date: Thu, 28 Jul 2022 19:11:46 +0200 Subject: [PATCH] Fixed EWC bug in importances --- avalanche/training/plugins/ewc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index 4a469ed8b..30c5cf841 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -89,7 +89,7 @@ def before_backward(self, strategy, **kwargs): # dynamic models may add new units # new units are ignored by the regularization n_units = saved_param.shape[0] - cur_param = saved_param[:n_units] + cur_param = cur_param[:n_units] penalty += (imp * (cur_param - saved_param).pow(2)).sum() elif self.mode == "online": prev_exp = exp_counter - 1 @@ -101,7 +101,7 @@ def before_backward(self, strategy, **kwargs): # dynamic models may add new units # new units are ignored by the regularization n_units = saved_param.shape[0] - cur_param = saved_param[:n_units] + cur_param = cur_param[:n_units] penalty += (imp * (cur_param - saved_param).pow(2)).sum() else: raise ValueError("Wrong EWC mode.")