Skip to content

Commit

Permalink
[Feature] Add retrieving auxiliary and Z losses from ExpertLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 30, 2023
1 parent b0fe0a6 commit cf0f927
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
8 changes: 4 additions & 4 deletions pipegoose/nn/expert_parallel/expert_context.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from __future__ import annotations
from typing import List


from torchtyping import TensorType


class ExpertContext:
_instance = None

def __init__(self):
self.aux_loss = []
self.z_loss = []

def push_aux_loss(self, aux_loss: TensorType):
self.aux_loss.append(aux_loss)

def pop_all_aux_loss(self) -> List[TensorType]:
def pop_all_aux_loss(self) -> list[TensorType]:
aux_loss, self.aux_loss = self.aux_loss, []
return aux_loss

def push_z_loss(self, z_loss: TensorType):
self.z_loss.append(z_loss)

def pop_all_z_loss(self) -> List[TensorType]:
def pop_all_z_loss(self) -> list[TensorType]:
z_loss, self.z_loss = self.z_loss, []
return z_loss

Expand Down
12 changes: 11 additions & 1 deletion pipegoose/nn/expert_parallel/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, List

from torchtyping import TensorType

Expand All @@ -11,6 +11,16 @@ def __init__(self, loss_func: Callable, aux_weight: float = 0.01, z_weight: floa
self.aux_weight = aux_weight
self.z_weight = z_weight

@property
def aux_loss(self) -> List[float]:
expert_context = ExpertContext.get_instance()
return expert_context.aux_loss

@property
def z_loss(self) -> List[float]:
expert_context = ExpertContext.get_instance()
return expert_context.z_loss

def __call__(self, *args, **kwargs) -> TensorType:
loss = self.loss_func(*args, **kwargs)
expert_context = ExpertContext.get_instance()
Expand Down

0 comments on commit cf0f927

Please sign in to comment.