diff --git a/pyroved/models/trvae.py b/pyroved/models/trvae.py index c416ca1..851b1d7 100644 --- a/pyroved/models/trvae.py +++ b/pyroved/models/trvae.py @@ -179,6 +179,20 @@ def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor: z_scale = z[:, self.z_dim:] return z_loc, z_scale + def decode(self, z: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor: + """ + Decodes a batch of latent coordnates + """ + if y is not None: + z = torch.cat([z.to(self.device), y.to(self.device)], -1) + z = [z] + if self.coord > 0: + grid = self.grid.expand(z.shape[0], *self.grid.shape) + z = z.append(grid.to(self.device)) + with torch.no_grad(): + loc = self.decoder_net(*z) + return loc + def manifold2d(self, d: int, plot: bool = True, **kwargs: Union[str, int]) -> torch.Tensor: """