Skip to content

Commit

Permalink
soundstorm can now accept semantic token ids for conditioning, which …
Browse files Browse the repository at this point in the history
…will be time aligned based on the sampling hz and downsample factors of both semantic and codec modules. eventually, will accept text and sample the semantic token ids from the finished https://github.com/lucidrains/spear-tts-pytorch repository
  • Loading branch information
lucidrains committed Jun 25, 2023
1 parent 2539868 commit 3822ebb
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 14 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.15',
version = '0.0.16',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand All @@ -23,6 +23,7 @@
'beartype',
'classifier-free-guidance-pytorch>=0.1.5',
'einops>=0.6.1',
'spear-tts-pytorch',
'torch>=1.6',
],
classifiers=[
Expand Down
158 changes: 145 additions & 13 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from soundstorm_pytorch.attend import Attend

from spear_tts_pytorch import TextToSemantic

from audiolm_pytorch import SoundStream

from tqdm import tqdm
Expand Down Expand Up @@ -509,6 +511,7 @@ def __init__(
net: ConformerWrapper,
*,
soundstream: Optional[SoundStream] = None,
spear_tts_text_to_semantic: Optional[TextToSemantic] = None,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
Expand All @@ -517,23 +520,76 @@ def __init__(
schedule = 'linear',
can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
self_token_critic = False, # https://aclanthology.org/2021.naacl-main.409/
critic_loss_weight = 1.
critic_loss_weight = 1.,
num_semantic_token_ids = None,
semantic_pad_id = -1,
wav2vec_target_sample_hz = None,
wav2vec_downsample_factor = None,
codec_target_sample_hz = None,
codec_downsample_factor = None,
):
super().__init__()

# conformer settings

self.net = net
dim = net.dim
self.dim = dim
self.num_tokens = net.codebook_size

# set soundstream

self.soundstream = soundstream

if exists(soundstream):
self.codec_target_sample_hz = soundstream.target_sample_hz
self.codec_downsample_factor = soundstream.downsample_factor
else:
self.codec_target_sample_hz = codec_target_sample_hz
self.codec_downsample_factor = codec_downsample_factor

if exists(self.soundstream):
assert net.grouped_quantizers == soundstream.rq_groups
assert net.codebook_size == soundstream.codebook_size
assert net.num_quantizers == soundstream.num_quantizers

assert not (self_token_critic and exists(token_critic))
# set text-to-semantic

self.net = net
self.text_to_semantic = spear_tts_text_to_semantic

if exists(spear_tts_text_to_semantic) and exists(spear_tts_text_to_semantic.wav2vec):
assert not (exists(wav2vec_downsample_factor) or exists(wav2vec_target_sample_hz)), 'wav2vec downsample factor and sampling freq being auto-set from the text-to-semantic module passed in, as it contains the wav2vec instance'
self.wav2vec = spear_tts_text_to_semantic.wav2vec
self.wav2vec_target_sample_hz = maybe_wav2vec.target_sample_hz
self.wav2vec_downsample_factor = maybe_wav2vec.downsample_factor
else:
self.wav2vec = None
self.wav2vec_target_sample_hz = wav2vec_target_sample_hz
self.wav2vec_downsample_factor = wav2vec_downsample_factor

# whether to text condition on audio generation is dependent on whether hyperparameters are supplied

self.should_condition = exists(self.wav2vec_downsample_factor) and exists(self.wav2vec_target_sample_hz)

# in the case that text-to-semantic module passed in

if self.should_condition:
assert exists(self.codec_target_sample_hz) and exists(self.codec_downsample_factor)

if exists(spear_tts_text_to_semantic):
self.semantic_token_emb = spear_tts_text_to_semantic.semantic_token_emb
self.semantic_cond_to_model_dim = nn.Linear(spear_tts_text_to_semantic, net.dim)
self.semantic_pad_id = spear_tts_text_to_semantic.semantic_pad_id
else:
assert exists(num_semantic_token_ids), 'if you are conditioning, you must pass in the number of semantic token ids'
self.semantic_token_emb = nn.Embedding(num_semantic_token_ids, dim)
self.semantic_cond_to_model_dim = nn.Identity()
self.semantic_pad_id = semantic_pad_id

# detect token critic settings

assert not (self_token_critic and exists(token_critic))

dim = net.dim
self.dim = dim
self.num_tokens = net.codebook_size
self.num_quantizers = net.num_quantizers
self.grouped_quantizers = net.grouped_quantizers

Expand Down Expand Up @@ -575,36 +631,53 @@ def __init__(

self.critic_loss_weight = critic_loss_weight

@property
def device(self):
return next(self.net.parameters()).device

@torch.no_grad()
@eval_decorator
def generate(
self,
num_latents = None,
*,
cond_semantic_token_ids = None,
seconds = None,
batch_size = None,
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
**kwargs
):
assert not (exists(cond_semantic_token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'

assert exists(num_latents) ^ exists(seconds)

if not exists(num_latents):
assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds'
num_latents = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of

sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)
# maybe condition

device = next(self.net.parameters()).device
cond_tokens = self.maybe_get_condition(cond_semantic_token_ids)

times = torch.linspace(0., 1., self.steps + 1)
# determine batch size and sequence length, which depends whether it is conditioning

# sequence length of the conformer is the number of latents
if exists(cond_tokens):
batch_size, num_latents = cond_tokens.shape[:2]
sample_one = batch_size == 1
else:
sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)

seq_len = num_latents * self.grouped_quantizers * self.num_quantizers

# device and time

device = self.device

times = torch.linspace(0., 1., self.steps + 1)

# sequence starts off as all masked

shape = (batch_size, seq_len)
Expand All @@ -627,6 +700,7 @@ def generate(

logits, embeds = self.net(
seq,
cond = cond_tokens,
sum_embeds = self_cond,
return_logits_and_embeddings = True,
**kwargs
Expand Down Expand Up @@ -681,14 +755,64 @@ def generate(

return out

def maybe_get_condition(self, token_ids = None, length = None):
assert not (exists(token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'

if not exists(token_ids):
return None

context = torch.no_grad if exists(self.text_to_semantic) else nullcontext

with context():
mask = token_ids != self.semantic_pad_id
token_ids = token_ids.masked_fill(~mask, 0)

semantic_tokens = self.semantic_token_emb(token_ids)
cond_tokens = self.semantic_cond_to_model_dim(semantic_tokens)

# just mask out the padding to 0s and let the network learn that for now
# eventually should add self attention masking to conformer, and calculate the correct number of masked tokens per variable lengthed batch row

cond_tokens = cond_tokens.masked_fill(~rearrange(mask, '... -> ... 1'), 0.)


# now need to interpolate the conditioning tokens
# to align semantic and vector quantized tokens, time-wise

cond_length = cond_tokens.shape[-2]

target_cond_length = math.ceil(cond_length * (self.wav2vec_downsample_factor / self.wav2vec_target_sample_hz) / (self.codec_downsample_factor / self.codec_target_sample_hz))

# pytorch does not interpolate 1d, so hack by convert to 2d

cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1')
cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear')
cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d')

# whether to curtail or pad to length

cond_length = cond_tokens.shape[-2]

if exists(length):
if cond_length < length:
cond_tokens = F.pad(cond_tokens, (0, 0, 0, length - cond_length), value = 0.)
elif cond_length > length:
cond_tokens = cond_tokens[:, :length]

return cond_tokens

def forward(
self,
x,
*,
cond_semantic_token_ids = None,
only_train_generator = False,
only_train_critic = False,
generator_sample_temperature = None,
**kwargs
):
# if raw audio passed in, convert to residual quantized vectors

is_raw_audio = x.dtype == torch.float

if is_raw_audio:
Expand All @@ -697,8 +821,16 @@ def forward(
self.soundstream.eval()
_, x, _ = self.soundstream(x, return_encoded = True)

# shape

b, n, gq, device = *x.shape, x.device

# maybe condition

cond_tokens = self.maybe_get_condition(cond_semantic_token_ids, length = x.shape[-2])

# prepare masking

orig_seq = rearrange(x.clone(), 'b n q -> b (n q)')

t = torch.randint(0, n, (1,)).item()
Expand Down Expand Up @@ -749,7 +881,7 @@ def forward(

if sample_prob(self.self_cond_train_prob):
with torch.no_grad():
self_cond = self.net(masked, return_embeddings = True, **kwargs).detach()
self_cond = self.net(masked, cond = cond_tokens, return_embeddings = True, **kwargs).detach()

kwargs.update(sum_embeds = self.to_self_cond(self_cond))

Expand All @@ -758,7 +890,7 @@ def forward(
context = torch.no_grad if only_train_critic else nullcontext

with context():
logits = self.net(masked, **kwargs)
logits = self.net(masked, cond = cond_tokens, **kwargs)

# cross entropy loss

Expand Down

0 comments on commit 3822ebb

Please sign in to comment.