From 342de780cd15bade5b785d0b8f132e8dd89f069e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20K=C3=BCnkele?= Date: Fri, 1 Sep 2023 18:51:49 +0200 Subject: [PATCH] fixes signal game with GumbelSoftmax --- egg/zoo/signal_game/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egg/zoo/signal_game/train.py b/egg/zoo/signal_game/train.py index e86888fbb..ffc4bd58f 100644 --- a/egg/zoo/signal_game/train.py +++ b/egg/zoo/signal_game/train.py @@ -67,7 +67,7 @@ def loss_nll( NLL loss - differentiable and can be used with both GS and Reinforce """ nll = F.nll_loss(receiver_output, labels, reduction="none") - acc = (labels == receiver_output.argmax(dim=1)).float().mean() + acc = (labels == receiver_output.argmax(dim=1)).float() return nll, {"acc": acc}