From 1a46052ce58177ebb379c207ab557c683f2f74b0 Mon Sep 17 00:00:00 2001 From: Jiaqi Li <120090727@link.cuhk.edu.cn> Date: Tue, 18 Jun 2024 10:38:22 +0800 Subject: [PATCH] remove unused files --- models/tts/valle_v2/mls_dataset.py | 544 ----------------------------- models/tts/valle_v2/run_infer.py | 328 ----------------- 2 files changed, 872 deletions(-) delete mode 100644 models/tts/valle_v2/mls_dataset.py delete mode 100644 models/tts/valle_v2/run_infer.py diff --git a/models/tts/valle_v2/mls_dataset.py b/models/tts/valle_v2/mls_dataset.py deleted file mode 100644 index 5326a698..00000000 --- a/models/tts/valle_v2/mls_dataset.py +++ /dev/null @@ -1,544 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import random -import torch -from utils.data_utils import * -from tqdm import tqdm -import librosa -from petrel_client.client import Client -from torch.utils.data import Dataset -import pandas as pd -import rir_generator as rir -import time -import io -from multiprocessing import Pool, Lock - -NUM_WORKERS = 32 -lock = Lock() -SAMPLE_RATE = 16000 - - -def get_duration(file_path): - duration = librosa.get_duration(path=file_path, sr=SAMPLE_RATE) - return file_path, duration - - -# g2p -# from utils.g2p.g2p import phonemizer_g2p - -# override g2p with g2p_en library -from .g2p_processor import G2pProcessor - -phonemizer_g2p = G2pProcessor() - -# lang2token ={ -# 'zh': "[ZH]", -# 'ja':"[JA]", -# "en":"[EN]", -# "fr":"[FR]", -# "kr": "[KR]", -# "de": "[DE]", -# } -# LANG2CODE = { -# 'en': 655, -# 'zh': 654, -# } -import logging - - -class PhonemizerWarningFilter(logging.Filter): - def filter(self, record): - # 只过滤 phonemizer 中的 WARNING 级别日志 - if record.name == "phonemizer" and record.levelno == logging.WARNING: - return False - return False - - -logger = logging.getLogger("phonemizer") -filter = PhonemizerWarningFilter() -logger.addFilter(filter) -logging.basicConfig(level=logging.INFO) - - -class VALLEDataset(Dataset): - def __init__(self, args, is_valid=False, resample_to_24k=False): - print(f"Initializing VALLEDataset") - dataset_list = args.dataset_list - dataset_cache_dir = args.cache_dir # cache_dir - print(f"args.cache_dir = ", args.cache_dir) - os.makedirs(dataset_cache_dir, exist_ok=True) - # create dataset2dir - - self.client = Client("/mnt/petrelfs/hehaorui/petreloss.conf") - self.resample_to_24k = resample_to_24k - if self.resample_to_24k: - assert SAMPLE_RATE == 24000 - print(f"Using 24k resampling.") - - print(f"data sampling rate is {SAMPLE_RATE}") - - self.dataset2dir = { - "mls_train": "public-dataset-p2:s3://public-dataset-p2/Multilingual-LibriSpeech/data_0321/unzip/mls_english1/train/audio", - "mls_dev": "public-dataset-p2:s3://public-dataset-p2/Multilingual-LibriSpeech/data_0321/unzip/mls_english1/dev/audio", - "mls_test": "public-dataset-p2:s3://public-dataset-p2/Multilingual-LibriSpeech/data_0321/unzip/mls_english1/test/audio", - "librilight_small": "amphion:s3://amphion/Libri-light/small_15s", - "librilight_medium": "amphion:s3://amphion/Libri-light/medium_15s", - "librilight_large": "amphion:s3://amphion/Libri-light/large_15s", - "mls_german": "public-dataset-p2:s3://public-dataset-p2/Multilingual-LibriSpeech/data_0321/unzip/mls_german/train/audio", - } - - self.use_speaker = args.use_speaker - self.use_noise = args.use_noise - print(f"Using speaker: {self.use_speaker}, using noise: {self.use_noise}") - - self.dataset_list = dataset_list - self.meta_data_cache = None - - self.transcripts = None - - for dataset_name in self.dataset_list: - if dataset_name == "mls_train": - self.meta_data_cache_path = os.path.join( - dataset_cache_dir, "mls_train_metadata_cache.csv" - ) - # read meta data cache: MAIN_metadata_cache.csv - print(f"Loaded metadata cache from {self.meta_data_cache_path}") - - # write language info - tmp_cache = pd.read_csv(self.meta_data_cache_path, encoding="utf-8") - tmp_cache["language"] = "en" - - if self.meta_data_cache == None: - self.meta_data_cache = tmp_cache - else: - self.meta_data_cache.append(tmp_cache) - - if len(self.meta_data_cache) == 0: - print(f"Empty metadata cache!") - raise ValueError("Empty metadata cache!") - elif len(self.meta_data_cache) < 10731070: - print(f"Need to reload metadata cache!") - print(f"Current size: {len(self.meta_data_cache)}") - raise ValueError("Need to reload metadata cache!") - print(f"Loaded {len(self.meta_data_cache)} metadata_cache") - - import pickle - - # load mls en transcripts - if not os.path.isfile( - "/mnt/petrelfs/hehaorui/jiaqi/vc-dev/mls_en_transcripts.pkl" - ): - # read MLS dataset transcript txt into dict - self.transcript_path = os.path.join( - self.dataset2dir["mls_train"].rstrip("audio/"), - "transcripts.txt", - ) - file_bytes = self.client.get(self.transcript_path) - assert file_bytes is not None - buffer = io.BytesIO(file_bytes) - transcripts = buffer.getvalue() - del buffer - transcripts = transcripts.decode("utf-8") - - # read MLS dataset transcript txt into dict - self.transcripts = {} - for line in transcripts.split("\n"): - if line == "": - continue - uid, transcript = line.split("\t") - self.transcripts[uid] = transcript - - # dump cache - pickle.dump(self.transcripts, open("mls_en_transcripts.pkl", "wb")) - self.transcripts = pickle.load( - open( - "/mnt/petrelfs/hehaorui/jiaqi/vc-dev/mls_en_transcripts.pkl", - "rb", - ) - ) - elif dataset_name == "librilight_medium": - self.meta_data_cache_path = os.path.join( - dataset_cache_dir, f"{dataset_name}_metadata_cache.csv" - ) - print(f"Loaded metadata cache from {self.meta_data_cache_path}") - - # write language info - tmp_cache = pd.read_csv(self.meta_data_cache_path, encoding="utf-8") - tmp_cache["language"] = "en" - - if self.meta_data_cache == None: - self.meta_data_cache = tmp_cache - else: - self.meta_data_cache.append(tmp_cache) - breakpoint() - # TODO: load transcripts - raise NotImplementedError - - elif dataset_name == "mls_german": - raise NotImplementedError - transcripts = pickle.load( - open( - "/mnt/petrelfs/hehaorui/jiaqi/gpt-tts/mls_german_transcripts.pkl", - "rb", - ) - ) - - # set random_state to current time - current_time = int(time.time()) - self.meta_data_cache = self.meta_data_cache.sample( - frac=1.0, random_state=current_time - ).reset_index(drop=True) - - # filter_by_length: filter_out files with duration < 3.0 or > 25.0 - print(f"Filtering files with duration between 3.0 and 25.0 seconds") - print(f"Before filtering: {len(self.meta_data_cache)}") - self.meta_data_cache = self.meta_data_cache[ - (self.meta_data_cache["duration"] >= 3.0) - & (self.meta_data_cache["duration"] <= 25.0) - ] - print(f"After filtering: {len(self.meta_data_cache)}") - # create speaker2speaker_id - # self.speaker2id = self.create_speaker2id() - self.all_num_frames = (self.meta_data_cache["duration"] * SAMPLE_RATE).to_list() - self.num_frame_sorted = np.array(sorted(self.all_num_frames)) - self.num_frame_indices = np.array( - sorted( - range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] - ) - ) - - def save_cache_files( - self, - relpath2duration_path, - relpath2speaker_path, - index2relpath_path, - relpath2duration, - relpath2speaker, - index2relpath, - ): - def safe_write_to_file(data, file_path, mode="w"): - try: - with lock, open(file_path, mode, encoding="utf-8") as f: - json.dump(data, f) - f.flush() - os.fsync(f.fileno()) - except IOError as e: - print(f"Error writing to {file_path}: {e}") - - safe_write_to_file(relpath2duration, relpath2duration_path) - print(f"Saved relpath2duration to {relpath2duration_path}") - safe_write_to_file(relpath2speaker, relpath2speaker_path) - print(f"Saved relpath2speaker to {relpath2speaker_path}") - safe_write_to_file(index2relpath, index2relpath_path) - print(f"Saved index2relpath to {index2relpath_path}") - - def create_metadata_cache(self, dataset, cache_dir): - dataset_relpath2duration_path = os.path.join( - cache_dir, f"{dataset}_relpath2duration.json" - ) - dataset_relpath2speaker_path = os.path.join( - cache_dir, f"{dataset}_relpath2speaker.json" - ) - dataset_index2relpath_path = os.path.join( - cache_dir, f"{dataset}_index2relpath.json" - ) - dataset_meta_data_cache_path = os.path.join( - cache_dir, f"{dataset}_metadata_cache.csv" - ) - - # if os.path.exists(dataset_relpath2duration_path) and os.path.exists(dataset_relpath2speaker_path) and os.path.exists(dataset_index2relpath_path): - # print(f"Loading cache for {dataset}") - # with open(dataset_relpath2duration_path, 'r', encoding='utf-8') as f: - # relpath2duration = json.load(f) - # with open(dataset_relpath2speaker_path, 'r', encoding='utf-8') as f: - # relpath2speaker = json.load(f) - # with open(dataset_index2relpath_path, 'r', encoding='utf-8') as f: - # index2relpath = json.load(f) - # print(f"Loaded cache for {dataset} with {len(relpath2duration)} files") - # else: - if True: - print(f"Creating cache for {dataset}") - relpath2duration = {} - relpath2speaker = {} - index2relpath = {} - audio_rel_paths = self.get_audio_files(self.dataset2dir[dataset]) - random.shuffle(audio_rel_paths) - print(f"Loaded {len(audio_rel_paths)} files from {dataset}") - print(f"Generating cache for {dataset}") - relpath2duration, relpath2speaker, index2relpath = ( - self.get_duration_speaker_and_filter(dataset, audio_rel_paths) - ) - print(f"Generated cache for {dataset} with {len(relpath2duration)} files") - print(f"Saving cache for {dataset}") - self.save_cache_files( - dataset_relpath2duration_path, - dataset_relpath2speaker_path, - dataset_index2relpath_path, - relpath2duration, - relpath2speaker, - index2relpath, - ) - print(f"Saved cache for {dataset}") - - meta_datas = [] - print(f"Generating metadata cache for {dataset}") - for idx, relpath in tqdm(index2relpath.items()): - temp_item = { - "uid": f"{dataset}#{str(idx)}", - "relpath": relpath, - "duration": relpath2duration[relpath], - "speaker": relpath2speaker[relpath], - } - meta_datas.append(temp_item) - dataset_meta_data_cache = pd.DataFrame(meta_datas) - dataset_meta_data_cache.to_csv( - dataset_meta_data_cache_path, index=False, encoding="utf-8" - ) - return dataset_meta_data_cache - - def get_duration_speaker_and_filter(self, dataset, audio_rel_paths): - print(f"Processing metadata...") - rel_path2duration = {} - rel_path2speaker = {} - idx2rel_path = {} - base_dir = self.dataset2dir[dataset] - full_paths = [os.path.join(base_dir, rel_path) for rel_path in audio_rel_paths] - with Pool(processes=NUM_WORKERS) as pool: - results = list( - tqdm( - pool.imap_unordered(get_duration, full_paths), - total=len(audio_rel_paths), - ) - ) - - idx = 0 - print(f"Filtering files with duration between 3.0 and 25.0 seconds") - for file, duration in tqdm(results): - if duration > 3.0 and duration < 25.0: - rel_path = os.path.relpath(file, base_dir) - rel_path2duration[rel_path] = duration - speaker_id = file.split(os.sep)[-3] - speaker = f"{dataset}_{speaker_id}" - rel_path2speaker[rel_path] = speaker - idx2rel_path[idx] = rel_path - idx += 1 - return rel_path2duration, rel_path2speaker, idx2rel_path - - def get_audio_files(self, directory): - audio_files = [] - for root, _, files in os.walk(directory): - for file in files: - if file.endswith((".flac", ".wav", ".opus")): - rel_path = os.path.relpath(os.path.join(root, file), directory) - audio_files.append(rel_path) - return audio_files - - # only includes audio tokens - def get_num_frames(self, index): - # get_num_frames(durations) by index - duration = self.meta_data_cache["duration"][index] - # num_frames = duration * SAMPLE_RATE - num_frames = int(duration * 50) - - # file_rel_path = self.meta_data_cache['relpath'][index] - # uid = file_rel_path.rstrip('.flac').split('/')[-1] - # num_frames += len(self.transcripts[uid]) - return num_frames - - def create_speaker2id(self): - all_speakers = self.meta_data_cache["speaker"].unique() - speaker2id = {} - for idx, speaker in enumerate(all_speakers): - speaker2id[speaker] = idx - return speaker2id - - def snr_mixer(self, clean, noise, snr): - # Normalizing to -25 dB FS - rmsclean = (clean**2).mean() ** 0.5 - epsilon = 1e-10 - rmsclean = max(rmsclean, epsilon) - scalarclean = 10 ** (-25 / 20) / rmsclean - clean = clean * scalarclean - - rmsnoise = (noise**2).mean() ** 0.5 - rmsnoise = max(rmsnoise, epsilon) - if rmsnoise == epsilon: - return clean / scalarclean - scalarnoise = 10 ** (-25 / 20) / rmsnoise - noise = noise * scalarnoise - rmsnoise = (noise**2).mean() ** 0.5 - - # Set the noise level for a given SNR - noisescalar = np.sqrt(rmsclean / (10 ** (snr / 20)) / rmsnoise) - noisenewlevel = noise * noisescalar - noisyspeech = clean + noisenewlevel - noisyspeech_tensor = torch.tensor(noisyspeech, dtype=torch.float32) - return noisyspeech_tensor - - def add_noise(self, clean): - # self.noise_filenames: list of noise files - random_idx = np.random.randint(0, np.size(self.noise_filenames)) - selected_noise_file = self.noise_filenames[random_idx] - noise, _ = librosa.load(selected_noise_file, sr=SAMPLE_RATE) - clean = clean.cpu().numpy() - if len(noise) >= len(clean): - noise = noise[0 : len(clean)] # 截取噪声的长度 - else: - while len(noise) <= len(clean): # 如果噪声的长度小于语音的长度 - random_idx = (random_idx + 1) % len( - self.noise_filenames - ) # 随机读一个噪声 - newnoise, fs = librosa.load(selected_noise_file, sr=SAMPLE_RATE) - noiseconcat = np.append( - noise, np.zeros(int(fs * 0.2)) - ) # 在噪声后面加上0.2静音 - noise = np.append(noiseconcat, newnoise) # 拼接噪声 - noise = noise[0 : len(clean)] # 截取噪声的长度 - # 随机sample一个小于20大于0的随机数 - snr = random.uniform(0.0, 15.0) - noisyspeech = self.snr_mixer( - clean=clean, noise=noise, snr=snr - ) # 根据随机的SNR级别,混合生成带噪音频 - del noise - return noisyspeech - - def add_reverb(self, speech): - room_dim = [ - np.random.uniform(1, 12) for _ in range(3) - ] # [length, width, height] - mic_pos = [np.random.uniform(0, dim) for dim in room_dim] # 随机选择麦克风位置 - distance = np.random.normal(2, 4) # 确定声源与麦克风的距离 - while distance <= 0 or distance > 5: - distance = np.random.normal(2, 4) - source_pos = [ - mic_pos[0] + distance, - mic_pos[1], - mic_pos[2], - ] # 随机选择声源位置,确保它在以麦克风为中心的球内 - rt60 = np.random.uniform(0.05, 1.0) # 随机选择RT60值 - try: - rir_filter = rir.generate( - c=340, # 声速 - fs=SAMPLE_RATE, - r=[mic_pos], # 麦克风位置 - s=source_pos, # 声源位置 - L=room_dim, # 房间尺寸 - reverberation_time=rt60, # RT60值 - nsample=4096, # IR长度 - ) - # 应用混响 - speech_reverb = np.convolve( - speech.cpu().numpy(), rir_filter[:, 0], mode="same" - ) - speech = torch.tensor(speech_reverb, dtype=torch.float32) - return speech - except: - return speech # 如果遇到ValueError: s is outside the room,直接返回没加混响的声音 - - def __len__(self): - return len(self.meta_data_cache) - - def __getitem__(self, idx): - # Get the file rel path - file_rel_path = self.meta_data_cache["relpath"][idx] - # Get the dataset from cache uid - dataset_name = self.meta_data_cache["uid"][idx].split("#")[0] - # Get the full file path - full_file_path = os.path.join(self.dataset2dir[dataset_name], file_rel_path) - - # get transcript - uid = file_rel_path.rstrip(".flac").split("/")[-1] - phone = self.transcripts[uid] - phone = phonemizer_g2p(phone, "en")[1] - # phone = [LANG2CODE['en']] + phone - # phone = torch.tensor(phone, dtype=torch.long) - - file_bytes = self.client.get(full_file_path) - assert file_bytes is not None, f"file {full_file_path} not found" - buffer = io.BytesIO(file_bytes) - speech, _ = librosa.load(buffer, sr=SAMPLE_RATE) - speech = torch.tensor(speech, dtype=torch.float32) - # pad speech to multiples of 320 - remainder = speech.size(0) % 320 - if remainder > 0: - pad = 320 - remainder - speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0) - - # inputs = self._get_reference_vc(speech, hop_length=200) - inputs = {} - # Get the speaker id - # speaker = self.meta_data_cache['speaker'][idx] - # speaker_id = self.speaker2id[speaker] - # inputs["speaker_id"] = speaker_id - inputs["speech"] = speech # 24khz speech, [T] - inputs["phone"] = phone # [T] - return inputs - - -def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): - if len(batch) == 0: - return 0 - if len(batch) == max_sentences: - return 1 - if num_tokens > max_tokens: - return 1 - return 0 - - -def batch_by_size( - indices, - num_tokens_fn, - max_tokens=None, - max_sentences=None, - required_batch_size_multiple=1, -): - """ - Yield mini-batches of indices bucketed by size. Batches may contain - sequences of different lengths. - - Args: - indices (List[int]): ordered list of dataset indices - num_tokens_fn (callable): function that returns the number of tokens at - a given index - max_tokens (int, optional): max number of tokens in each batch - (default: None). - max_sentences (int, optional): max number of sentences in each - batch (default: None). - required_batch_size_multiple (int, optional): require batch size to - be a multiple of N (default: 1). - """ - bsz_mult = required_batch_size_multiple - - sample_len = 0 - sample_lens = [] - batch = [] - batches = [] - for i in range(len(indices)): - idx = indices[i] - num_tokens = num_tokens_fn(idx) - sample_lens.append(num_tokens) - sample_len = max(sample_len, num_tokens) - - assert ( - sample_len <= max_tokens - ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( - idx, sample_len, max_tokens - ) - num_tokens = (len(batch) + 1) * sample_len - - if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): - mod_len = max( - bsz_mult * (len(batch) // bsz_mult), - len(batch) % bsz_mult, - ) - batches.append(batch[:mod_len]) - batch = batch[mod_len:] - sample_lens = sample_lens[mod_len:] - sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 - batch.append(idx) - if len(batch) > 0: - batches.append(batch) - return batches diff --git a/models/tts/valle_v2/run_infer.py b/models/tts/valle_v2/run_infer.py deleted file mode 100644 index 0184b42c..00000000 --- a/models/tts/valle_v2/run_infer.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import torch -import random -import glob -import librosa - -# from utils.g2p.g2p import phonemizer_g2p as g2p -from .g2p_processor import G2pProcessor - -g2p = G2pProcessor() # use g2p_en as g2p - -import os -import torchaudio -import re -import numpy as np -import shutil - -SAMPLE_RATE = 16000 - -test_wer = True -test_sim = True -test_fid = False - - -class WER: - def __init__(self): - print("Loading WER") - from transformers import Wav2Vec2Processor, HubertForCTC - - from evaluate import load - - wer = load("wer") - - self.wer = wer - self.processor = Wav2Vec2Processor.from_pretrained( - "facebook/hubert-large-ls960-ft" - ) - self.model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") - self.model = self.model.to("cuda") - - def calc(self, transcript_text, target_text): - transcript_text = transcript_text.lower() - transcript_text = re.sub(r"[^\w\s]", "", transcript_text) - transcript_text = re.sub(r"\s+", " ", transcript_text) - transcript_text = transcript_text.strip() - - target_text = target_text.lower() - target_text = re.sub(r"[^\w\s]", "", target_text) - target_text = re.sub(r"\s+", " ", target_text) - target_text = target_text.strip() - - predictions = [transcript_text] - references = [target_text] - wer_score = self.wer.compute(predictions=predictions, references=references) - return wer_score, transcript_text, target_text - - def __call__(self, audio, gt_text): - # need 16khz audio, 1-dimensional - assert len(audio.shape) == 1 - audio = np.array(audio.cpu()) - input_values = self.processor(audio, return_tensors="pt").input_values.to( - "cuda" - ) - logits = self.model(input_values=input_values).logits - predicted_ids = torch.argmax(logits, dim=-1) - transcript_text = self.processor.decode(predicted_ids[0]) - # remove special characters - transcript_text = re.sub(r"[^\w\s]", "", transcript_text) - - wer_score, transcript_text, target_text = self.calc(transcript_text, gt_text) - return wer_score, transcript_text, target_text - - -class SIM: - def __init__(self): - from evaluation_test.eval import ( - WAVLM_LARGE_FINTUNED_PATH, - load, - init_model, - pipeline, - Tasks, - ) - - print("Loading WavLM-large-finetuned") - self.speaker_encoder = ( - init_model(checkpoint=WAVLM_LARGE_FINTUNED_PATH).to("cuda").eval() - ) - - def __call__(self, audio1, audio2): - # need 16khz audio, 1-dimensional, torch tensor - audio1 = audio1.unsqueeze(0).to("cuda") - audio2 = audio2.unsqueeze(0).to("cuda") - with torch.no_grad(): - embedding1 = self.speaker_encoder(audio1) - embedding2 = self.speaker_encoder(audio2) - sim = torch.nn.functional.cosine_similarity(embedding1, embedding2, dim=1) - return sim.item() - - -class FID: - pass - - -class LibriSpeechDevDataset(torch.utils.data.Dataset): - def __init__(self, data_dir=None, use_vocos=False): - self.data_dir = "/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/*/*" - self.wav_list = glob.glob(self.data_dir + "/*.flac") + glob.glob( - self.data_dir + "/*.wav" - ) - random.shuffle(self.wav_list) - - self.transcript_file = glob.glob(self.data_dir + "/*.txt") - self.transcripts = {} - for f_transcript in self.transcript_file: - with open(f_transcript, "r") as f: - for line in f: - line = line.strip().split() - self.transcripts[line[0]] = " ".join(line[1:]) - - def __len__(self): - return len(self.wav_list) - - def __getitem__(self, idx): - wav_file = self.wav_list[idx] - transcript = self.transcripts[os.path.basename(wav_file)[:-5]] - orig_transcript = transcript - transcript = g2p(transcript, "en")[1] - transcript = torch.tensor(transcript, dtype=torch.long) - - speech, _ = librosa.load(wav_file, sr=SAMPLE_RATE) - speech = torch.tensor(speech, dtype=torch.float32) - - return { - "speech": speech, - "phone_ids": transcript, - "transcript": orig_transcript, - "target_transcript": orig_transcript, - "output_path": os.path.basename(wav_file)[:-5] + ".wav", - } - - -import json - - -class LibriSpeechTestDataset(torch.utils.data.Dataset): - def __init__(self, data_dir=None, use_vocos=False): - self.data_dir = "/mnt/petrelfs/hehaorui/jiaqi/vc-dev/Wave16k16bNormalized" - self.wav_list = [] - self.transcripts = {} - self.target_transcripts = {} - - # load json file - with open( - "/mnt/petrelfs/hehaorui/jiaqi/vc-dev/librispeech_ref_dur_3_test_full_with_punc_wdata.json", - "r", - ) as f: - json_data = f.read() - data = json.loads(json_data) - - test_data = data["test_cases"] - - self.output_path = [] - for wav_info in test_data: - wav_path = os.path.join(self.data_dir, wav_info["wav_path"].split("/")[-1]) - self.wav_list.append(wav_path) - # print(wav_info["wav_path"]) - wav_path = wav_info["wav_path"].split("/")[-1][:-4] - self.transcripts[wav_path] = ( - wav_info["text"] + " " + wav_info["target_text"] - ) - self.target_transcripts[wav_path] = wav_info["target_text"] - # print(self.transcripts[wav_path]) - output_file_name = wav_info["uid"] + ".wav" - self.output_path.append(output_file_name) - - def __len__(self): - return len(self.wav_list) - - def __getitem__(self, idx): - wav_file = self.wav_list[idx] - transcript = self.transcripts[os.path.basename(wav_file)[:-4]] - target_transcript = self.target_transcripts[os.path.basename(wav_file)[:-4]] - # remove punctuation - transcript = "".join( - e for e in transcript if e.isalnum() or e.isspace() - ).lower() - orig_transcript = transcript - transcript = g2p(transcript, "en")[1] - # transcript = [LANG2CODE['en']] + transcript - transcript = torch.tensor(transcript, dtype=torch.long) - - speech, _ = librosa.load(wav_file, sr=SAMPLE_RATE) - speech = torch.tensor(speech, dtype=torch.float32) - - return { - "speech": speech, # prompt speech. do not include gt - "phone_ids": transcript, - "orig_transcript": orig_transcript, - "target_transcript": target_transcript, - "output_path": self.output_path[idx], - } - - -def test(): - dataset = LibriSpeechDevDataset() - # dataset = LibriSpeechTestDataset() - from .valle_inference import ValleInference - - inference = ValleInference( - use_vocos=False, - use_speechtokenizer=True, - ar_path="/mnt/petrelfs/hehaorui/jiaqi/vc-dev/ckpt/valle_v2/ar_mls_speechtokenizer/checkpoint/epoch-0004_step-0190000_loss-0.813551/pytorch_model.bin", - nar_path="/mnt/petrelfs/hehaorui/jiaqi/AmphionVALLEv2/ckpt/valle_v2/nar_mls_speechtokenizer/checkpoint/epoch-0001_step-0164000_loss-1.848536/pytorch_model.bin", - ) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) - if test_wer: - wer = WER() - if test_sim: - sim = SIM() - - import tqdm - - wer_scores = [] - similarity_scores = [] - fid_scores = [] - total_cnt = 0 - - shutil.rmtree("infer", ignore_errors=True) - shutil.rmtree("wer_abnormals_output", ignore_errors=True) - os.mkdir("infer") - os.mkdir("wer_abnormals_output") - for num_beams in [1]: - for top_k in [30]: - for top_p in [0.9]: - for repeat_penalty in [1.0]: - for temperature in [0.95]: - - for batch in tqdm.tqdm(dataloader): - if ( - batch["speech"].shape[-1] < 10 * SAMPLE_RATE - or batch["speech"].shape[-1] > 20 * SAMPLE_RATE - ): - continue - # breakpoint() - print(batch["target_transcript"][0].lower()) - chunks = [ - dict( - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_beams=num_beams, - repeat_penalty=repeat_penalty, - max_length=2000, - ) - ] - - if isinstance(dataset, LibriSpeechDevDataset): - output_wav = inference( - batch, chunks, return_prompt=True - ) - else: - output_wav = inference( - batch, chunks, return_prompt=False - ) - - # output_wav = batch['speech'].unsqueeze(0) - - torchaudio.save( - f"infer/{batch['output_path'][0]}", - output_wav[0].cpu(), - SAMPLE_RATE, - ) - print(f"saved to " + f"infer/{batch['output_path'][0]}") - - # breakpoint() - # torchaudio.save('gt.wav', batch['speech'][0].unsqueeze(0).cpu(), SAMPLE_RATE) - - # resample to 16k - output_wav_resampled = torchaudio.functional.resample( - output_wav, orig_freq=SAMPLE_RATE, new_freq=16000 - ) - if test_wer: - # get wer score - wer_score, transcribed, gt_text = wer( - output_wav_resampled.squeeze(0).squeeze(0), - batch["target_transcript"][0], - ) - print(f"WER: {wer_score}") - wer_scores.append(wer_score) - print(f"average wer: {sum(wer_scores)/len(wer_scores)}") - - # if wer_score > 0.1: - # # save - # torchaudio.save(f'wer_abnormals_output/{batch["output_path"][0]}', output_wav[0].cpu(), SAMPLE_RATE) - # # torchaudio.save(f'wer_abnormals_gt/{batch["output_path"][0]}', output_wav[0].cpu(), SAMPLE_RATE) - # with open(f'wer_abnormals_output/{batch["output_path"][0][:-4]}.txt', 'w') as f: - # f.write('target: ') - # f.write(gt_text) - # f.write('\n') - # f.write('transcribed: ') - # f.write(transcribed) - # f.write('\n') - # f.write(f'wer: {wer_score}') - # print(f'target: {batch["target_transcript"][0]}, transcribed: {transcribed.lower()}') - # print(f'wer_abnormals_output/{batch["output_path"][0][:-4]}.txt') - if test_sim: - # get similarity score - batch_speech_resampled = torchaudio.functional.resample( - batch["speech"], - orig_freq=SAMPLE_RATE, - new_freq=16000, - ) - sim_score = sim( - output_wav_resampled.squeeze(0).squeeze(0), - batch_speech_resampled.squeeze(0), - ) - similarity_scores.append(sim_score) - print(f"SIM: {sim_score}") - print( - f"average sim: {sum(similarity_scores)/len(similarity_scores)}" - ) - - -if __name__ == "__main__": - test()