From f0def31575ffeec60e1b967396ed32408f945519 Mon Sep 17 00:00:00 2001 From: julian-parker Date: Tue, 14 Jan 2025 10:36:13 +0000 Subject: [PATCH] Use 64bit ints to guard against issues with very large codebook size --- stable_audio_tools/models/fsq.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/stable_audio_tools/models/fsq.py b/stable_audio_tools/models/fsq.py index 1d11154..920fa42 100644 --- a/stable_audio_tools/models/fsq.py +++ b/stable_audio_tools/models/fsq.py @@ -35,10 +35,10 @@ def __init__( super().__init__() self.levels = levels - _levels = torch.tensor(levels, dtype=int32) + _levels = torch.tensor(levels, dtype=torch.int64) self.register_buffer("_levels", _levels, persistent = False) - _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) self.register_buffer("_basis", _basis, persistent = False) codebook_dim = len(levels) @@ -90,7 +90,9 @@ def _indices_to_codes(self, indices): def _codes_to_indices(self, zhat): zhat = self._scale_and_shift(zhat) - return (zhat * self._basis).sum(dim=-1).to(int32) + zhat = zhat.round().to(torch.int64) + out = (zhat * self._basis).sum(dim=-1) + return out def _indices_to_level_indices(self, indices): indices = rearrange(indices, '... -> ... 1') @@ -100,7 +102,7 @@ def _indices_to_level_indices(self, indices): def indices_to_codes(self, indices): # Expects input of batch x sequence x num_codebooks assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' - codes = self._indices_to_codes(indices) + codes = self._indices_to_codes(indices.to(torch.int64)) codes = rearrange(codes, '... c d -> ... (c d)') return codes @@ -116,7 +118,7 @@ def forward(self, z, skip_tanh: bool = False): # make sure allowed dtype before quantizing if z.dtype not in self.allowed_dtypes: - z = z.float() + z = z.to(torch.float64) codes = self.quantize(z, skip_tanh=skip_tanh) indices = self._codes_to_indices(codes)