From 3822ebb5050c7e2d9c3f504feb2d4b3e00e40577 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 25 Jun 2023 12:38:56 -0700 Subject: [PATCH] soundstorm can now accept semantic token ids for conditioning, which will be time aligned based on the sampling hz and downsample factors of both semantic and codec modules. eventually, will accept text and sample the semantic token ids from the finished https://github.com/lucidrains/spear-tts-pytorch repository --- setup.py | 3 +- soundstorm_pytorch/soundstorm.py | 158 ++++++++++++++++++++++++++++--- 2 files changed, 147 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 389015b..deeb369 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.15', + version = '0.0.16', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', @@ -23,6 +23,7 @@ 'beartype', 'classifier-free-guidance-pytorch>=0.1.5', 'einops>=0.6.1', + 'spear-tts-pytorch', 'torch>=1.6', ], classifiers=[ diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 82984e7..5014db4 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -16,6 +16,8 @@ from soundstorm_pytorch.attend import Attend +from spear_tts_pytorch import TextToSemantic + from audiolm_pytorch import SoundStream from tqdm import tqdm @@ -509,6 +511,7 @@ def __init__( net: ConformerWrapper, *, soundstream: Optional[SoundStream] = None, + spear_tts_text_to_semantic: Optional[TextToSemantic] = None, steps = 18, self_cond = False, self_cond_train_prob = 0.75, @@ -517,23 +520,76 @@ def __init__( schedule = 'linear', can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked self_token_critic = False, # https://aclanthology.org/2021.naacl-main.409/ - critic_loss_weight = 1. + critic_loss_weight = 1., + num_semantic_token_ids = None, + semantic_pad_id = -1, + wav2vec_target_sample_hz = None, + wav2vec_downsample_factor = None, + codec_target_sample_hz = None, + codec_downsample_factor = None, ): super().__init__() + + # conformer settings + + self.net = net + dim = net.dim + self.dim = dim + self.num_tokens = net.codebook_size + + # set soundstream + self.soundstream = soundstream + if exists(soundstream): + self.codec_target_sample_hz = soundstream.target_sample_hz + self.codec_downsample_factor = soundstream.downsample_factor + else: + self.codec_target_sample_hz = codec_target_sample_hz + self.codec_downsample_factor = codec_downsample_factor + if exists(self.soundstream): assert net.grouped_quantizers == soundstream.rq_groups assert net.codebook_size == soundstream.codebook_size assert net.num_quantizers == soundstream.num_quantizers - assert not (self_token_critic and exists(token_critic)) + # set text-to-semantic - self.net = net + self.text_to_semantic = spear_tts_text_to_semantic + + if exists(spear_tts_text_to_semantic) and exists(spear_tts_text_to_semantic.wav2vec): + assert not (exists(wav2vec_downsample_factor) or exists(wav2vec_target_sample_hz)), 'wav2vec downsample factor and sampling freq being auto-set from the text-to-semantic module passed in, as it contains the wav2vec instance' + self.wav2vec = spear_tts_text_to_semantic.wav2vec + self.wav2vec_target_sample_hz = maybe_wav2vec.target_sample_hz + self.wav2vec_downsample_factor = maybe_wav2vec.downsample_factor + else: + self.wav2vec = None + self.wav2vec_target_sample_hz = wav2vec_target_sample_hz + self.wav2vec_downsample_factor = wav2vec_downsample_factor + + # whether to text condition on audio generation is dependent on whether hyperparameters are supplied + + self.should_condition = exists(self.wav2vec_downsample_factor) and exists(self.wav2vec_target_sample_hz) + + # in the case that text-to-semantic module passed in + + if self.should_condition: + assert exists(self.codec_target_sample_hz) and exists(self.codec_downsample_factor) + + if exists(spear_tts_text_to_semantic): + self.semantic_token_emb = spear_tts_text_to_semantic.semantic_token_emb + self.semantic_cond_to_model_dim = nn.Linear(spear_tts_text_to_semantic, net.dim) + self.semantic_pad_id = spear_tts_text_to_semantic.semantic_pad_id + else: + assert exists(num_semantic_token_ids), 'if you are conditioning, you must pass in the number of semantic token ids' + self.semantic_token_emb = nn.Embedding(num_semantic_token_ids, dim) + self.semantic_cond_to_model_dim = nn.Identity() + self.semantic_pad_id = semantic_pad_id + + # detect token critic settings + + assert not (self_token_critic and exists(token_critic)) - dim = net.dim - self.dim = dim - self.num_tokens = net.codebook_size self.num_quantizers = net.num_quantizers self.grouped_quantizers = net.grouped_quantizers @@ -575,12 +631,17 @@ def __init__( self.critic_loss_weight = critic_loss_weight + @property + def device(self): + return next(self.net.parameters()).device + @torch.no_grad() @eval_decorator def generate( self, num_latents = None, *, + cond_semantic_token_ids = None, seconds = None, batch_size = None, start_temperature = 1., @@ -588,23 +649,35 @@ def generate( noise_level_scale = 1., **kwargs ): + assert not (exists(cond_semantic_token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa' + assert exists(num_latents) ^ exists(seconds) if not exists(num_latents): assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds' num_latents = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of - sample_one = not exists(batch_size) - batch_size = default(batch_size, 1) + # maybe condition - device = next(self.net.parameters()).device + cond_tokens = self.maybe_get_condition(cond_semantic_token_ids) - times = torch.linspace(0., 1., self.steps + 1) + # determine batch size and sequence length, which depends whether it is conditioning - # sequence length of the conformer is the number of latents + if exists(cond_tokens): + batch_size, num_latents = cond_tokens.shape[:2] + sample_one = batch_size == 1 + else: + sample_one = not exists(batch_size) + batch_size = default(batch_size, 1) seq_len = num_latents * self.grouped_quantizers * self.num_quantizers + # device and time + + device = self.device + + times = torch.linspace(0., 1., self.steps + 1) + # sequence starts off as all masked shape = (batch_size, seq_len) @@ -627,6 +700,7 @@ def generate( logits, embeds = self.net( seq, + cond = cond_tokens, sum_embeds = self_cond, return_logits_and_embeddings = True, **kwargs @@ -681,14 +755,64 @@ def generate( return out + def maybe_get_condition(self, token_ids = None, length = None): + assert not (exists(token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa' + + if not exists(token_ids): + return None + + context = torch.no_grad if exists(self.text_to_semantic) else nullcontext + + with context(): + mask = token_ids != self.semantic_pad_id + token_ids = token_ids.masked_fill(~mask, 0) + + semantic_tokens = self.semantic_token_emb(token_ids) + cond_tokens = self.semantic_cond_to_model_dim(semantic_tokens) + + # just mask out the padding to 0s and let the network learn that for now + # eventually should add self attention masking to conformer, and calculate the correct number of masked tokens per variable lengthed batch row + + cond_tokens = cond_tokens.masked_fill(~rearrange(mask, '... -> ... 1'), 0.) + + + # now need to interpolate the conditioning tokens + # to align semantic and vector quantized tokens, time-wise + + cond_length = cond_tokens.shape[-2] + + target_cond_length = math.ceil(cond_length * (self.wav2vec_downsample_factor / self.wav2vec_target_sample_hz) / (self.codec_downsample_factor / self.codec_target_sample_hz)) + + # pytorch does not interpolate 1d, so hack by convert to 2d + + cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1') + cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear') + cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d') + + # whether to curtail or pad to length + + cond_length = cond_tokens.shape[-2] + + if exists(length): + if cond_length < length: + cond_tokens = F.pad(cond_tokens, (0, 0, 0, length - cond_length), value = 0.) + elif cond_length > length: + cond_tokens = cond_tokens[:, :length] + + return cond_tokens + def forward( self, x, + *, + cond_semantic_token_ids = None, only_train_generator = False, only_train_critic = False, generator_sample_temperature = None, **kwargs ): + # if raw audio passed in, convert to residual quantized vectors + is_raw_audio = x.dtype == torch.float if is_raw_audio: @@ -697,8 +821,16 @@ def forward( self.soundstream.eval() _, x, _ = self.soundstream(x, return_encoded = True) + # shape + b, n, gq, device = *x.shape, x.device + # maybe condition + + cond_tokens = self.maybe_get_condition(cond_semantic_token_ids, length = x.shape[-2]) + + # prepare masking + orig_seq = rearrange(x.clone(), 'b n q -> b (n q)') t = torch.randint(0, n, (1,)).item() @@ -749,7 +881,7 @@ def forward( if sample_prob(self.self_cond_train_prob): with torch.no_grad(): - self_cond = self.net(masked, return_embeddings = True, **kwargs).detach() + self_cond = self.net(masked, cond = cond_tokens, return_embeddings = True, **kwargs).detach() kwargs.update(sum_embeds = self.to_self_cond(self_cond)) @@ -758,7 +890,7 @@ def forward( context = torch.no_grad if only_train_critic else nullcontext with context(): - logits = self.net(masked, **kwargs) + logits = self.net(masked, cond = cond_tokens, **kwargs) # cross entropy loss