From 5e8bda54ef1812f03dbb1ff01621700c9c6cbe92 Mon Sep 17 00:00:00 2001 From: Antonios Sarikas Date: Mon, 2 Dec 2024 17:17:37 +0200 Subject: [PATCH] Fix `requires_grad` for `identity` in `TNet` --- src/aidsorb/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aidsorb/modules.py b/src/aidsorb/modules.py index aa85531..46f9d34 100644 --- a/src/aidsorb/modules.py +++ b/src/aidsorb/modules.py @@ -207,7 +207,7 @@ def forward(self, x): x = self.dense_blocks(x) # Initialize the identity matrix. - identity = torch.eye(self.embed_dim, device=x.device, requires_grad=True).repeat(bs, 1, 1) + identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1) # Output has shape (B, self.embed_dim, self.embed_dim). x = x.view(-1, self.embed_dim, self.embed_dim) + identity