diff --git a/pipegoose/nn/expert_parallel/expert_context.py b/pipegoose/nn/expert_parallel/expert_context.py index ad760fb..13e9982 100644 --- a/pipegoose/nn/expert_parallel/expert_context.py +++ b/pipegoose/nn/expert_parallel/expert_context.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import List + from torchtyping import TensorType class ExpertContext: _instance = None - + def __init__(self): self.aux_loss = [] self.z_loss = [] @@ -14,14 +14,14 @@ def __init__(self): def push_aux_loss(self, aux_loss: TensorType): self.aux_loss.append(aux_loss) - def pop_all_aux_loss(self) -> List[TensorType]: + def pop_all_aux_loss(self) -> list[TensorType]: aux_loss, self.aux_loss = self.aux_loss, [] return aux_loss def push_z_loss(self, z_loss: TensorType): self.z_loss.append(z_loss) - def pop_all_z_loss(self) -> List[TensorType]: + def pop_all_z_loss(self) -> list[TensorType]: z_loss, self.z_loss = self.z_loss, [] return z_loss diff --git a/pipegoose/nn/expert_parallel/loss.py b/pipegoose/nn/expert_parallel/loss.py index d290a88..d4fdead 100644 --- a/pipegoose/nn/expert_parallel/loss.py +++ b/pipegoose/nn/expert_parallel/loss.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, List from torchtyping import TensorType @@ -11,6 +11,16 @@ def __init__(self, loss_func: Callable, aux_weight: float = 0.01, z_weight: floa self.aux_weight = aux_weight self.z_weight = z_weight + @property + def aux_loss(self) -> List[float]: + expert_context = ExpertContext.get_instance() + return expert_context.aux_loss + + @property + def z_loss(self) -> List[float]: + expert_context = ExpertContext.get_instance() + return expert_context.z_loss + def __call__(self, *args, **kwargs) -> TensorType: loss = self.loss_func(*args, **kwargs) expert_context = ExpertContext.get_instance()