Skip to content

Commit

Permalink
Merge pull request ContinualAI#1446 from AlbinSou/fix_mer
Browse files Browse the repository at this point in the history
Fix MER
  • Loading branch information
AntonioCarta authored Jul 6, 2023
2 parents 200b8c4 + 861b074 commit 2b7fa26
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 42 deletions.
1 change: 1 addition & 0 deletions avalanche/training/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .l2p import LearningToPrompt
from .supervised_contrastive_replay import SCR
from .expert_gate import ExpertGateStrategy
from .mer import MER
70 changes: 36 additions & 34 deletions avalanche/training/supervised/mer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from typing import Callable, Sequence, Optional, Union
from copy import deepcopy
from typing import Callable, Optional, Sequence, Union

import torch
import torch.nn.functional as F
from torch.nn import Module, CrossEntropyLoss
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer
from copy import deepcopy

from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates import OnlineSupervisedMetaLearningTemplate
from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins import EvaluationPlugin, SupervisedPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.storage_policy import ReservoirSamplingBuffer
from avalanche.training.templates import SupervisedMetaLearningTemplate


class MERBuffer:
def __init__(
self, max_buffer_size=100, buffer_mb_size=10, device=torch.device("cpu")
):
self.storage_policy = ReservoirSamplingBuffer(max_size=max_buffer_size)
self.buffer_mb_size = buffer_mb_size
def __init__(self, mem_size=100, batch_size_mem=10, device=torch.device("cpu")):
self.storage_policy = ReservoirSamplingBuffer(max_size=mem_size)
self.batch_size_mem = batch_size_mem
self.device = device

def update(self, strategy):
Expand All @@ -31,8 +29,8 @@ def get_batch(self, x, y, t):
if len(self) == 0:
return x, y, t

bsize = min(len(self), self.buffer_mb_size)
rnd_ind = torch.randperm(len(self))[:bsize]
bsize = min(len(self), self.batch_size_mem)
rnd_ind = torch.randperm(len(self))[:bsize].tolist()
buff_x = torch.cat(
[self.storage_policy.buffer[i][0].unsqueeze(0) for i in rnd_ind]
).to(self.device)
Expand All @@ -50,19 +48,19 @@ def get_batch(self, x, y, t):
return mixed_x, mixed_y, mixed_t


class MER(OnlineSupervisedMetaLearningTemplate):
class MER(SupervisedMetaLearningTemplate):
def __init__(
self,
model: Module,
optimizer: Optimizer,
criterion=CrossEntropyLoss(),
max_buffer_size=200,
buffer_mb_size=10,
mem_size=200,
batch_size_mem=10,
n_inner_steps=5,
beta=0.1,
gamma=0.1,
train_mb_size: int = 1,
train_passes: int = 1,
train_epochs: int = 1,
eval_mb_size: int = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[Sequence["SupervisedPlugin"]] = None,
Expand All @@ -78,8 +76,8 @@ def __init__(
:param model: PyTorch model.
:param optimizer: PyTorch optimizer.
:param criterion: loss function.
:param max_buffer_size: maximum size of the buffer.
:param buffer_mb_size: number of samples to retrieve from buffer
:param mem_size: maximum size of the buffer.
:param batch_size_mem: number of samples to retrieve from buffer
for each sample.
:param n_inner_steps: number of inner updates per sample.
:param beta: coefficient for within-batch Reptile update.
Expand All @@ -91,7 +89,7 @@ def __init__(
optimizer,
criterion,
train_mb_size,
train_passes,
train_epochs,
eval_mb_size,
device,
plugins,
Expand All @@ -101,8 +99,8 @@ def __init__(
)

self.buffer = MERBuffer(
max_buffer_size=max_buffer_size,
buffer_mb_size=buffer_mb_size,
mem_size=mem_size,
batch_size_mem=batch_size_mem,
device=self.device,
)
self.n_inner_steps = n_inner_steps
Expand Down Expand Up @@ -132,21 +130,25 @@ def _inner_updates(self, **kwargs):

# Within-batch Reptile update
w_aft_t = self.model.state_dict()
self.model.load_state_dict(
{
name: w_bef_t[name] + ((w_aft_t[name] - w_bef_t[name]) * self.beta)
for name in w_bef_t
}
)
load_dict = {}
for name, param in self.model.named_parameters():
load_dict[name] = w_bef_t[name] + (
(w_aft_t[name] - w_bef_t[name]) * self.beta
)

self.model.load_state_dict(load_dict, strict=False)

def _outer_update(self, **kwargs):
w_aft = self.model.state_dict()
self.model.load_state_dict(
{
name: self.w_bef[name] + ((w_aft[name] - self.w_bef[name]) * self.gamma)
for name in self.w_bef
}
)

load_dict = {}
for name, param in self.model.named_parameters():
load_dict[name] = self.w_bef[name] + (
(w_aft[name] - self.w_bef[name]) * self.gamma
)

self.model.load_state_dict(load_dict, strict=False)

with torch.no_grad():
pred = self.forward()
self.loss = self._criterion(pred, self.mb_y)
Expand Down
21 changes: 13 additions & 8 deletions avalanche/training/supervised/supervised_contrastive_replay.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Sequence, Optional
from typing import Optional, Sequence

import torch
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates import SupervisedTemplate
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda

from avalanche.core import BaseSGDPlugin
from torchvision.transforms import Compose
from torchvision.transforms import Lambda
from avalanche.training.plugins import ReplayPlugin
from avalanche.models import SCRModel
from avalanche.training.losses import SCRLoss
from avalanche.training.plugins import ReplayPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.models import SCRModel
from avalanche.training.templates import SupervisedTemplate


class SCR(SupervisedTemplate):
Expand Down Expand Up @@ -157,7 +157,9 @@ def _after_forward(self, **kwargs):
def _after_training_exp(self, **kwargs):
"""Update NCM means"""
super()._after_training_exp(**kwargs)
self.model.eval()
self.compute_class_means()
self.model.train()

@torch.no_grad()
def compute_class_means(self):
Expand All @@ -166,7 +168,10 @@ def compute_class_means(self):
# for each class
for dataset in self.replay_plugin.storage_policy.buffer_datasets:
dl = DataLoader(
dataset, shuffle=False, batch_size=self.eval_mb_size, drop_last=False
dataset.eval(),
shuffle=False,
batch_size=self.eval_mb_size,
drop_last=False,
)
num_els = 0
# for each mini-batch in each class
Expand Down
34 changes: 34 additions & 0 deletions tests/training/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
DER,
LearningToPrompt,
ExpertGateStrategy,
MER,
)
from avalanche.training.supervised.cumulative import Cumulative
from avalanche.training.supervised.icarl import ICaRL
Expand Down Expand Up @@ -410,6 +411,7 @@ def test_replay(self):

# MT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True)

strategy = Replay(
model,
optimizer,
Expand All @@ -420,6 +422,38 @@ def test_replay(self):
eval_mb_size=50,
train_epochs=2,
)

run_strategy(benchmark, strategy)

def test_mer(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
strategy = MER(
model,
optimizer,
criterion,
mem_size=10,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
run_strategy(benchmark, strategy)

# MT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True)

strategy = MER(
model,
optimizer,
criterion,
mem_size=10,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)

run_strategy(benchmark, strategy)

def test_gdumb(self):
Expand Down

0 comments on commit 2b7fa26

Please sign in to comment.