Skip to content

Commit

Permalink
Update CSS inference
Browse files Browse the repository at this point in the history
  • Loading branch information
JusperLee committed Oct 31, 2024
1 parent 4387f55 commit c29219a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 26 deletions.
114 changes: 89 additions & 25 deletions separation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,78 @@
import look2hear.models
from look2hear.metrics import MetricsTracker
from pyannote.audio import Pipeline
from speechbrain.pretrained import EncoderClassifier

import torch


def process_audio_segments(ests_out_lists, ests_out_emb_lists):
"""Process and concatenate audio segments into two complete audio tracks using PyTorch.
Args:
ests_out_lists (list): List of audio segments, each segment is [s1, s2].
ests_out_emb_lists (list): List of embeddings, each segment is [e1, e2].
Returns:
final_s1 (torch.Tensor): Concatenated audio for s1.
final_s2 (torch.Tensor): Concatenated audio for s2.
"""
# Initialize lists to store the reordered s1 and s2 audio segments
s1_list = []
s2_list = []

# Get the first segment's audio and embeddings as reference
first_s1, first_s2 = ests_out_lists[0]
first_e1, first_e2 = ests_out_emb_lists[0]

# Append the first segment's audio to the lists
s1_list.append(first_s1)
s2_list.append(first_s2)

# Process each subsequent segment
for i in range(1, len(ests_out_lists)):
s1_i, s2_i = ests_out_lists[i]
e1_i, e2_i = ests_out_emb_lists[i]

# Compute similarities between embeddings using cosine similarity
# Ensure the similarities are scalar values

# Adjust the dimension based on your embeddings' shape
# Here we assume embeddings are 1D tensors (vectors)
sim11 = torch.nn.functional.cosine_similarity(first_e1, e1_i, dim=0)
sim22 = torch.nn.functional.cosine_similarity(first_e2, e2_i, dim=0)
sim12 = torch.nn.functional.cosine_similarity(first_e1, e2_i, dim=0)
sim21 = torch.nn.functional.cosine_similarity(first_e2, e1_i, dim=0)

# If the similarities are tensors with more than one element, reduce them to scalars
if sim11.numel() > 1:
sim11 = sim11.mean()
sim22 = sim22.mean()
sim12 = sim12.mean()
sim21 = sim21.mean()

# Alternatively, if embeddings are single-element tensors, extract the scalar value
if sim11.numel() == 1:
sim11 = sim11.item()
sim22 = sim22.item()
sim12 = sim12.item()
sim21 = sim21.item()

# Decide whether to swap s1 and s2 based on similarity
if (sim11 + sim22) >= (sim12 + sim21):
# Keep the original order
s1_list.append(s1_i)
s2_list.append(s2_i)
else:
# Swap s1 and s2 to maximize similarity alignment
s1_list.append(s2_i)
s2_list.append(s1_i)

# Concatenate all s1 and s2 segments to form the final audio tracks
final_s1 = torch.cat(s1_list, dim=1)
final_s2 = torch.cat(s2_list, dim=1)

return final_s1, final_s2

def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if cfg.model.get("_target_") == "look2hear.models.dptnet.DPTNetModel":
Expand All @@ -34,19 +106,22 @@ def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
else:
cfg.model.pop("_target_", None)
model = look2hear.models.ConvTasNet.from_pretrain(os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"), **cfg.model).cuda()
vad_model = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token="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", cache_dir="/home/likai/data5/huggingface_models")
vad_model = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token="AUTH_TOKEN", cache_dir="./huggingface_models")

initial_params = {"onset": 0.3, "offset": 0.2,
"min_duration_on": 0.0, "min_duration_off": 0.0}
vad_model.instantiate(initial_params)
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "results/"), exist_ok=True)
metrics = MetricsTracker(save_file=os.path.join(cfg.exp.dir, cfg.exp.name, "results/")+"metrics.csv")

classifier = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="./huggingface_models/speechbrain/spkrec-ecapa-voxceleb",
use_auth_token="AUTH_TOKEN")

mix = torchaudio.load("tests/noise/mix.wav")[0]
spk = [torchaudio.load("tests/noise/s1.wav")[0], torchaudio.load("tests/noise/s2.wav")[0]]
spks = torch.cat(spk, dim=0)

mix = mix.squeeze(0).cuda()
spks = spks.cuda()

ests_out_lists = []
ests_out_emb_lists = []

with torch.no_grad():
vad_results = vad_model("tests/noise/mix.wav")
Expand All @@ -56,8 +131,6 @@ def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
start = int(start_end.start * 16000)
end = int(start_end.end * 16000)
mix_tmp = mix[start:end].cpu()
spks_tmp = spks[:, start:end]

if end - start <= 320:
continue
try:
Expand All @@ -71,25 +144,16 @@ def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
ests_out = model(mix_tmp.unsqueeze(0)).to("cuda")
else:
raise e

if ests_out.dim() == 3:
ests_out = ests_out.squeeze(0)

metrics(
mix=mix_tmp,
clean=spks_tmp,
estimate=ests_out,
key="test",
spks_id=["s1", "s2"],
start_idx=start,
end_idx=end
)
ests_out_lists.append([ests_out[:, 0], ests_out[:, 1]])
ests_out_emb_lists.append([classifier.encode_batch(ests_out[:, 0]).view(-1), classifier.encode_batch(ests_out[:, 1]).view(-1)])
finally:
torch.cuda.empty_cache()

print(metrics.update())
metrics.final()

final_s1, final_s2 = process_audio_segments(ests_out_lists, ests_out_emb_lists)
torchaudio.save("tests/noise/s1_est.wav", final_s1, 16000)
torchaudio.save("tests/noise/s2_est.wav", final_s2, 16000)


if __name__ == "__main__":

parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion separation/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
else:
cfg.model.pop("_target_", None)
model = look2hear.models.ConvTasNet.from_pretrain(os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"), **cfg.model).cuda()
vad_model = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token="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", cache_dir="./huggingface_models")
vad_model = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token="AUTH_TOKEN", cache_dir="./huggingface_models")
initial_params = {"onset": 0.3, "offset": 0.2,
"min_duration_on": 0.0, "min_duration_off": 0.0}
vad_model.instantiate(initial_params)
Expand Down

0 comments on commit c29219a

Please sign in to comment.