diff --git a/README.md b/README.md index 9c6845b7..024a4f0e 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ - **TTM**: Text to Music (👨‍💻 developing) - more… -In addition to the specific generation tasks, Amphion also includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. +In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are generated by our models. Just enjoy it! @@ -33,6 +33,7 @@ Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are ) ## 🚀 News +- **2024/6/17**: Amphion has a new release for its VALL-E models, it uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md) - **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md) - **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://github.com/open-mmlab/Amphion/assets/33707885/0a6e39e8-d5f1-4288-b0f8-32da5a2d6e96) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md) - **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39) @@ -42,10 +43,10 @@ Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are ### TTS: Text to Speech -- Amphion achieves state-of-the-art performance when compared with existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures: +- Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures: - [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks. - [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning - - [Vall-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes. + - [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes. - [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices. ### SVC: Singing Voice Conversion @@ -139,6 +140,7 @@ We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTIN - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code. - [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design. +- [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design. - [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code. - [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy. - [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks. diff --git a/bins/tts/train.py b/bins/tts/train.py index 241ee933..b1132a07 100644 --- a/bins/tts/train.py +++ b/bins/tts/train.py @@ -11,6 +11,9 @@ from models.tts.vits.vits_trainer import VITSTrainer from models.tts.valle.valle_trainer import VALLETrainer from models.tts.naturalspeech2.ns2_trainer import NS2Trainer +from models.tts.VALLE_V2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR +from models.tts.VALLE_V2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR + from utils.util import load_config @@ -20,6 +23,8 @@ def build_trainer(args, cfg): "VITS": VITSTrainer, "VALLE": VALLETrainer, "NaturalSpeech2": NS2Trainer, + "VALLE_V2_AR": VALLE_V2_AR, + "VALLE_V2_NAR": VALLE_V2_NAR, } trainer_class = supported_trainer[cfg.model_type] @@ -32,6 +37,7 @@ def cuda_relevant(deterministic=False): # TF32 on Ampere and above torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False torch.backends.cudnn.allow_tf32 = True # Deterministic torch.backends.cudnn.deterministic = deterministic @@ -47,6 +53,13 @@ def main(): help="json files for configurations.", required=True, ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="random seed", + required=False, + ) parser.add_argument( "--exp_name", type=str, @@ -57,6 +70,9 @@ def main(): parser.add_argument( "--resume", action="store_true", help="The model name to restore" ) + parser.add_argument( + "--test", action="store_true", default=False, help="Test the model" + ) parser.add_argument( "--log_level", default="warning", help="logging level (debug, info, warning)" ) @@ -72,39 +88,62 @@ def main(): default=None, help="Checkpoint for resume training or finetuning.", ) - - VALLETrainer.add_arguments(parser) + parser.add_argument( + "--resume_from_ckpt_path", + type=str, + default="", + help="Checkpoint for resume training or finetuning.", + ) + # VALLETrainer.add_arguments(parser) args = parser.parse_args() cfg = load_config(args.config) # Data Augmentation - if ( - type(cfg.preprocess.data_augment) == list - and len(cfg.preprocess.data_augment) > 0 - ): - new_datasets_list = [] - for dataset in cfg.preprocess.data_augment: - new_datasets = [ - f"{dataset}_pitch_shift" if cfg.preprocess.use_pitch_shift else None, - ( - f"{dataset}_formant_shift" - if cfg.preprocess.use_formant_shift - else None - ), - f"{dataset}_equalizer" if cfg.preprocess.use_equalizer else None, - f"{dataset}_time_stretch" if cfg.preprocess.use_time_stretch else None, - ] - new_datasets_list.extend(filter(None, new_datasets)) - cfg.dataset.extend(new_datasets_list) - + if hasattr(cfg, "preprocess"): + if hasattr(cfg.preprocess, "data_augment"): + if ( + type(cfg.preprocess.data_augment) == list + and len(cfg.preprocess.data_augment) > 0 + ): + new_datasets_list = [] + for dataset in cfg.preprocess.data_augment: + new_datasets = [ + ( + f"{dataset}_pitch_shift" + if cfg.preprocess.use_pitch_shift + else None + ), + ( + f"{dataset}_formant_shift" + if cfg.preprocess.use_formant_shift + else None + ), + ( + f"{dataset}_equalizer" + if cfg.preprocess.use_equalizer + else None + ), + ( + f"{dataset}_time_stretch" + if cfg.preprocess.use_time_stretch + else None + ), + ] + new_datasets_list.extend(filter(None, new_datasets)) + cfg.dataset.extend(new_datasets_list) + + print("experiment name: ", args.exp_name) # # CUDA settings cuda_relevant() # Build trainer + print(f"Building {cfg.model_type} trainer") trainer = build_trainer(args, cfg) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - trainer.train_loop() + print(f"Start training {cfg.model_type} model") + if args.test: + trainer.test_loop() + else: + trainer.train_loop() if __name__ == "__main__": diff --git a/egs/tts/valle_v2/README.md b/egs/tts/valle_v2/README.md new file mode 100644 index 00000000..cf209b45 --- /dev/null +++ b/egs/tts/valle_v2/README.md @@ -0,0 +1,169 @@ +# VALL-E +## Introduction +This is an unofficial PyTorch implementation of VALL-E, a zero-shot voice cloning model via neural codec language modeling ([paper link](https://arxiv.org/abs/2301.02111)). +If trained properly, this model could match the performance specified in the original paper. + +## Change notes +This is a refined version compared to the first version of VALL-E in Amphion, we have changed the underlying implementation to Llama +to provide better model performance, faster training speed, and more readable codes. +This can be a great tool if you want to learn speech language models and its implementation. + +## Installation requirement + +Set up your environemnt as in Amphion README (you'll need a conda environment, and we recommend using Linux). A GPU is recommended if you want to train this model yourself. +For inferencing our pretrained models, you could generate samples even without a GPU. +To ensure your transformers library can run the code, we recommend additionally running: +```bash +pip install -U transformers==4.41.2 +``` + + + +## Inferencing pretrained VALL-E models +### Download pretrained weights +You need to download our pretrained weights from huggingface. + +Script to download AR and NAR model checkpoint: +```bash +huggingface-cli download amphion/valle valle_ar_mls_196000.bin valle_nar_mls_164000.bin --local-dir ckpts +``` +Script to download codec model (SpeechTokenizer) checkpoint: +```bash +huggingface-cli download amphion/valle speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts +``` + +### Inference in IPython notebook + +We provide our pretrained VALL-E model that is trained on 45k hours MLS dataset. +The "demo.ipynb" file provides a working example of inferencing our pretrained VALL-E model. Give it a try! + +## Examining the model files +Examining the model files of VALL-E is a great way to learn how it works. +We provide examples that allows you to overfit a single batch (so no dataset downloading is required). + +The AR model is essentially a causal language model that "continues" a speech. The NAR model is a modification from the AR model that allows for bidirectional attention. + + +File `valle_ar.py` and `valle_nar.py` in "models/tts/VALLE_V2" folder are models files, these files can be run directly via `python -m models.tts.VALLE_V2.valle_ar` (or `python -m models.tts.VALLE_V2.valle_nar`). +This will invoke a test which overfits it to a single example. + +## Training VALL-E from scratch +### Preparing LibriTTS or LibriTTS-R dataset files + +We have tested our training script on LibriTTS and LibriTTS-R. +You could download LibriTTS-R at [this link](https://www.openslr.org/141/) and LibriTTS at [this link](https://www.openslr.org/60). +The "train-clean-360" split is currently used by our configuration. +You can test dataset.py by run `python -m models.tts.VALLE_V2.libritts_dataset`. + +For your reference, our unzipped dataset files has a file structure like this: +``` +/path/to/LibriTTS_R +├── BOOKS.txt +├── CHAPTERS.txt +├── dev-clean +│ ├── 2412 +│ │ ├── 153947 +│ │ │ ├── 2412_153947_000014_000000.normalized.txt +│ │ │ ├── 2412_153947_000014_000000.original.txt +│ │ │ ├── 2412_153947_000014_000000.wav +│ │ │ ├── 2412_153947_000017_000001.normalized.txt +│ │ │ ├── 2412_153947_000017_000001.original.txt +│ │ │ ├── 2412_153947_000017_000001.wav +│ │ │ ├── 2412_153947_000017_000005.normalized.txt +├── train-clean-360 + ├── 422 +│ │ └── 122949 +│ │ ├── 422_122949_000009_000007.normalized.txt +│ │ ├── 422_122949_000009_000007.original.txt +│ │ ├── 422_122949_000009_000007.wav +│ │ ├── 422_122949_000013_000010.normalized.txt +│ │ ├── 422_122949_000013_000010.original.txt +│ │ ├── 422_122949_000013_000010.wav +│ │ ├── 422_122949.book.tsv +│ │ └── 422_122949.trans.tsv +``` + + +Alternativelly, you could write your own dataloader for your dataset. +You can reference the `__getitem__` method in `models/tts/VALLE_V2/mls_dataset.py` +It should return a dict of a 1-dimensional tensor 'speech', which is a 16kHz speech; and a 1-dimensional tensor of 'phone', which is the phoneme sequence of the speech. +As long as your dataset returns this in `__getitem__`, it should work. + +### Changing batch size and dataset path in configuration file +Our configuration file for training VALL-E AR model is at "egs/tts/VALLE_V2/exp_ar_libritts.json", and NAR model at "egs/tts/VALLE_V2/exp_nar_libritts.json" + +To train your model, you need to modify the `dataset` variable in the json configurations. +Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory. +``` + "dataset": { + "dataset_list":["train-clean-360"], // You can also change to other splits like "dev-clean" + "data_dir": "/path/to/your/LibriTTS_R", + }, +``` + +You should also select a reasonable batch size at the "batch_size" entry (currently it's set at 5). + + +You can change other experiment settings in the `/egs/tts/VALLE_V2/exp_ar_libritts.json` such as the learning rate, optimizer and the dataset. + +Here we choose `libritts` dataset we added and set `use_dynamic_dataset` false. + +Config `use_dynamic_dataset` is used to solve the problem of inconsistent sequence length and improve gpu utilization, here we set it to false for simplicity. + +```json +"dataset": { + "use_dynamic_batchsize": false, + "name": "libritts" + }, +``` + +We also recommend changing "num_hidden_layers" if your GPU memory is limited. + +**Set smaller batch_size if you are out of memory😢😢** + +I used batch_size=3 to successfully run on a single card, if you'r out of memory, try smaller. + +```json + "batch_size": 3, + "max_tokens": 11000, + "max_sentences": 64, + "random_seed": 0 +``` + + +### Run the command to Train AR model +(Make sure your current directory is at the Amphion root directory). +Run: +```sh +sh egs/tts/VALLE_V2/train_ar_libritts.sh +``` +Your model checkpoint could be found in `ckpt/VALLE_V2/ar_libritts/checkpoint/epoch-0000_step-0000000_loss-7.397293/pytorch_model.bin` + + +### Resume from existing checkpoint +Our framework supports resuming from existing checkpoint. + +Run: +```sh +sh egs/tts/VALLE_V2/train_ar_libritts.sh --resume +``` + +### Run the command to Train NAR model +(Make sure your current directory is at the Amphion root directory). +Run: +```sh +sh egs/tts/VALLE_V2/train_nar_libritts.sh +``` + +### Inference your models +Since our inference script is already given, you can change the paths +from our pretrained model to you newly trained models and do the inference. + +## Future plans +- [ ] Support more languages +- [ ] More are coming... diff --git a/egs/tts/valle_v2/demo.ipynb b/egs/tts/valle_v2/demo.ipynb new file mode 100644 index 00000000..324c8307 --- /dev/null +++ b/egs/tts/valle_v2/demo.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('../../..')\n", + "print(os.getcwd()) # Ensure this is you Amphion root path, otherwise change the above path to you amphion root path\n", + "assert os.path.isfile('./README.md') # make sure the current path is Amphion root path\n", + "import sys\n", + "sys.path.append('.')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# put your cheackpoint file (.bin) in the root path of AmphionVALLEv2\n", + "# or use your own pretrained weights\n", + "ar_model_path = 'ckpts/valle_ar_mls_196000.bin' #huggingface-cli download jiaqili3/vallex valle_ar_mls_196000.bin valle_nar_mls_164000.bin --local-dir ckpts\n", + "nar_model_path = 'ckpts/valle_nar_mls_164000.bin'\n", + "speechtokenizer_path = 'ckpts/speechtokenizer_hubert_avg' # huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cpu' # change to 'cuda' if you have gpu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from models.tts.valle_v2.valle_inference import ValleInference\n", + "# change to device='cuda' to use CUDA GPU for fast inference\n", + "# change \"use_vocos\" to True would give better sound quality\n", + "# If you meet problem with network, you could set \"use_vocos=False\", though would give bad quality\n", + "model = ValleInference(ar_path=ar_model_path, nar_path=nar_model_path, speechtokenizer_path=speechtokenizer_path, device=device)\n", + "# model = ValleInference(use_vocos=False, ar_path=ar_model_path, nar_path=nar_model_path, device='cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# prepare inference data\n", + "import librosa\n", + "import torch\n", + "wav, _ = librosa.load('./egs/tts/valle_v2/example.wav', sr=16000)\n", + "wav = torch.tensor(wav, dtype=torch.float32)\n", + "from IPython.display import Audio\n", + "Audio(wav, rate = 16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# The transcript of the prompt part\n", + "prompt_transcript_text = 'and keeping eternity before the eyes'\n", + "\n", + "# Here are the words you want the model to output\n", + "target_transcript_text = 'It presents a unified framework that is inclusive of diverse generation tasks and models with the added bonus of being easily extendable for new applications'\n", + "from models.tts.valle_v2.g2p_processor import G2pProcessor\n", + "g2p = G2pProcessor()\n", + "prompt_transcript = g2p(prompt_transcript_text, 'en')[1]\n", + "target_transcript = g2p(target_transcript_text, 'en')[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "prompt_transcript = torch.tensor(prompt_transcript).long()\n", + "target_transcript = torch.tensor(target_transcript).long()\n", + "transcript = torch.cat([prompt_transcript, target_transcript], dim=-1)\n", + "batch = {\n", + " 'speech': wav.unsqueeze(0),\n", + " 'phone_ids': transcript.unsqueeze(0),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'speech': tensor([[ 3.0518e-05, 3.0518e-05, 3.0518e-05, ..., -3.0518e-05,\n", + " -3.0518e-05, 3.0518e-05]]),\n", + " 'phone_ids': tensor([[ 5, 28, 149, 72, 219, 134, 127, 170, 115, 147, 219, 113, 185, 91,\n", + " 149, 30, 185, 123, 219, 65, 115, 106, 43, 172, 219, 73, 29, 219,\n", + " 59, 214, 6, 5, 116, 181, 219, 168, 173, 124, 218, 82, 149, 185,\n", + " 175, 219, 28, 219, 210, 200, 149, 30, 106, 64, 72, 219, 104, 173,\n", + " 100, 143, 209, 94, 135, 219, 73, 24, 181, 219, 116, 214, 219, 113,\n", + " 149, 136, 140, 200, 179, 115, 205, 219, 31, 205, 219, 71, 58, 206,\n", + " 91, 175, 219, 131, 85, 149, 88, 100, 178, 30, 145, 219, 180, 24,\n", + " 179, 136, 175, 219, 28, 149, 72, 219, 141, 15, 76, 30, 140, 214,\n", + " 219, 207, 118, 74, 219, 73, 29, 219, 22, 76, 30, 72, 219, 65,\n", + " 155, 149, 30, 175, 219, 31, 205, 219, 65, 127, 115, 147, 219, 125,\n", + " 218, 30, 140, 123, 219, 83, 136, 179, 185, 82, 149, 76, 30, 67,\n", + " 30, 139, 219, 104, 43, 172, 219, 144, 199, 219, 25, 170, 140, 30,\n", + " 136, 100, 178, 30, 149, 214, 6]])}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# print the contents of the model input\n", + "# `phone_ids` contains a concatenation of `prompt_transcript` and `target_transcript` \n", + "batch" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "configs = [dict(\n", + " top_p=0.9,\n", + " top_k=5,\n", + " temperature=0.95,\n", + " repeat_penalty=1.0,\n", + " max_length=2000,\n", + " num_beams=1,\n", + ")] # model inference hyperparameters\n", + "output_wav = model(batch, configs)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-1.2337e-06, -1.2981e-05, -4.0130e-05, ..., -4.1360e-05,\n", + " 1.1917e-05, -4.2949e-05]]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_wav # The output wav is a tensor of shape [1,1,T]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "prompt_transcript : and keeping eternity before the eyes\n", + "target_transcript : It presents a unified framework that is inclusive of diverse generation tasks and models with the added bonus of being easily extendable for new applications\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(f'prompt_transcript : {prompt_transcript_text}')\n", + "print(f'target_transcript : {target_transcript_text}')\n", + "Audio(output_wav.squeeze(0), rate = 16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import torchaudio\n", + "torchaudio.save('out.wav', output_wav.squeeze(0), 24000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "amphion", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/egs/tts/valle_v2/example.wav b/egs/tts/valle_v2/example.wav new file mode 100644 index 00000000..0484674d Binary files /dev/null and b/egs/tts/valle_v2/example.wav differ diff --git a/egs/tts/valle_v2/exp_ar_libritts.json b/egs/tts/valle_v2/exp_ar_libritts.json new file mode 100644 index 00000000..81bd8a87 --- /dev/null +++ b/egs/tts/valle_v2/exp_ar_libritts.json @@ -0,0 +1,55 @@ +{ + "model_type": "VALLE_V2_AR", + "log_dir": "./ckpt/VALLE_V2", + "use_speechtokenizer": true, + "train": { + "gradient_accumulation_step": 1, + "find_unused_parameters": false, + "tracker": ["tensorboard"], + "max_epoch": 1000, + "save_checkpoint_stride": [500], + "keep_last": [1], + "run_eval": [true], + "dataloader": { + "num_worker": 4, + "pin_memory": true, + "persistent_workers": true + }, + "dataset": { + "use_dynamic_batchsize": false, + "name": "libritts" + }, + "optimizer": "adamW", + "adamw": { + "lr": 1e-4 + }, + "scheduler": { + "warmup_steps": 25000, + "total_steps": 800000, + "min_lr": 1e-5 + }, + "exponentiallr": { + "gamma": 0.999999 + }, + "batch_size": 5, + "max_tokens": 5000, + "max_sentences": 64, + "random_seed": 0 + }, + "dataset": { + "dataset_list":["train-clean-360"], + "data_dir": "/path/to/your/libritts" // You can also change to other splits like "dev-clean" + }, + "model": { + "phone_vocab_size": 300, + "target_vocab_size": 1024, + "pad_token_id": 1324, + "bos_target_id": 1325, + "eos_target_id": 1326, + "bos_phone_id": 1327, + "eos_phone_id": 1328, + "bos_prompt_id": 1329, + "eos_prompt_id": 1330, + "num_hidden_layers": 16 + } + } diff --git a/egs/tts/valle_v2/exp_nar_libritts.json b/egs/tts/valle_v2/exp_nar_libritts.json new file mode 100644 index 00000000..bba7beb7 --- /dev/null +++ b/egs/tts/valle_v2/exp_nar_libritts.json @@ -0,0 +1,55 @@ +{ + "model_type": "VALLE_V2_NAR", + "log_dir": "./ckpt/VALLE_V2", + "use_speechtokenizer": true, + "train": { + "gradient_accumulation_step": 1, + "find_unused_parameters": true, + "tracker": ["tensorboard"], + "max_epoch": 1000, + "save_checkpoint_stride": [500], + "keep_last": [1], + "run_eval": [true], + "dataloader": { + "num_worker": 4, + "pin_memory": true, + "persistent_workers": true + }, + "dataset": { + "use_dynamic_batchsize": false, + "name": "libritts" + }, + "optimizer": "adamw", + "adamw": { + "lr": 1e-4 + }, + "scheduler": { + "warmup_steps": 25000, + "total_steps": 800000, + "min_lr": 1e-5 + }, + "exponentiallr": { + "gamma": 0.999999 + }, + "batch_size": 5, + "max_tokens": 10500, + "max_sentences": 64, + "random_seed": 0 + }, + "dataset": { + "dataset_list":["train-clean-360"], + "data_dir": "/path/to/your/libritts" // You can also change to other splits like "dev-clean" + }, + "model": { + "phone_vocab_size": 300, + "target_vocab_size": 1024, + "pad_token_id": 1324, + "bos_target_id": 1325, + "eos_target_id": 1326, + "bos_phone_id": 1327, + "eos_phone_id": 1328, + "bos_prompt_id": 1329, + "eos_prompt_id": 1330, + "num_hidden_layers": 16 + } + } diff --git a/egs/tts/valle_v2/train_ar_libritts.sh b/egs/tts/valle_v2/train_ar_libritts.sh new file mode 100644 index 00000000..2166551e --- /dev/null +++ b/egs/tts/valle_v2/train_ar_libritts.sh @@ -0,0 +1,27 @@ +export PYTHONPATH="./" + +######## Build Experiment Environment ########### +exp_dir="./egs/tts/VALLE_V2" +echo exp_dir: $exp_dir +work_dir="./" # Amphion root folder +echo work_dir: $work_dir + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +######## Set Config File Dir ############## +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}"/exp_ar_libritts.json +fi +echo "Exprimental Configuration File: $exp_config" + +######## Set the experiment name ########## +exp_name="ar_libritts" + +port=53333 # a random number for port + +######## Train Model ########### +echo "Experiment Name: $exp_name" +accelerate launch --main_process_port $port "${work_dir}"/bins/tts/train.py --config $exp_config \ +--exp_name $exp_name --log_level debug $1 diff --git a/egs/tts/valle_v2/train_nar_libritts.sh b/egs/tts/valle_v2/train_nar_libritts.sh new file mode 100644 index 00000000..d753854c --- /dev/null +++ b/egs/tts/valle_v2/train_nar_libritts.sh @@ -0,0 +1,27 @@ +export PYTHONPATH="./" + +######## Build Experiment Environment ########### +exp_dir="./egs/tts/VALLE_V2" +echo exp_dir: $exp_dir +work_dir="./" # Amphion root folder +echo work_dir: $work_dir + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +######## Set Config File Dir ############## +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}"/exp_nar_libritts.json +fi +echo "Exprimental Configuration File: $exp_config" + +######## Set the experiment name ########## +exp_name="nar_libritts" + +port=17004 # a random number for port + +######## Train Model ########### +echo "Experimental Name: $exp_name" +accelerate launch --main_process_port $port "${work_dir}"/bins/tts/train.py --config $exp_config \ +--exp_name $exp_name --log_level debug $1 diff --git a/env.sh b/env.sh index 10ef7ff1..5e65a066 100644 --- a/env.sh +++ b/env.sh @@ -12,7 +12,9 @@ conda install -c conda-forge ffmpeg # Pip packages pip install setuptools ruamel.yaml tqdm colorama easydict tabulate loguru json5 Cython unidecode inflect argparse g2p_en tgt librosa==0.9.1 matplotlib typeguard einops omegaconf hydra-core humanfriendly pandas -pip install tensorboard tensorboardX torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 accelerate==0.24.1 transformers diffusers praat-parselmouth audiomentations pedalboard ffmpeg-python==0.2.0 pyworld diffsptk==1.0.1 nnAudio unidecode inflect ptwt +pip install tensorboard tensorboardX torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 accelerate==0.24.1 transformers==4.41.2 diffusers praat-parselmouth audiomentations pedalboard ffmpeg-python==0.2.0 pyworld diffsptk==1.0.1 nnAudio unidecode inflect ptwt + +pip install encodec vocos speechtokenizer g2p_en pip install torchmetrics pymcd openai-whisper frechet_audio_distance asteroid resemblyzer vector-quantize-pytorch==1.12.5 diff --git a/models/codec/speechtokenizer/model.py b/models/codec/speechtokenizer/model.py new file mode 100644 index 00000000..b722d386 --- /dev/null +++ b/models/codec/speechtokenizer/model.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 Amphion. +# +# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py +# Licensed under Apache License 2.0 + +from .modules.seanet import SEANetEncoder, SEANetDecoder +from .modules.quantization import ResidualVectorQuantizer +import torch.nn as nn +from einops import rearrange +import torch +import numpy as np + + +class SpeechTokenizer(nn.Module): + def __init__(self, config): + """ + + Parameters + ---------- + config : json + Model Config. + + """ + super().__init__() + self.encoder = SEANetEncoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=config.get("bidirectional"), + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + self.sample_rate = config.get("sample_rate") + self.n_q = config.get("n_q") + self.downsample_rate = np.prod(config.get("strides")) + if config.get("dimension") != config.get("semantic_dimension"): + self.transform = nn.Linear( + config.get("dimension"), config.get("semantic_dimension") + ) + else: + self.transform = nn.Identity() + self.quantizer = ResidualVectorQuantizer( + dimension=config.get("dimension"), + n_q=config.get("n_q"), + bins=config.get("codebook_size"), + ) + self.decoder = SEANetDecoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=False, + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + + @classmethod + def load_from_checkpoint(cls, config_path: str, ckpt_path: str): + """ + + Parameters + ---------- + config_path : str + Path of model configuration file. + ckpt_path : str + Path of model checkpoint. + + Returns + ------- + model : SpeechTokenizer + SpeechTokenizer model. + + """ + import json + + with open(config_path) as f: + cfg = json.load(f) + model = cls(cfg) + params = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(params) + return model + + def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + layers : list[int], optional + Layers of RVQ should return quantized result. The default is the first layer. + + Returns + ------- + o : torch.tensor + Output wavs. Shape: (batch, channels, timesteps). + commit_loss : torch.tensor + Commitment loss from residual vector quantizers. + feature : torch.tensor + Output of RVQ's first layer. Shape: (batch, timesteps, dimension) + + """ + n_q = n_q if n_q else self.n_q + e = self.encoder(x) + quantized, codes, commit_loss, quantized_list = self.quantizer( + e, n_q=n_q, layers=layers + ) + feature = rearrange(quantized_list[0], "b d t -> b t d") + feature = self.transform(feature) + o = self.decoder(quantized) + return o, commit_loss, feature + + def forward_feature(self, x: torch.tensor, layers: list = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape should be (batch, channels, timesteps). + layers : list[int], optional + Layers of RVQ should return quantized result. The default is all layers. + + Returns + ------- + quantized_list : list[torch.tensor] + Quantized of required layers. + + """ + e = self.encoder(x) + layers = layers if layers else list(range(self.n_q)) + quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) + return quantized_list + + def encode(self, x: torch.tensor, n_q: int = None, st: int = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + codes : torch.tensor + Output indices for each quantizer. Shape: (n_q, batch, timesteps) + + """ + e = self.encoder(x) + if st is None: + st = 0 + n_q = n_q if n_q else self.n_q + codes = self.quantizer.encode(e, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.tensor, st: int = 0): + """ + + Parameters + ---------- + codes : torch.tensor + Indices for each quantizer. Shape: (n_q, batch, timesteps). + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + o : torch.tensor + Reconstruct wavs from codes. Shape: (batch, channels, timesteps) + + """ + quantized = self.quantizer.decode(codes, st=st) + o = self.decoder(quantized) + return o diff --git a/models/codec/speechtokenizer/modules/__init__.py b/models/codec/speechtokenizer/modules/__init__.py new file mode 100644 index 00000000..0581347c --- /dev/null +++ b/models/codec/speechtokenizer/modules/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder diff --git a/models/codec/speechtokenizer/modules/conv.py b/models/codec/speechtokenizer/modules/conv.py new file mode 100644 index 00000000..0352b8bf --- /dev/null +++ b/models/codec/speechtokenizer/modules/conv.py @@ -0,0 +1,346 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset( + [ + "none", + "weight_norm", + "spectral_norm", + "time_layer_norm", + "layer_norm", + "time_group_norm", + ] +) + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "zero", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn( + "SConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + padding_total = (kernel_size - 1) * dilation - (stride - 1) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/models/codec/speechtokenizer/modules/lstm.py b/models/codec/speechtokenizer/modules/lstm.py new file mode 100644 index 00000000..7f7e4312 --- /dev/null +++ b/models/codec/speechtokenizer/modules/lstm.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__( + self, + dimension: int, + num_layers: int = 2, + skip: bool = True, + bidirectional: bool = False, + ): + super().__init__() + self.bidirectional = bidirectional + self.skip = skip + self.lstm = nn.LSTM( + dimension, dimension, num_layers, bidirectional=bidirectional + ) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.bidirectional: + x = x.repeat(1, 1, 2) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/models/codec/speechtokenizer/modules/norm.py b/models/codec/speechtokenizer/modules/norm.py new file mode 100644 index 00000000..ff5eaefd --- /dev/null +++ b/models/codec/speechtokenizer/modules/norm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + + def __init__( + self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs + ): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, "b ... t -> b t ...") + x = super().forward(x) + x = einops.rearrange(x, "b t ... -> b ... t") + return diff --git a/models/codec/speechtokenizer/modules/quantization/__init__.py b/models/codec/speechtokenizer/modules/quantization/__init__.py new file mode 100644 index 00000000..79d90a1a --- /dev/null +++ b/models/codec/speechtokenizer/modules/quantization/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/models/codec/speechtokenizer/modules/quantization/ac.py b/models/codec/speechtokenizer/modules/quantization/ac.py new file mode 100644 index 00000000..5695ea84 --- /dev/null +++ b/models/codec/speechtokenizer/modules/quantization/ac.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: torch.Tensor, + total_range_bits: int, + roundoff: float = 1e-8, + min_range: int = 2, + check: bool = True, +) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ( + (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range + ).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, ( + effective_low, + effective_high, + range_low, + range_high, + ) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/models/codec/speechtokenizer/modules/quantization/core_vq.py b/models/codec/speechtokenizer/modules/quantization/core_vq.py new file mode 100644 index 00000000..57997255 --- /dev/null +++ b/models/codec/speechtokenizer/modules/quantization/core_vq.py @@ -0,0 +1,388 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + # broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward( + self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None + ): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + out_quantized = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + if layers and i in layers: + out_quantized.append(quantized) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses, out_quantized + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/models/codec/speechtokenizer/modules/quantization/distrib.py b/models/codec/speechtokenizer/modules/quantization/distrib.py new file mode 100644 index 00000000..7b9a9b83 --- /dev/null +++ b/models/codec/speechtokenizer/modules/quantization/distrib.py @@ -0,0 +1,135 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + # src = int(rank()) # added code + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True + ) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True + ) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/models/codec/speechtokenizer/modules/quantization/vq.py b/models/codec/speechtokenizer/modules/quantization/vq.py new file mode 100644 index 00000000..ec7df0f9 --- /dev/null +++ b/models/codec/speechtokenizer/modules/quantization/vq.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, + x: torch.Tensor, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + layers (list): Layer that need to return quantized. Defalt: None. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated numbert quantizers and layer quantized required to return. + """ + n_q = n_q if n_q else self.n_q + if layers and max(layers) >= n_q: + raise ValueError( + f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." + ) + quantized, codes, commit_loss, quantized_list = self.vq( + x, n_q=n_q, layers=layers + ) + return quantized, codes, torch.mean(commit_loss), quantized_list + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + st (int): Start to encode input from which layers. Default: 0. + """ + n_q = n_q if n_q else self.n_q + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: + """Decode the given codes to the quantized representation. + Args: + codes (torch.Tensor): Input indices for each quantizer. + st (int): Start to decode input codes from which layers. Default: 0. + """ + quantized = self.vq.decode(codes, st=st) + return quantized diff --git a/models/codec/speechtokenizer/modules/seanet.py b/models/codec/speechtokenizer/modules/seanet.py new file mode 100644 index 00000000..481de20c --- /dev/null +++ b/models/codec/speechtokenizer/modules/seanet.py @@ -0,0 +1,414 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn +import torch + +from . import SConv1d, SConvTranspose1d, SLSTM + + +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) if activation != "Snake" else Snake1d + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params) if activation != "Snake" else act(in_chs), + SConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + bidirectional: bool = False, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) # 计算乘积 + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d( + channels, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + mult = mult * 2 if bidirectional else mult + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + trim_right_ratio: float = 1.0, + bidirectional: bool = False, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params) if activation != "Snake" else act(n_filters), + SConv1d( + n_filters, + channels, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + print("z ", z.shape) + assert 1 == 2 + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == "__main__": + test() diff --git a/models/tts/valle_v2/base_trainer.py b/models/tts/valle_v2/base_trainer.py new file mode 100644 index 00000000..2f7f97b0 --- /dev/null +++ b/models/tts/valle_v2/base_trainer.py @@ -0,0 +1,810 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import random +import shutil +import time +from abc import abstractmethod +from pathlib import Path +import math +import accelerate +import json5 +import numpy as np +import torch +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm + +from models.base.base_sampler import build_samplers +from optimizer.optimizers import NoamLR + + +class MainProcessLogger: + def __init__(self, is_main_process=True, name=None, **kwargs): + import logging + + if name is None: + logger = logging.getLogger(__name__) + else: + logger = logging.getLogger(name) + self.logger = logger + self.is_main_process = is_main_process + + def info(self, msg): + if self.is_main_process: + print(msg) + # self.logger.info(msg) + + def debug(self, msg): + if self.is_main_process: + print(msg) + # self.logger.debug(msg) + + def warning(self, msg): + if self.is_main_process: + print(msg) + # self.logger.warning(msg) + + +class BaseTrainer(object): + r"""The base trainer for all tasks. Any trainer should inherit from this class.""" + + def __init__(self, args=None, cfg=None): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # init with accelerate + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Use accelerate logger for distributed training + with self.accelerator.main_process_first(): + self.logger = MainProcessLogger( + self.accelerator.is_main_process, + name=args.exp_name, + log_level=args.log_level, + ) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # init counts + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check values + if self.accelerator.is_main_process: + self.__check_basic_configs() + # Set runtime configs + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.keep_last = [ + i if i > 0 else float("inf") for i in self.cfg.train.keep_last + ] + self.run_eval = self.cfg.train.run_eval + + # set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(args.seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {args.seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info( + f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M" + ) + # optimizer & scheduler + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + self.optimizer = self._build_optimizer() + self.scheduler = self._build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # accelerate prepare + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self._accelerator_prepare() + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # create criterion + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterion = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume or Finetune + with self.accelerator.main_process_first(): + if args.resume: + if args.resume_from_ckpt_path == "": + ## Automatically resume according to the current exprimental name + self.logger.info( + "Automatically resuming from latest checkpoint in {}...".format( + self.checkpoint_dir + ) + ) + start = time.monotonic_ns() + ckpt_path = self._load_model( + checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + else: + ## Resume from the given checkpoint path + if not os.path.exists(args.resume_from_ckpt_path): + raise ValueError( + "[Error] The resumed checkpoint path {} don't exist.".format( + args.resume_from_ckpt_path + ) + ) + self.logger.info( + "Resuming from {}...".format(args.resume_from_ckpt_path) + ) + start = time.monotonic_ns() + ckpt_path = self._load_model( + checkpoint_path=args.resume_from_ckpt_path, + resume_type=args.resume_type, + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + + # save config file path + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + def _accelerator_prepare(self): + ( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) + + ### Following are abstract methods that should be implemented in child classes ### + @abstractmethod + def _build_dataset(self): + r"""Build dataset for model training/validating/evaluating.""" + pass + + @staticmethod + @abstractmethod + def _build_criterion(): + r"""Build criterion function for model loss calculation.""" + pass + + @abstractmethod + def _build_model(self): + r"""Build model for training/validating/evaluating.""" + pass + + @abstractmethod + def _forward_step(self, batch): + r"""One forward step of the neural network. This abstract method is trying to + unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation. + However, for special case that using different forward step pattern for + training and validating, you could just override this method with ``pass`` and + implement ``_train_step`` and ``_valid_step`` separately. + """ + pass + + def save_checkpoint(self): + if self.accelerator.is_main_process: + keep_last = self.keep_last[0] + # 读取self.checkpoint_dir所有的folder + all_ckpts = os.listdir(self.checkpoint_dir) + all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts) + all_ckpts = list(all_ckpts) + if len(all_ckpts) > keep_last: + # 只保留keep_last个的folder in self.checkpoint_dir, sort by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}" + all_ckpts = sorted( + all_ckpts, key=lambda x: int(x.split("_")[1].split("-")[1]) + ) + for ckpt in all_ckpts[:-keep_last]: + shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt)) + checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, self.current_loss + ) + path = os.path.join(self.checkpoint_dir, checkpoint_filename) + self.logger.info("Saving state to {}...".format(path)) + self.accelerator.save_state(path) + self.logger.info("Finished saving state.") + + @abstractmethod + def _save_auxiliary_states(self): + r"""To save some auxiliary states when saving model's ckpt""" + pass + + def echo_log(self, losses, mode="Training"): + message = [ + "{} - Epoch {} Step {}: [{:.3f} s/step]".format( + mode, self.epoch + 1, self.step, self.time_window.average + ) + ] + + for key in sorted(losses.keys()): + if isinstance(losses[key], dict): + for k, v in losses[key].items(): + message.append( + str(k).split("/")[-1] + "=" + str(round(float(v), 5)) + ) + else: + message.append( + str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) + ) + self.logger.info(", ".join(message)) + + ### Abstract methods end ### + + ### THIS IS MAIN ENTRY ### + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + self.model.train() + self.optimizer.zero_grad() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) + ### It's inconvenient for the model with multiple losses + # Do training & validating epoch + train_loss = self._train_epoch() + self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) + valid_loss = self._valid_epoch() + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) + self.accelerator.log( + {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_loss + ), + ) + ) + self._save_auxiliary_states() + + self.accelerator.end_training() + + def get_lr(self, it): + # 1) linear warmup for warmup_iters steps + if it < self.cfg.train.scheduler.warmup_steps: + return self.cfg.train.adamw.lr * it / self.cfg.train.scheduler.warmup_steps + # 2) if it > lr_decay_iters, return min learning rate + if it > self.cfg.train.scheduler.total_steps: + return self.cfg.train.scheduler.min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - self.cfg.train.scheduler.warmup_steps) / ( + self.cfg.train.scheduler.total_steps - self.cfg.train.scheduler.warmup_steps + ) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return self.cfg.train.scheduler.min_lr + coeff * ( + self.cfg.train.adamw.lr - self.cfg.train.scheduler.min_lr + ) + + ### Following are methods that can be used directly in child classes ### + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.train() + epoch_sum_loss: float = 0.0 + ema_loss = None + + # profiler + start_this_step_time = time.time() + finish_last_step_time = time.time() + + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + assert batch is not None + + # start_this_step_time = time.time() + # print(f'load batch took: {start_this_step_time - finish_last_step_time:.6f}s') + + # update learning rate + lr = self.get_lr(self.step) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + # Do training step and BP + with self.accelerator.accumulate(self.model): + loss = self._train_step(batch) + self.current_loss = loss.item() + ema_loss = ( + 0.99 * ema_loss + 0.01 * self.current_loss + if ema_loss is not None + else self.current_loss + ) + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.optimizer.zero_grad() + self.batch_count += 1 + + # if self.accelerator.is_main_process: + # print(self.current_loss) + + if self.accelerator.sync_gradients: + if self.step % self.cfg.train.save_checkpoint_stride[0] == 0: + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + try: + self.save_checkpoint() + except: + self.logger.info("Failed to save checkpoint, resuming...") + if self.accelerator.is_main_process: + if self.step % 100 == 0: + self.logger.info(f"EMA Loss: {ema_loss:.6f}") + self.accelerator.log( + { + "Step/Train Loss": loss, + "Step/Learning Rate": self.optimizer.param_groups[0]["lr"], + }, + step=self.step, + ) + epoch_sum_loss += loss + self.step += 1 + + # finish_last_step_time = time.time() + # print(f'load took: {finish_last_step_time - start_this_step_time:.6f}s') + return ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.eval() + epoch_sum_loss = 0.0 + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + batch_loss = self._valid_step(batch) + epoch_sum_loss += batch_loss.item() + + return epoch_sum_loss / len(self.valid_dataloader) + + def _train_step(self, batch): + r"""Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + return self._forward_step(batch) + + @torch.inference_mode() + def _valid_step(self, batch): + r"""Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + return self._forward_step(batch) + + def _load_model( + self, + checkpoint_dir: str = None, + checkpoint_path: str = None, + resume_type: str = "", + ): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + try: + all_ckpts = os.listdir(checkpoint_dir) + all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts) + ls = list(all_ckpts) + ls = [os.path.join(checkpoint_dir, i) for i in ls] + ls.sort( + key=lambda x: int(x.split("_")[-2].split("-")[-1]), reverse=True + ) + checkpoint_path = ls[0] + self.logger.info("Resume from {}".format(checkpoint_path)) + except Exception as e: + print( + "Failed to load checkpoint from {}, starting FROM SCRATCH...".format( + checkpoint_dir + ) + ) + return None + + if resume_type in ["resume", ""]: + # Load all the things, including model weights, optimizer, scheduler, and random states. + self.accelerator.load_state(input_dir=checkpoint_path) + + # set epoch and step + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + + elif resume_type == "finetune": + # Load only the model weights + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune...") + + else: + raise ValueError("Resume_type must be `resume` or `finetune`.") + + return checkpoint_path + + # TODO: LEGACY CODE + def _build_dataloader(self): + Dataset, Collator = self._build_dataset() + + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + train_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train") + self.logger.debug(f"train batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {train_dataset.cumulative_sizes}") + # TODO: use config instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + + # Build valid dataloader + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid") + self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {valid_dataset.cumulative_sizes}") + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, valid_loader + + @staticmethod + def _set_random_seed(seed): + r"""Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss, y_pred, y_gt): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: Training is down since loss has Nan!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + if torch.any(torch.isnan(y_pred)): + self.logger.error( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + else: + self.logger.debug( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + if torch.any(torch.isnan(y_gt)): + self.logger.error( + f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + else: + self.logger.debug( + f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + if torch.any(torch.isnan(y_pred)): + self.logger.error(f"y_pred: {y_pred}", in_order=True) + else: + self.logger.debug(f"y_pred: {y_pred}", in_order=True) + if torch.any(torch.isnan(y_gt)): + self.logger.error(f"y_gt: {y_gt}", in_order=True) + else: + self.logger.debug(f"y_gt: {y_gt}", in_order=True) + + # TODO: still OK to save tracking? + self.accelerator.end_training() + raise RuntimeError("Loss has Nan! See log for more info.") + + ### Protected methods end ### + + ## Following are private methods ## + ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed. + def _build_optimizer(self): + r"""Build optimizer for model.""" + # Make case-insensitive matching + if self.cfg.train.optimizer.lower() == "adadelta": + optimizer = torch.optim.Adadelta( + self.model.parameters(), **self.cfg.train.adadelta + ) + self.logger.info("Using Adadelta optimizer.") + elif self.cfg.train.optimizer.lower() == "adagrad": + optimizer = torch.optim.Adagrad( + self.model.parameters(), **self.cfg.train.adagrad + ) + self.logger.info("Using Adagrad optimizer.") + elif self.cfg.train.optimizer.lower() == "adam": + optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam) + self.logger.info("Using Adam optimizer.") + elif self.cfg.train.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + self.model.parameters(), **self.cfg.train.adamw + ) + elif self.cfg.train.optimizer.lower() == "sparseadam": + optimizer = torch.optim.SparseAdam( + self.model.parameters(), **self.cfg.train.sparseadam + ) + elif self.cfg.train.optimizer.lower() == "adamax": + optimizer = torch.optim.Adamax( + self.model.parameters(), **self.cfg.train.adamax + ) + elif self.cfg.train.optimizer.lower() == "asgd": + optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd) + elif self.cfg.train.optimizer.lower() == "lbfgs": + optimizer = torch.optim.LBFGS( + self.model.parameters(), **self.cfg.train.lbfgs + ) + elif self.cfg.train.optimizer.lower() == "nadam": + optimizer = torch.optim.NAdam( + self.model.parameters(), **self.cfg.train.nadam + ) + elif self.cfg.train.optimizer.lower() == "radam": + optimizer = torch.optim.RAdam( + self.model.parameters(), **self.cfg.train.radam + ) + elif self.cfg.train.optimizer.lower() == "rmsprop": + optimizer = torch.optim.RMSprop( + self.model.parameters(), **self.cfg.train.rmsprop + ) + elif self.cfg.train.optimizer.lower() == "rprop": + optimizer = torch.optim.Rprop( + self.model.parameters(), **self.cfg.train.rprop + ) + elif self.cfg.train.optimizer.lower() == "sgd": + optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd) + else: + raise NotImplementedError( + f"Optimizer {self.cfg.train.optimizer} not supported yet!" + ) + return optimizer + + def _build_scheduler(self): + r"""Build scheduler for optimizer.""" + # Make case-insensitive matching + if self.cfg.train.scheduler.lower() == "lambdalr": + scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, **self.cfg.train.lambdalr + ) + elif self.cfg.train.scheduler.lower() == "multiplicativelr": + scheduler = torch.optim.lr_scheduler.MultiplicativeLR( + self.optimizer, **self.cfg.train.multiplicativelr + ) + elif self.cfg.train.scheduler.lower() == "steplr": + scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, **self.cfg.train.steplr + ) + elif self.cfg.train.scheduler.lower() == "multisteplr": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.optimizer, **self.cfg.train.multisteplr + ) + elif self.cfg.train.scheduler.lower() == "constantlr": + scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, **self.cfg.train.constantlr + ) + elif self.cfg.train.scheduler.lower() == "linearlr": + scheduler = torch.optim.lr_scheduler.LinearLR( + self.optimizer, **self.cfg.train.linearlr + ) + elif self.cfg.train.scheduler.lower() == "exponentiallr": + scheduler = torch.optim.lr_scheduler.ExponentialLR( + self.optimizer, **self.cfg.train.exponentiallr + ) + elif self.cfg.train.scheduler.lower() == "polynomiallr": + scheduler = torch.optim.lr_scheduler.PolynomialLR( + self.optimizer, **self.cfg.train.polynomiallr + ) + elif self.cfg.train.scheduler.lower() == "cosineannealinglr": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, **self.cfg.train.cosineannealinglr + ) + elif self.cfg.train.scheduler.lower() == "sequentiallr": + scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, **self.cfg.train.sequentiallr + ) + elif self.cfg.train.scheduler.lower() == "reducelronplateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, **self.cfg.train.reducelronplateau + ) + elif self.cfg.train.scheduler.lower() == "cycliclr": + scheduler = torch.optim.lr_scheduler.CyclicLR( + self.optimizer, **self.cfg.train.cycliclr + ) + elif self.cfg.train.scheduler.lower() == "onecyclelr": + scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, **self.cfg.train.onecyclelr + ) + elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts": + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, **self.cfg.train.cosineannearingwarmrestarts + ) + elif self.cfg.train.scheduler.lower() == "noamlr": + scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler) + else: + raise NotImplementedError( + f"Scheduler {self.cfg.train.scheduler} not supported yet!" + ) + return scheduler + + def _init_accelerator(self): + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, + logging_dir=os.path.join(self.exp_dir, "log"), + ) + from accelerate import DistributedDataParallelKwargs + + kwargs = DistributedDataParallelKwargs( + find_unused_parameters=self.cfg.train.find_unused_parameters + ) + + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + kwargs_handlers=[kwargs], + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def __check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + # TODO: check other values + + @staticmethod + def __count_parameters(model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + @torch.inference_mode() + def test_loop(self): + pass + + ### Private methods end ### diff --git a/models/tts/valle_v2/g2p_processor.py b/models/tts/valle_v2/g2p_processor.py new file mode 100644 index 00000000..43807fb1 --- /dev/null +++ b/models/tts/valle_v2/g2p_processor.py @@ -0,0 +1,363 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import numpy as np +import os +import torch +import copy +from g2p_en import G2p +import re +import unicodedata +from g2p_en import G2p +from g2p_en.expand import normalize_numbers + +g2p = G2p() + +PHONE_SET = [ + "!", + ",", + ".", + ".B", + ":", + "", + "", + "", + "", + "?", + "AA0B", + "AA0E", + "AA0I", + "AA1B", + "AA1E", + "AA1I", + "AA2B", + "AA2E", + "AA2I", + "AE0B", + "AE0E", + "AE0I", + "AE1B", + "AE1E", + "AE1I", + "AE2B", + "AE2E", + "AE2I", + "AH0B", + "AH0E", + "AH0I", + "AH1B", + "AH1E", + "AH1I", + "AH2B", + "AH2E", + "AH2I", + "AO0B", + "AO0E", + "AO0I", + "AO1", + "AO1B", + "AO1E", + "AO1I", + "AO2B", + "AO2E", + "AO2I", + "AW0B", + "AW0E", + "AW0I", + "AW1B", + "AW1E", + "AW1I", + "AW2B", + "AW2E", + "AW2I", + "AY0B", + "AY0E", + "AY0I", + "AY1B", + "AY1E", + "AY1I", + "AY2B", + "AY2E", + "AY2I", + "BB", + "BE", + "BI", + "CHB", + "CHE", + "CHI", + "DB", + "DE", + "DHB", + "DHE", + "DHI", + "DI", + "EH0B", + "EH0E", + "EH0I", + "EH1B", + "EH1E", + "EH1I", + "EH2B", + "EH2E", + "EH2I", + "ER0B", + "ER0E", + "ER0I", + "ER1B", + "ER1E", + "ER1I", + "ER2B", + "ER2E", + "ER2I", + "EY0B", + "EY0E", + "EY0I", + "EY1B", + "EY1E", + "EY1I", + "EY2B", + "EY2E", + "EY2I", + "FB", + "FE", + "FI", + "GB", + "GE", + "GI", + "HHB", + "HHE", + "HHI", + "IH0B", + "IH0E", + "IH0I", + "IH1B", + "IH1E", + "IH1I", + "IH2B", + "IH2E", + "IH2I", + "IY0B", + "IY0E", + "IY0I", + "IY1B", + "IY1E", + "IY1I", + "IY2B", + "IY2E", + "IY2I", + "JHB", + "JHE", + "JHI", + "KB", + "KE", + "KI", + "L", + "LB", + "LE", + "LI", + "MB", + "ME", + "MI", + "NB", + "NE", + "NGB", + "NGE", + "NGI", + "NI", + "OW0B", + "OW0E", + "OW0I", + "OW1B", + "OW1E", + "OW1I", + "OW2B", + "OW2E", + "OW2I", + "OY0B", + "OY0E", + "OY0I", + "OY1B", + "OY1E", + "OY1I", + "OY2B", + "OY2E", + "OY2I", + "PB", + "PE", + "PI", + "RB", + "RE", + "RI", + "SB", + "SE", + "SHB", + "SHE", + "SHI", + "SI", + "TB", + "TE", + "THB", + "THE", + "THI", + "TI", + "UH0B", + "UH0E", + "UH0I", + "UH1B", + "UH2B", + "UH1E", + "UH1I", + "UH2E", + "UH2I", + "UW0B", + "UW0E", + "UW0I", + "UW1B", + "UW1E", + "UW1I", + "UW2B", + "UW2E", + "UW2I", + "VB", + "VE", + "VI", + "WB", + "WE", + "WI", + "YB", + "YE", + "YI", + "ZB", + "ZE", + "ZHB", + "ZHE", + "ZHI", + "ZI", + "|", +] +PHPONE2ID = {PHONE_SET[i]: i for i in range(len(PHONE_SET))} + +PUNCS = "!,.?;:" + + +def is_sil_phoneme(p): + return p == "" or not p[0].isalpha() + + +def add_bdr(txt_struct): + txt_struct_ = [] + for i, ts in enumerate(txt_struct): + txt_struct_.append(ts) + if ( + i != len(txt_struct) - 1 + and not is_sil_phoneme(txt_struct[i][0]) + and not is_sil_phoneme(txt_struct[i + 1][0]) + ): + txt_struct_.append(["|", ["|"]]) + return txt_struct_ + + +def preprocess_text(text): + text = normalize_numbers(text) + text = "".join( + char + for char in unicodedata.normalize("NFD", text) + if unicodedata.category(char) != "Mn" + ) # Strip accents + text = text.lower() + text = re.sub("['\"()]+", "", text) + text = re.sub("[-]+", " ", text) + text = re.sub(f"[^ a-z{PUNCS}]", "", text) + text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! + text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! + text = text.replace("i.e.", "that is") + text = text.replace("i.e.", "that is") + text = text.replace("etc.", "etc") + text = re.sub(f"([{PUNCS}])", r" ", text) # remove punctuations for now + text = re.sub(rf"\s+", r" ", text) + return text + + +def postprocess(txt_struct): + while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]): + txt_struct = txt_struct[1:] + while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]): + txt_struct = txt_struct[:-1] + txt_struct = add_bdr(txt_struct) + txt_struct = [["", [""]]] + txt_struct + [["", [""]]] + return txt_struct + + +def process(txt, g2p): + txt = preprocess_text(txt).strip() + phs = g2p(txt) + txt_struct = [[w, []] for w in txt.split(" ")] + i_word = 0 + for p in phs: + if p == " ": + i_word += 1 + else: + txt_struct[i_word][1].append(p) + + txt_struct_ret = copy.deepcopy(txt_struct) + + for i_word in range(len(txt_struct)): + if not is_sil_phoneme(txt_struct[i_word][0]): + if len(txt_struct[i_word][1]) > 1: + txt_struct_ret[i_word][1][0] += "B" + for i in range(1, len(txt_struct[i_word][1]) - 1): + txt_struct_ret[i_word][1][i] += "I" + txt_struct_ret[i_word][1][-1] += "E" + else: + txt_struct_ret[i_word][1][0] += "B" + + txt_struct_ret = postprocess(txt_struct_ret) + + return txt_struct_ret, txt + + +def test(): + g2p = G2p() + txt = "This is a test sentence." + txt_struct, txt = process(txt, g2p) + print(txt_struct) + print(txt) + phone_seq = [p for w in txt_struct for p in w[1]] + print(phone_seq) + phone_id = [PHPONE2ID[p] for p in phone_seq] + print(phone_id) + + +class G2pProcessor: + def __init__(self): + self.g2p = G2p() + + def __call__(self, txt, lang="en"): + return self.txt2phoneid(txt) + + def txt2phoneid(self, txt): + txt_struct, txt = process(txt, self.g2p) + phone_seq = [p for w in txt_struct for p in w[1]] + phone_id = [PHPONE2ID[p] for p in phone_seq] + return None, phone_id + + def phoneid2txt(self, phone_id): + txt = [] + for i in phone_id: + txt.append(PHONE_SET[i]) + return txt + + +if __name__ == "__main__": + g2p = G2pProcessor() + txt = "This is a test sentence." + phoneid = g2p.txt2phoneid(txt)[1] + # output: [5, 73, 118, 175, 218, 116, 213, 218, 28, 218, 180, 82, 179, 181, 218, 174, 82, 149, 185, 30, 149, 175, 6] + # print(phoneid) + print(g2p.phoneid2txt(phoneid)) + # output: ['', 'DHB', 'IH1I', 'SE', '|', 'IH1B', 'ZE', '|', 'AH0B', '|', 'TB', 'EH1I', 'SI', 'TE', '|', 'SB', 'EH1I', 'NI', 'TI', 'AH0I', 'NI', 'SE', ''] + print(len(PHONE_SET)) + # output: 219 diff --git a/models/tts/valle_v2/libritts_dataset.py b/models/tts/valle_v2/libritts_dataset.py new file mode 100644 index 00000000..89b40e6f --- /dev/null +++ b/models/tts/valle_v2/libritts_dataset.py @@ -0,0 +1,271 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from tqdm import tqdm +from g2p_en import G2p +import librosa +from torch.utils.data import Dataset +import pandas as pd +import time +import io + +SAMPLE_RATE = 16000 +# g2p +from .g2p_processor import G2pProcessor + +phonemizer_g2p = G2pProcessor() + + +class VALLEDataset(Dataset): + def __init__(self, args): + print(f"Initializing VALLEDataset") + self.dataset_list = args.dataset_list + + print(f"using sampling rate {SAMPLE_RATE}") + + # set dataframe clumn name + book_col_name = [ + "ID", + "Original_text", + "Normalized_text", + "Aligned_or_not", + "Start_time", + "End_time", + "Signal_to_noise_ratio", + ] + trans_col_name = [ + "ID", + "Original_text", + "Normalized_text", + "Dir_path", + "Duration", + ] + self.metadata_cache = pd.DataFrame(columns=book_col_name) + self.trans_cache = pd.DataFrame(columns=trans_col_name) + # dataset_cache_dir = args.cache_dir # cache_dir + # print(f"args.cache_dir = ", args.cache_dir) + # os.makedirs(dataset_cache_dir, exist_ok=True) + + ######## add data dir to dataset2dir ########## + self.dataset2dir = { + "dev-clean": f"{args.data_dir}/dev-clean", + "dev-other": f"{args.data_dir}/dev-other", + "test-clean": f"{args.data_dir}/test-clean", + "test-other": f"{args.data_dir}/test-other", + "train-clean-100": f"{args.data_dir}/train-clean-100", + "train-clean-360": f"{args.data_dir}/train-clean-360", + "train-other-500": f"{args.data_dir}/train-other-500", + } + + ###### load metadata and transcripts ##### + for dataset_name in self.dataset_list: + print("Initializing dataset: ", dataset_name) + # get [book,transcripts,audio] files list + self.book_files_list = self.get_metadata_files( + self.dataset2dir[dataset_name] + ) + self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name]) + + ## create metadata_cache (book.tsv file is not filtered, some file is not exist, but contain Duration and Signal_to_noise_ratio) + print("reading paths for dataset...") + for book_path in tqdm(self.book_files_list): + tmp_cache = pd.read_csv( + book_path, sep="\t", names=book_col_name, quoting=3 + ) + self.metadata_cache = pd.concat( + [self.metadata_cache, tmp_cache], ignore_index=True + ) + self.metadata_cache.set_index("ID", inplace=True) + + ## create transcripts (the trans.tsv file) + print("creating transcripts for dataset...") + for trans_path in tqdm(self.trans_files_list): + tmp_cache = pd.read_csv( + trans_path, sep="\t", names=trans_col_name, quoting=3 + ) + tmp_cache["Dir_path"] = os.path.dirname(trans_path) + self.trans_cache = pd.concat( + [self.trans_cache, tmp_cache], ignore_index=True + ) + self.trans_cache.set_index("ID", inplace=True) + + ## calc duration + self.trans_cache["Duration"] = ( + self.metadata_cache.End_time[self.trans_cache.index] + - self.metadata_cache.Start_time[self.trans_cache.index] + ) + ## add fullpath + # self.trans_cache['Full_path'] = os.path.join(self.dataset2dir[dataset_name],self.trans_cache['ID']) + + # filter_by_duration: filter_out files with duration < 3.0 or > 15.0 + print(f"Filtering files with duration between 3.0 and 15.0 seconds") + print(f"Before filtering: {len(self.trans_cache)}") + self.trans_cache = self.trans_cache[ + (self.trans_cache["Duration"] >= 3.0) + & (self.trans_cache["Duration"] <= 15.0) + ] + print(f"After filtering: {len(self.trans_cache)}") + + def get_metadata_files(self, directory): + book_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".book.tsv") and file[0] != ".": + rel_path = os.path.join(root, file) + book_files.append(rel_path) + return book_files + + def get_trans_files(self, directory): + trans_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".trans.tsv") and file[0] != ".": + rel_path = os.path.join(root, file) + trans_files.append(rel_path) + return trans_files + + def get_audio_files(self, directory): + audio_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith((".flac", ".wav", ".opus")): + rel_path = os.path.relpath(os.path.join(root, file), directory) + audio_files.append(rel_path) + return audio_files + + def get_num_frames(self, index): + # get_num_frames(durations) by index + duration = self.meta_data_cache["Duration"][index] + # num_frames = duration * SAMPLE_RATE + num_frames = int(duration * 75) + + # file_rel_path = self.meta_data_cache['relpath'][index] + # uid = file_rel_path.rstrip('.flac').split('/')[-1] + # num_frames += len(self.transcripts[uid]) + return num_frames + + def __len__(self): + return len(self.trans_cache) + + def __getitem__(self, idx): + # Get the file rel path + file_dir_path = self.trans_cache["Dir_path"].iloc[idx] + # Get uid + uid = self.trans_cache.index[idx] + # Get the file name from cache uid + file_name = uid + ".wav" + # Get the full file path + full_file_path = os.path.join(file_dir_path, file_name) + + # get phone + phone = self.trans_cache["Normalized_text"][uid] + phone = phonemizer_g2p(phone, "en")[1] + # load speech + speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE) + # if self.resample_to_24k: + # speech = librosa.resample(speech, orig_sr=SAMPLE_RATE, target_sr=24000) + # speech = torch.tensor(speech, dtype=torch.float32) + # pad speech to multiples of 200 + + # remainder = speech.size(0) % 200 + # if remainder > 0: + # pad = 200 - remainder + # speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0) + + # inputs = self._get_reference_vc(speech, hop_length=200) + inputs = {} + # Get the speaker id + # speaker = self.meta_data_cache['speaker'][idx] + # speaker_id = self.speaker2id[speaker] + # inputs["speaker_id"] = speaker_id + inputs["speech"] = speech # 24khz speech, [T] + inputs["phone"] = phone # [T] + return inputs + + +def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +def batch_by_size( + indices, + num_tokens_fn, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + bsz_mult = required_batch_size_multiple + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert ( + sample_len <= max_tokens + ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( + idx, sample_len, max_tokens + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + + +def test(): + from utils.util import load_config + + cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json") + dataset = VALLEDataset(cfg.dataset) + metadata_cache = dataset.metadata_cache + trans_cache = dataset.trans_cache + print(trans_cache.head(10)) + # print(dataset.book_files_list) + breakpoint() + + +if __name__ == "__main__": + test() diff --git a/models/tts/valle_v2/modeling_llama.py b/models/tts/valle_v2/modeling_llama.py new file mode 100644 index 00000000..e2b71b6d --- /dev/null +++ b/models/tts/valle_v2/modeling_llama.py @@ -0,0 +1,1043 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This code is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +# Original work copyright +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.models.llama.modeling_llama import ACT2FN +from transformers.models.llama.modeling_llama import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import PreTrainedModel +from transformers.models.llama.modeling_llama import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.llama.modeling_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full( + (tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device, + ) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return (self.weight * hidden_states).to(input_dtype) + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :], persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :], persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :], persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :], persistent=False + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, **kwargs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + + if "layer_idx" in kwargs: + self.layer_idx = kwargs["layer_idx"] + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, + torch.tensor( + torch.finfo(attn_weights.dtype).min, device=attn_weights.device + ), + ) + + unnormed_attn_weights = attn_weights + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, unnormed_attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, **kwargs): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/models/tts/valle_v2/valle_ar.py b/models/tts/valle_v2/valle_ar.py new file mode 100644 index 00000000..f50820fb --- /dev/null +++ b/models/tts/valle_v2/valle_ar.py @@ -0,0 +1,302 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel +import torch +import torch.nn.functional as F +import numpy as np +import os +import torch.nn as nn + + +class ValleAR(nn.Module): + def __init__( + self, + phone_vocab_size=256, + target_vocab_size=1024, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=12, + num_attention_heads=16, + pad_token_id=1281, + bos_target_id=1282, + eos_target_id=1283, + bos_phone_id=1284, + eos_phone_id=1285, + use_input_embeds=False, + emb_dim=256, + **kwargs, + ): + super(ValleAR, self).__init__() + self.config = LlamaConfig( + vocab_size=phone_vocab_size + target_vocab_size + 10, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + pad_token_id=pad_token_id, + bos_token_id=bos_target_id, + eos_token_id=eos_target_id, + ) + self.phone_vocab_size = phone_vocab_size + self.target_vocab_size = target_vocab_size + self.pad_token_id = pad_token_id + self.bos_target_id = bos_target_id + self.eos_target_id = eos_target_id + self.bos_phone_id = bos_phone_id + self.eos_phone_id = eos_phone_id + self.model = LlamaForCausalLM(self.config) + + self.use_input_embeds = use_input_embeds + + # no input embedding is used to provide speaker information + if self.use_input_embeds: + self.emb_linear = nn.Linear(emb_dim, hidden_size) + self.emb_linear.weight.data.normal_(mean=0.0, std=0.01) + self.emb_linear.bias.data.zero_() + + def forward( + self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None + ): + if input_embeds is not None: + input_embeds = self.emb_linear(input_embeds) + phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label( + phone_ids, + phone_mask, + self.eos_phone_id, + self.bos_phone_id, + self.pad_token_id, + ) + target_ids, target_mask, target_label = self.add_target_eos_bos_label( + target_ids, + target_mask, + self.eos_target_id, + self.bos_target_id, + self.pad_token_id, + ) + input_token_ids = torch.cat([phone_ids, target_ids], dim=-1) + attention_mask = torch.cat([phone_mask, target_mask], dim=-1) + # breakpoint() + if input_embeds is not None: + raise NotImplementedError + attention_mask = torch.cat( + [ + torch.ones( + (input_embeds.shape[0], input_embeds.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ], + dim=-1, + ) + labels = torch.cat([phone_label, target_label], dim=-1) + if input_embeds is not None: + raise NotImplementedError + labels = torch.cat( + [ + -100 + * torch.ones( + (input_embeds.shape[0], input_embeds.shape[1]), + dtype=labels.dtype, + device=labels.device, + ), + labels, + ], + dim=-1, + ) + + if input_embeds is not None: + raise NotImplementedError + inputs_embeds = torch.cat( + [input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 + ) + out = self.model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + return_dict=True, + ) + return out + + out = self.model( + input_token_ids, + attention_mask=attention_mask, + labels=labels, + return_dict=True, + ) + + # calcualte top1, top5, top10 accuracy + logits = out.logits + logits = logits[:, -target_ids.shape[1] :] + top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:] + top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum() + + top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1] + top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1) + top5_acc = ( + top5_acc * target_mask[..., :-1].unsqueeze(-1) + ).sum() / target_mask.sum() + + top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1] + top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1) + top10_acc = ( + top10_acc * target_mask[..., :-1].unsqueeze(-1) + ).sum() / target_mask.sum() + + out.top1_acc = top1_acc + out.top5_acc = top5_acc + out.top10_acc = top10_acc + + return out + + def add_phone_eos_bos_label( + self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id + ): + # phone_ids: [B, T] + # phone_mask: [B, T] + + phone_ids = phone_ids + self.target_vocab_size * phone_mask + + phone_ids = phone_ids * phone_mask + phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( + 1 - phone_mask, (0, 1), value=1 + ) # make pad token eos token, add eos token at the end + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask + phone_ids = phone_ids * phone_mask + pad_token_id * ( + 1 - phone_mask + ) # restore pad token ids + phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask + phone_label = -100 * torch.ones_like( + phone_ids + ) # loss for entire phone is not computed (passed to llama) + return phone_ids, phone_mask, phone_label + + def add_target_eos_bos_label( + self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id + ): + # target_ids: [B, T] + # target_mask: [B, T] + target_ids = target_ids * target_mask + target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad( + 1 - target_mask, (0, 1), value=1 + ) + target_mask = F.pad(target_mask, (1, 0), value=1) + target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) + target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) + target_mask = F.pad(target_mask, (1, 0), value=1) + target_label = target_ids * target_mask + (-100) * ( + 1 - target_mask + ) # loss for target is computed on unmasked tokens + return target_ids, target_mask, target_label + + def sample_hf( + self, + phone_ids, # the phones of prompt and target should be concatenated together + prompt_ids, + inputs_embeds=None, + max_length=2000, + temperature=1.0, + top_k=100, + top_p=0.9, + repeat_penalty=1.0, + num_beams=1, + ): + if inputs_embeds is not None: + inputs_embeds = self.emb_linear(inputs_embeds) + phone_mask = torch.ones_like(phone_ids) + prompt_mask = torch.ones_like(prompt_ids) + phone_ids, _, _ = self.add_phone_eos_bos_label( + phone_ids, + phone_mask, + self.eos_phone_id, + self.bos_phone_id, + self.pad_token_id, + ) + prompt_ids, _, _ = self.add_target_eos_bos_label( + prompt_ids, + prompt_mask, + self.eos_target_id, + self.bos_target_id, + self.pad_token_id, + ) + prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode + + input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) + + if inputs_embeds is not None: + raise NotImplementedError + inputs_embeds = torch.cat( + [inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 + ) + generated_ids = self.model.generate( + inputs_embeds=inputs_embeds, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + ) + gen_tokens = generated_ids[:, :-1] + return gen_tokens + + input_length = input_token_ids.shape[1] + generated_ids = self.model.generate( + input_token_ids, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + num_beams=num_beams, + ) + + gen_tokens = generated_ids[:, input_length:-1] + + return gen_tokens + + +def test(): + model = ValleAR() + + phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]]) + phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) + target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1) + target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1) + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) + + for i in range(15): + optimizer.zero_grad() + out = model( + phone_ids=phone_ids, + phone_mask=phone_mask, + target_ids=target_ids, + target_mask=target_mask, + ) + loss = out.loss + + loss.backward() + + optimizer.step() + + print(f"iter={i}, {loss}.") + + phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1) + target_ids = torch.LongTensor([765, 234]).reshape(1, -1) + sampled = model.sample_hf(phone_ids, target_ids) + + breakpoint() + + +if __name__ == "__main__": + test() diff --git a/models/tts/valle_v2/valle_ar_trainer.py b/models/tts/valle_v2/valle_ar_trainer.py new file mode 100644 index 00000000..5dc1aa07 --- /dev/null +++ b/models/tts/valle_v2/valle_ar_trainer.py @@ -0,0 +1,371 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import shutil +import torch +import time +from pathlib import Path +import torch +from tqdm import tqdm +import torch.nn as nn +from .base_trainer import BaseTrainer + + +def make_pad_mask( + lengths: torch.Tensor, max_len: int = 0, left_pad=False +) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + left_pad: + A boolean indicating whether to left pad the mask. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + mask = expaned_lengths >= lengths.unsqueeze(-1) + + if left_pad: + mask = mask.flip(dims=[1]) + + return mask + + +class ValleARTrainer(BaseTrainer): + def __init__(self, args=None, cfg=None): + super().__init__(args, cfg) + if self.cfg.use_speechtokenizer: + from models.codec.speechtokenizer.model import SpeechTokenizer + + config_path = "./ckpts/speechtokenizer_hubert_avg/config.json" + ckpt_path = "./ckpts/speechtokenizer_hubert_avg/SpeechTokenizer.pt" + assert os.path.isfile( + config_path + ), f"codec model {config_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" + assert os.path.isfile( + ckpt_path + ), f"codec model {ckpt_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" + self.codec_encoder = SpeechTokenizer.load_from_checkpoint( + config_path, ckpt_path + ) + self.codec_encoder.eval() + self.codec_encoder.to(self.accelerator.device) + print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}") + else: + from encodec import EncodecModel + + with self.accelerator.main_process_first(): + self.codec_encoder = EncodecModel.encodec_model_24khz() + self.codec_encoder.set_target_bandwidth(6.0) + self.codec_encoder.to(self.accelerator.device) + self.codec_decoder = None + print("Loaded EncodecModel") + self.top1_accuracies = [] + self.top5_accuracies = [] + self.top10_accuracies = [] + + if hasattr(self.cfg, "flatten_first_2_layers"): + self.flatten_first_2_layers = self.cfg.flatten_first_2_layers + print("flattened:", self.flatten_first_2_layers) + else: + self.flatten_first_2_layers = False + + if hasattr(self.cfg, "num_prediction_heads"): + self.num_prediction_heads = self.cfg.num_prediction_heads + print("num_prediction_heads:", self.num_prediction_heads) + + def _accelerator_prepare(self): + # if self.accelerator.is_main_process: + # breakpoint() + # self.accelerator.wait_for_everyone() + + ( + self.model, + self.optimizer, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + ) + + def _build_criterion(self): + pass # loss is directly returned from model + + def _build_scheduler(self): + from transformers import ( + get_cosine_schedule_with_warmup, + get_constant_schedule_with_warmup, + ) + + return get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps=self.cfg.train.scheduler.warmup_steps, + num_training_steps=self.cfg.train.scheduler.total_steps, + ) + + def _build_model(self): + if hasattr(self.cfg.model, "num_prediction_heads"): + from .valle_ar_multihead import ValleAR + else: + from .valle_ar import ValleAR + return ValleAR(**self.cfg.model) + + def _train_step(self, batch): + # inference codec + """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') + speech: [B, T] + speech_len: [B] + phone_ids: [B, T] + phone_lens: [B] + """ + device = self.accelerator.device + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + with torch.no_grad(): + if self.cfg.use_speechtokenizer: + # Extract discrete codes from SpeechTokenizer + vq_id = self.codec_encoder.encode( + batch["speech"].unsqueeze(1) + ) # [B,1,T] -> (n_q, B, T) + else: + vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) + vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( + 0, 1 + ) + + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # torchaudio.save('a.wav', recovered_audio[0], 16000) + # vq_id: [8, B, T//320] + if self.flatten_first_2_layers: + first_layer = vq_id[0] + second_layer = vq_id[1] + # flatten the first two layers + batch["speech"] = torch.stack( + [first_layer, second_layer], dim=-1 + ).flatten(-2, -1) + batch["speech_len"] = batch["speech_len"] // 160 + elif hasattr(self.cfg.model, "num_prediction_heads"): + batch["speech"] = vq_id[:2] # first two layers + batch["speech_len"] = ( + batch["speech_len"] // 320 + ) # our codec downsamples 320x + else: + batch["speech"] = vq_id[0] # use first layer + batch["speech_len"] = ( + batch["speech_len"] // 320 + ) # our codec downsamples 320x + assert batch["speech_len"].max() <= batch["speech"].shape[-1] + + phone_mask = 1 - make_pad_mask( + batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False + ).to(torch.long) + speech_mask = 1 - make_pad_mask( + batch["speech_len"], max_len=batch["speech"].size(1) + ).to(torch.long) + + out = self.model( + phone_ids=batch["phone_ids"], + phone_mask=phone_mask, + target_ids=batch["speech"], + target_mask=speech_mask, + ) + loss = out.loss + # if self.accelerator.is_main_process: + # print(loss) + # if hasattr(out, 'top1_acc'): + # self.top1_accuracies.append(out.top1_acc) + # self.top5_accuracies.append(out.top5_acc) + # self.top10_accuracies.append(out.top10_acc) + # print(f'avgs: top1: {sum(self.top1_accuracies)/len(self.top1_accuracies)}, top5: {sum(self.top5_accuracies)/len(self.top5_accuracies)}, top10: {sum(self.top10_accuracies)/len(self.top10_accuracies)}') + # breakpoint() + return loss + + ##########add your own dataloader to the trainer############# + def _build_dataloader(self): + from torch.utils.data import ConcatDataset, DataLoader + + if self.cfg.train.dataset.name == "emilia": + from .emilia_dataset import EmiliaDataset as VALLEDataset + + train_dataset = VALLEDataset() + elif self.cfg.train.dataset.name == "mls": + from .mls_dataset import VALLEDataset as VALLEDataset + + train_dataset = VALLEDataset(self.cfg.dataset, resample_to_24k=False) + elif self.cfg.train.dataset.name == "libritts": + from .libritts_dataset import VALLEDataset as VALLEDataset + + train_dataset = VALLEDataset(self.cfg.dataset) + + from .valle_collator import VALLECollator + import numpy as np + + print("length of train_dataset:", len(train_dataset)) + + collator = VALLECollator() + + if self.cfg.train.dataset.use_dynamic_batchsize: + if self.accelerator.is_main_process: + self.logger.info("Use Dynamic Batchsize......") + from .mls_dataset import batch_by_size + + batch_sampler = batch_by_size( + train_dataset.num_frame_indices, + train_dataset.get_num_frames, + max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, + max_sentences=self.cfg.train.max_sentences + * self.accelerator.num_processes, + required_batch_size_multiple=self.accelerator.num_processes, + ) + np.random.shuffle(batch_sampler) + print(batch_sampler[0]) + batches = [ + x[ + self.accelerator.local_process_index :: self.accelerator.num_processes + ] + for x in batch_sampler + if len(x) % self.accelerator.num_processes == 0 + ] + from models.base.base_sampler import VariableSampler + + train_loader = DataLoader( + train_dataset, + collate_fn=collator, + num_workers=self.cfg.train.dataloader.num_worker, + batch_sampler=VariableSampler( + batches, drop_last=True, use_random_sampler=True + ), + pin_memory=self.cfg.train.dataloader.pin_memory, + persistent_workers=self.cfg.train.dataloader.persistent_workers, + prefetch_factor=4, + ) + print( + f"process {self.accelerator.local_process_index} has {len(batches)} batches" + ) + self.accelerator.wait_for_everyone() + + else: + sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=self.accelerator.num_processes, + rank=self.accelerator.local_process_index, + shuffle=True, + ) + train_loader = DataLoader( + train_dataset, + batch_size=self.cfg.train.batch_size, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + collate_fn=collator, + sampler=sampler, + ) + print( + f"process {self.accelerator.local_process_index} has {len(train_loader)} batches" + ) + + return train_loader, None + + def _test_step(self, batch): + # inference codec + """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') + speech: [B, T] + speech_len: [B] + phone_ids: [B, T] + phone_lens: [B] + """ + import torchaudio + + device = self.accelerator.device + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + with torch.no_grad(): + if self.cfg.use_speechtokenizer: + # Extract discrete codes from SpeechTokenizer + vq_id = self.codec_encoder.encode( + batch["speech"].unsqueeze(1) + ) # [B,1,T] -> (n_q, B, T) + else: + vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) + vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( + 0, 1 + ) + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # torchaudio.save('a.wav', recovered_audio[0], 16000) + # vq_id: [8, B, T//200] + + # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # recovered_audio.shape: torch.Size([1, 1, 50200]) + + if self.flatten_first_2_layers: + first_layer = vq_id[0] + second_layer = vq_id[1] + # flatten the first two layers + batch["speech"] = torch.stack( + [first_layer, second_layer], dim=-1 + ).flatten(-2, -1) + batch["speech_len"] = batch["speech_len"] // 160 + elif hasattr(self.cfg.model, "num_prediction_heads"): + batch["speech"] = vq_id[:2] # first two layers + batch["speech_len"] = ( + batch["speech_len"] // 320 + ) # our codec downsamples 320x + else: + batch["speech"] = vq_id[0] # use first layer + batch["speech_len"] = ( + batch["speech_len"] // 320 + ) # our codec downsamples 320x + + # save gt + breakpoint() + recovered_audio = self.codec_encoder.decode(vq_id[:1, :1]) + # recovered_audio = self.codec_encoder.decode([(vq_id[:1].transpose(0,1), None)]) + torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) + out_vq_ids = self.model.sample_hf( + batch["phone_ids"][:1, ...], batch["speech"][:1, :225], temperature=0.9 + ) + # out_vq_ids = torch.cat([batch['speech'][:1, :225], out_vq_ids[:1, ...]], dim=1) + + # reconstruct form tokens + recovered_audio = self.codec_encoder.decode(out_vq_ids.unsqueeze(0)) + # recovered_audio = self.codec_encoder.decode([(out_vq_ids, None)]) + torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) + breakpoint() + print() + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + epoch_sum_loss = 0.0 + return epoch_sum_loss + + def _inference(self): + pass + + def test_loop(self): + self.model.eval() + for batch in self.train_dataloader: + self._test_step(batch) diff --git a/models/tts/valle_v2/valle_collator.py b/models/tts/valle_v2/valle_collator.py new file mode 100644 index 00000000..29db1b32 --- /dev/null +++ b/models/tts/valle_v2/valle_collator.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.nn.utils.rnn import pad_sequence + + +class VALLECollator: + def __init__(self, cfg=None): + self.cfg = cfg + + def __call__(self, batch): + """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') + speech: [B, T] + speech_len: [B] + phone_ids: [B, T] + phone_lens: [B] + """ + assert len(batch) != 0, "batch is empty before None checking" + batch = [b for b in batch if b is not None] + assert len(batch) != 0, "batch is empty after None checking" + packed_batch_features = {} + + # Function to handle tensor copying + def process_tensor(data, dtype=torch.float32): + if isinstance(data, torch.Tensor): + return data.detach() + else: + return torch.tensor(data, dtype=dtype) + + # Process 'speech' data + speeches = [process_tensor(b["speech"]) for b in batch] + packed_batch_features["speech_len"] = torch.tensor( + [len(s) for s in speeches], dtype=torch.long + ) + packed_batch_features["speech"] = pad_sequence( + speeches, batch_first=True, padding_value=0 + ) + + # right-padding 'phone' data + phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch] + packed_batch_features["phone_lens"] = torch.tensor( + [len(phone) for phone in phones], dtype=torch.long + ) + packed_batch_features["phone_ids"] = pad_sequence( + phones, batch_first=True, padding_value=0 + ) + + # # Process 'phone' data, with left padding + # phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence + # packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long) + # packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding + # packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding) + + return packed_batch_features diff --git a/models/tts/valle_v2/valle_inference.py b/models/tts/valle_v2/valle_inference.py new file mode 100644 index 00000000..39efc6f2 --- /dev/null +++ b/models/tts/valle_v2/valle_inference.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchaudio + + +class ValleInference(torch.nn.Module): + def __init__( + self, + use_vocos=False, + use_speechtokenizer=True, + ar_path=None, + nar_path=None, + speechtokenizer_path=None, + device="cuda", + ): + super().__init__() + + self.device = device + + # prepare pretrained VALLE AR model + from .valle_ar import ValleAR + + self.ar_model = ValleAR( + phone_vocab_size=300, + target_vocab_size=1024, + pad_token_id=1324, + bos_target_id=1325, + eos_target_id=1326, + bos_phone_id=1327, + eos_phone_id=1328, + bos_prompt_id=1329, + eos_prompt_id=1330, + num_hidden_layers=16, + ) + # change the following path to your trained model path + assert ar_path is not None + self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu")) + self.ar_model.eval().to(self.device) + + # prepare pretrained VALLE NAR model + from .valle_nar import ValleNAR + + self.nar_model = ValleNAR( + phone_vocab_size=300, + target_vocab_size=1024, + pad_token_id=1324, + bos_target_id=1325, + eos_target_id=1326, + bos_phone_id=1327, + eos_phone_id=1328, + bos_prompt_id=1329, + eos_prompt_id=1330, + num_hidden_layers=16, + ) + assert nar_path is not None + self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu")) + self.nar_model.eval().to(self.device) + + # prepare codec encoder + assert not ( + use_speechtokenizer and use_vocos + ), "Only one of use_speechtokenizer and use_vocos can be True" + self.use_speechtokenizer = use_speechtokenizer + if use_speechtokenizer: + from models.codec.speechtokenizer.model import SpeechTokenizer + + # download from https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg + config_path = speechtokenizer_path + "/config.json" + ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt" + self.codec_encoder = SpeechTokenizer.load_from_checkpoint( + config_path, ckpt_path + ) + self.codec_encoder.eval() + self.codec_encoder.to(device) + print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}") + else: + # use Encodec + from encodec import EncodecModel + + self.codec_encoder = EncodecModel.encodec_model_24khz() + self.codec_encoder.set_target_bandwidth(6.0) + self.codec_encoder.to(self.device) + if use_vocos: + from vocos import Vocos + + self.codec_decoder = Vocos.from_pretrained( + "charactr/vocos-encodec-24khz" + ) + self.codec_decoder.to(self.device) + print("Loaded Vocos") + print("Loaded EncodecModel") + + self.use_vocos = use_vocos + + def decode(self, vq_ids): + """vq_ids.shape: [8, B, T], + returns: [B, 1, T]""" + if self.use_speechtokenizer: + # infer speechtokenizer + return self.codec_encoder.decode(vq_ids) # [B, 1, T] + else: + if not self.use_vocos: + # vocos decoder + return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)]) + else: + # encodec decoder + features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1)) + bandwidth_id = torch.tensor([2], device=vq_ids.device) + return self.codec_decoder.decode( + features, bandwidth_id=bandwidth_id + ).unsqueeze(0) + + def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None): + """batch: dict( + speech: [B, T] + phone_ids: [B, T] + ) + returns: [B, 1, T] audio + """ + if prompt_len is None: + prompt_len = 100000 # no prompt length limiting + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(self.device) + with torch.no_grad(): + if self.use_speechtokenizer: + vq_id = self.codec_encoder.encode( + batch["speech"].unsqueeze(1) + ) # [B,1,T] -> (n_q, B, T) + else: + vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) + vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( + 0, 1 + ) + + # typically we only require one config in the chunk, + # but we can also use multiple configs to, for example, use different sampling temperature at different positions + for chunk in chunk_configs: + ar_vq_ids = self.ar_model.sample_hf( + batch["phone_ids"], + vq_id[0, :, :prompt_len], + top_p=chunk["top_p"], + top_k=chunk["top_k"], + temperature=chunk["temperature"], + num_beams=chunk["num_beams"], + repeat_penalty=chunk["repeat_penalty"], + max_length=chunk["max_length"], + ) + # recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0)) + # torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000) + + nar_vq_ids = self.nar_model.sample_hf( + phone_ids=batch["phone_ids"], + prompt_ids=vq_id[:, :, :prompt_len], + first_stage_ids=ar_vq_ids, + # first_stage_ids=vq_id[0, :, prompt_len:], + ) + + if return_prompt: + nar_vq_ids = torch.cat( + [vq_id[..., :prompt_len], nar_vq_ids], dim=-1 + ) + + recovered_audio = self.decode(nar_vq_ids) + return recovered_audio # [B, 1, T] diff --git a/models/tts/valle_v2/valle_nar.py b/models/tts/valle_v2/valle_nar.py new file mode 100644 index 00000000..7d2ff021 --- /dev/null +++ b/models/tts/valle_v2/valle_nar.py @@ -0,0 +1,801 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +import torch +import torch.nn.functional as F +import numpy as np +import os +import torch.nn as nn +from typing import List, Optional, Tuple, Union + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR. +START_QUANTIZATION_LAYER = 1 # start quantization layer +END_QUANTIZATION_LAYER = 7 # end quantization layer + + +class LlamaAdaptiveRMSNorm(nn.Module): + def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024): + super().__init__() + self.to_weight = nn.Linear(dim_cond, hidden_size) + nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02) + # nn.init.zeros_(self.to_weight.weight) + # nn.init.ones_(self.to_weight.bias) + self.variance_epsilon = eps + self._is_hf_initialized = True # disable automatic init + + def forward(self, hidden_states, cond_embedding): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + weight = self.to_weight(cond_embedding) + + return (weight * hidden_states).to(input_dtype) + + +class LlamaNARDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + """Override to adaptive layer norm""" + super().__init__(config=config, layer_idx=0) # init attention, mlp, etc. + self.input_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + self.post_attention_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + + # add `cond` in forward function + def forward( + self, + hidden_states: torch.Tensor, + cond_embedding: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +from transformers.models.llama.modeling_llama import BaseModelOutputWithPast + + +class MultiEmbedding(nn.Module): + """Embedding for multiple quantization layers, summing up the embeddings of each layer.""" + + def __init__( + self, + num_embeddings=1034, + embedding_dim=1024, + num_quantization_layers=NUM_QUANTIZERS, + ): + super().__init__() + self.embeddings = nn.ModuleList( + [ + nn.Embedding(num_embeddings, embedding_dim) + for _ in range(num_quantization_layers) + ] + ) + + # initialize embeddings + for i in range(num_quantization_layers): + self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02) + self._is_hf_initialized = True # disable automatic init + + def forward(self, input_ids): + """Input: [num_quant, B, T] -> Output: [B, T, H]""" + num_quant, B, T = input_ids.shape + summed_embeddings = torch.zeros( + B, T, self.embeddings[0].embedding_dim, device=input_ids.device + ) + for i in range(num_quant): + summed_embeddings += self.embeddings[i](input_ids[i]) + return summed_embeddings + + +class LlammaNARModel(LlamaModel): + def __init__(self, config): + """Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer""" + super().__init__(config) + self.layers = nn.ModuleList( + [LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + + self.embed_cond = nn.Embedding( + NUM_QUANTIZERS, config.hidden_size + ) # 7 quantization layers + + for layer in self.layers: + layer.input_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + + self.post_init() + + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create noncausal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, # [num_quant, B, T] + cond: torch.LongTensor = None, # index for conditional embeddings, [B] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # retrieve some shape info + batch_size, seq_length, _ = input_ids.shape + + inputs_embeds = input_ids # [B, T, H] + # embed cond + cond_embedding = self.embed_cond(cond) # [B, H] + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + raise NotImplementedError + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cond_embedding=cond_embedding, # using cond embed + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +from transformers.models.llama.modeling_llama import LlamaPreTrainedModel +from transformers.models.llama.modeling_llama import CrossEntropyLoss +from easydict import EasyDict as edict + + +class LlamaForNARModeling(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlammaNARModel(config) + + self.lm_head = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1) + ] + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + cond: torch.LongTensor, # added + prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss + input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + # labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + """Prediction target: [B, T]""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + cond=cond, # added + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states) + + loss = None + loss_fct = CrossEntropyLoss() + + if prediction_target is not None: + # calculate loss if prediction_target is provided + logits_tmp = logits.view(-1, logits.size(-1)) + prediction_target = prediction_target.view(-1) + loss = loss_fct(logits_tmp, prediction_target) + + return edict( + loss=loss, + logits=logits, + ) + + +class ValleNAR(nn.Module): + def __init__( + self, + phone_vocab_size=256, + target_vocab_size=1024, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=12, + num_attention_heads=16, + pad_token_id=1024 + 256, + bos_target_id=1282, + eos_target_id=1283, + bos_phone_id=1284, + eos_phone_id=1285, + bos_prompt_id=1286, + eos_prompt_id=1287, + use_input_embeds=False, + emb_dim=256, + ): + super(ValleNAR, self).__init__() + self.config = LlamaConfig( + vocab_size=phone_vocab_size + target_vocab_size + 10, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + pad_token_id=pad_token_id, + bos_token_id=bos_target_id, + eos_token_id=eos_target_id, + use_cache=False, + ) + self.phone_vocab_size = phone_vocab_size + self.target_vocab_size = target_vocab_size + self.pad_token_id = pad_token_id + self.bos_target_id = bos_target_id + self.eos_target_id = eos_target_id + self.bos_phone_id = bos_phone_id + self.eos_phone_id = eos_phone_id + self.bos_prompt_id = bos_prompt_id + self.eos_prompt_id = eos_prompt_id + self.model = LlamaForNARModeling(self.config) + + self.use_input_embeds = use_input_embeds + + self.phone_embedder = nn.Embedding( + self.phone_vocab_size + 10, hidden_size + ) # use phone_embedder to embed all eos, bos tokens + self.prompt_embedder = MultiEmbedding( + num_embeddings=self.target_vocab_size, + embedding_dim=hidden_size, + num_quantization_layers=NUM_QUANTIZERS, + ) + self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02) + + # use linear mask schedule when training + # another option is uniform + self.mask_layer_schedule = "uniform" + + # no input embedding is used to provide speaker information + if self.use_input_embeds: + self.emb_linear = nn.Linear(emb_dim, hidden_size) + self.emb_linear.weight.data.normal_(mean=0.0, std=0.01) + self.emb_linear.bias.data.zero_() + + def forward( + self, + phone_ids, + phone_mask, + target_ids, + target_mask, + target_quantization_layer=None, + prompt_len=None, + dropout=0.0, + ): + """ + phone_ids: [B, T] + phone_mask: [B, T] + target_ids: [8,B,T] + target_mask: [B, T] + dropout: rate of dropping out the target tokens + """ + assert (target_ids < 1024).all(), "target_ids should be less than 1024" + phone_ids = phone_ids + self.target_vocab_size + phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id + # assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024" + # phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label( + # phone_ids, + # phone_mask, + # self.eos_phone_id, + # self.bos_phone_id, + # self.pad_token_id, + # ) + phone_label = -100 * (1 - phone_mask) + # get phone embedding + phone_embedding = self.phone_embedder( + phone_ids - self.target_vocab_size + ) # [B, T, H] + + if prompt_len is not None: + assert not self.training # inference stage fix prompt len to input + NUM_PROMPT_TOKENS = prompt_len + else: + assert self.training + # randomly select a prompt length + assert self.training # randomize prompt len in training + NUM_PROMPT_TOKENS = np.random.randint( + min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2 + ) + + # extract 8-level prompts + prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T] + prompt_mask = torch.ones_like(prompt_tokens[0]) + prompt_label = -100 * prompt_mask + # get prompt embedding + prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H] + + # randomly select a target qnt layer to predict + # total quant layer is 0 to 7 + if target_quantization_layer is None: + if self.mask_layer_schedule == "linear": + weights = torch.tensor( + [ + NUM_QUANTIZERS - i + for i in range( + START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 + ) + ] + ) + weights = weights / weights.sum() + mask_layer = ( + torch.multinomial(weights, 1, replacement=True) + + START_QUANTIZATION_LAYER + ) + assert ( + mask_layer >= START_QUANTIZATION_LAYER + and mask_layer <= END_QUANTIZATION_LAYER + ) + target_quantization_layer = mask_layer.item() + elif self.mask_layer_schedule == "cosine": + weights = torch.tensor( + [ + np.cos(i / NUM_QUANTIZERS * np.pi / 2) + for i in range( + START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 + ) + ] + ) + weights = weights / weights.sum() + mask_layer = ( + torch.multinomial(weights, 1, replacement=True) + + START_QUANTIZATION_LAYER + ) + assert ( + mask_layer >= START_QUANTIZATION_LAYER + and mask_layer <= END_QUANTIZATION_LAYER + ) + target_quantization_layer = mask_layer.item() + breakpoint() + elif self.mask_layer_schedule == "uniform": + target_quantization_layer = np.random.randint( + START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 + ) + + # print(f'target layer: {target_quantization_layer}') + # prompt of the target part + target_prompt_ids = target_ids[ + :target_quantization_layer, :, NUM_PROMPT_TOKENS: + ] + + def randomly_set_elements(tensor, fraction, value): + """ + Randomly set a fraction of the elements in a tensor to a specific value. + + Args: + tensor (torch.Tensor): The input tensor. + fraction (float): The fraction of elements to set to the specified value (between 0 and 1). + value (float or int): The value to set the elements to. + + Returns: + torch.Tensor: The tensor with some elements set to the specified value. + """ + # Create a mask with the same shape as the tensor + mask = torch.rand_like(tensor, dtype=torch.float32) < fraction + # Clone the tensor to avoid modifying the original tensor + result_tensor = tensor.clone() + # Set the elements where the mask is True to the specified value + result_tensor[mask] = value + return result_tensor + + if dropout != 0.0: + target_prompt_ids = randomly_set_elements( + target_prompt_ids, dropout, self.target_vocab_size + ) + + target_embedding = self.prompt_embedder(target_prompt_ids) + + # mask of the target part + target_mask = target_mask[:, NUM_PROMPT_TOKENS:] + + target_labels = target_ids[ + target_quantization_layer, :, NUM_PROMPT_TOKENS: + ] * target_mask + (-100 * (1 - target_mask)) + + # input embeddings + input_embeddings = torch.cat( + [phone_embedding, prompt_embedding, target_embedding], dim=1 + ) + input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T] + prediction_target = torch.cat( + [phone_label, prompt_label, target_labels], dim=1 + ) # [B, T] + + out = self.model( + cond=torch.tensor( + target_quantization_layer, + device=prediction_target.device, + dtype=torch.long, + ), + input_ids=input_embeddings, + prediction_target=prediction_target, + attention_mask=input_mask, + return_dict=True, + ) + logits = out.logits[:, -target_embedding.shape[1] :, :] + targets = prediction_target[..., -target_embedding.shape[1] :] + top1_acc = logits.argmax(-1) == targets + top1_acc = (top1_acc * target_mask).sum() / target_mask.sum() + + top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1) + top5_acc = (top5_acc * target_mask).sum() / target_mask.sum() + + top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1) + top10_acc = (top10_acc * target_mask).sum() / target_mask.sum() + + out.target_quantization_layer = target_quantization_layer + out.top1_acc = top1_acc + out.top5_acc = top5_acc + out.top10_acc = top10_acc + + return out + + def add_phone_eos_bos_label( + self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id + ): + # phone_ids: [B, T] + # phone_mask: [B, T] + + phone_ids = phone_ids + self.target_vocab_size * phone_mask + + phone_ids = phone_ids * phone_mask + phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( + 1 - phone_mask, (0, 1), value=1 + ) # make pad token eos token, add eos token at the end + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask + phone_ids = phone_ids * phone_mask + pad_token_id * ( + 1 - phone_mask + ) # restore pad token ids + phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask + phone_label = -100 * torch.ones_like( + phone_ids + ) # loss for entire phone is not computed (passed to llama) + return phone_ids, phone_mask, phone_label + + @torch.no_grad() + def sample_hf( + self, + phone_ids, # [B, T] + prompt_ids, # [8, B, T] + first_stage_ids, # [B, T] + top_k=50, + top_p=1, + temperature=1.1, + first_stage_ids_gt=None, # [Q, B, T] + first_stage_ids_gt_end_layer=None, # 2 to 8 + ): + """ + phone_ids: [B, T] + prompt_ids: [8, B, T] + first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids + """ + phone_mask = torch.ones_like(phone_ids, dtype=torch.long) + + assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens" + target_ids = torch.cat( + [prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1 + ) + target_mask = torch.ones_like(target_ids[0], dtype=torch.long) + + if first_stage_ids_gt is not None: + target_ids[ + :first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] : + ] = first_stage_ids_gt[:first_stage_ids_gt_end_layer] + + gen_len = first_stage_ids.shape[-1] + + start_qnt_layer = 1 + if first_stage_ids_gt_end_layer is not None: + start_qnt_layer = first_stage_ids_gt_end_layer + for qnt_level in range(start_qnt_layer, 8): + out = self.forward( + phone_ids=phone_ids, + phone_mask=phone_mask, + target_ids=target_ids, + target_mask=target_mask, + target_quantization_layer=qnt_level, + prompt_len=prompt_ids.shape[-1], + ) + logits = out.logits + gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[ + -gen_len: + ] # [T], generated tokens in this level + + # overwrite the target_ids with the generated tokens + target_ids[qnt_level, :, -gen_len:] = gen_tokens + + return target_ids[:, :, -gen_len:] + + +def test(): + model = ValleNAR().cuda() + + phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda() + phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda() + target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda() + target_mask = torch.ones(1, 250, dtype=torch.long).cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) + + for i in range(200): + optimizer.zero_grad() + out = model( + phone_ids=phone_ids, + phone_mask=phone_mask, + target_ids=target_ids, + target_mask=target_mask, + # target_quantization_layer=1+i%6, + ) + loss = out.loss + + loss.backward() + + optimizer.step() + + print(f"iter={i}, {loss}.") + target_ids_short = target_ids[:, :, :240] + + model.eval() + sampled = model.sample_hf( + phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:] + ) + + print(target_ids[:, :, -10:]) + print(sampled) + + print((sampled == target_ids[:, :, -10:]).all()) + + +if __name__ == "__main__": + test() diff --git a/models/tts/valle_v2/valle_nar_trainer.py b/models/tts/valle_v2/valle_nar_trainer.py new file mode 100644 index 00000000..88e5e8a3 --- /dev/null +++ b/models/tts/valle_v2/valle_nar_trainer.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchaudio +import numpy as np +import time +from .valle_ar_trainer import ValleARTrainer, make_pad_mask + + +class ValleNARTrainer(ValleARTrainer): + def __init__(self, args=None, cfg=None): + super().__init__(args, cfg) + print("simple NAR") + self.top1_accuracies = { + 1: [], + 2: [], + 3: [], + 4: [], + 5: [], + 6: [], + 7: [], + } + self.top5_accuracies = { + 1: [], + 2: [], + 3: [], + 4: [], + 5: [], + 6: [], + 7: [], + } + self.top10_accuracies = { + 1: [], + 2: [], + 3: [], + 4: [], + 5: [], + 6: [], + 7: [], + } + + def _build_model(self): + from .valle_nar import ValleNAR + + return ValleNAR(**self.cfg.model) + + def _train_step(self, batch): + # inference codec + """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') + speech: [B, T] + speech_len: [B] + phone_ids: [B, T] + phone_lens: [B] + """ + device = self.accelerator.device + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + + with torch.no_grad(): + if self.cfg.use_speechtokenizer: + # Extract discrete codes from SpeechTokenizer + # 16k + vq_id = self.codec_encoder.encode( + batch["speech"].unsqueeze(1) + ) # [B,T] -> (n_q, B, T) + # RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens + # RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer + # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding + # wav = self.codec_encoder.decode(vq_id) + # torchaudio.save('a.wav', wav[0].cpu(), 16000) + + # # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers + # wav = model.decode(codes[i: (j + 1)], st=i) + else: + # using encodec, 24k + vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) + vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( + 0, 1 + ) + + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # torchaudio.save('a.wav', recovered_audio[0], 16000) + # vq_id: [8, B, T//320] + batch["speech"] = vq_id + batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x + assert batch["speech_len"].max() <= batch["speech"].shape[-1] + + phone_mask = 1 - make_pad_mask( + batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False + ).to(torch.long) + speech_mask = 1 - make_pad_mask( + batch["speech_len"], max_len=batch["speech"].size(-1) + ).to(torch.long) + + np.random.seed(int(time.time()) - 5 * self.accelerator.process_index) + + if hasattr(self.cfg.train, "dropout"): + dropout = self.cfg.train.dropout + else: + dropout = 0.0 + + out = self.model( + phone_ids=batch["phone_ids"], + phone_mask=phone_mask, + target_ids=batch["speech"], + target_mask=speech_mask, + dropout=dropout, + ) + loss = out.loss + + self.accelerator.log( + {f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc}, + step=self.step, + ) + self.accelerator.log( + {f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc}, + step=self.step, + ) + self.accelerator.log( + {f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc}, + step=self.step, + ) + + # if hasattr(out, 'top1_acc'): + # idx = out.target_quantization_layer + # self.top1_accuracies[idx].append(out.top1_acc) + # self.top5_accuracies[idx].append(out.top5_acc) + # self.top10_accuracies[idx].append(out.top10_acc) + # if len(self.top1_accuracies[idx]) >= 160: + # breakpoint() + # if self.accelerator.is_main_process: + # print(loss) + return loss + + def _test_step(self, batch): + # inference codec + """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') + speech: [B, T] + speech_len: [B] + phone_ids: [B, T] + phone_lens: [B] + """ + import torchaudio + + device = self.accelerator.device + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + with torch.no_grad(): + if self.cfg.use_speechtokenizer: + # Extract discrete codes from SpeechTokenizer + # 16k + vq_id = self.codec_encoder.encode( + batch["speech"].unsqueeze(1) + ) # [B,1,T] -> (n_q, B, T) + # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding + # wav = self.codec_encoder.decode(vq_id) + # torchaudio.save('a.wav', wav[0].cpu(), 16000) + + else: + vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) + vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( + 0, 1 + ) + # recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)]) + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # torchaudio.save('a.wav', recovered_audio[0], 16000) + # vq_id: [8, B, T//200] + + # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) + # recovered_audio = self.codec_decoder(vq_emb, vq=False) + # recovered_audio.shape: torch.Size([1, 1, 50200]) + + batch["speech"] = vq_id + + # save gt + if self.cfg.use_speechtokenizer: + recovered_audio = self.codec_encoder.decode(vq_id) + else: + recovered_audio = self.codec_encoder.decode( + [(vq_id.transpose(0, 1), None)] + ) + torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) + self.model.eval() + out_vq_ids = self.model.sample_hf( + phone_ids=batch["phone_ids"][:1], + prompt_ids=batch["speech"][:, :1, :150], + first_stage_ids=batch["speech"][0, :1, 150:], + ) + # breakpoint() + # out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1) + + # reconstruct form tokens + if self.cfg.use_speechtokenizer: + recovered_audio = self.codec_encoder.decode(out_vq_ids) + else: + recovered_audio = self.codec_encoder.decode( + [(out_vq_ids.transpose(0, 1)[:1], None)] + ) + torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) + breakpoint()