diff --git a/video_salmonn/README.md b/video_salmonn/README.md new file mode 100644 index 0000000..36900dc --- /dev/null +++ b/video_salmonn/README.md @@ -0,0 +1,52 @@ +## Inference + +### Preparation +Install the environment with the following specified config: +``` +conda env create -f videosalmonn.yml +``` +Create directory to store checkpoints (If modify the structure/rename directories, need to change config files and model files accordingly) +``` +mkdir -p ckpt/MultiResQFormer +mkdir -p ckpt/pretrained_ckpt +``` +Then download the following model checkpoints: + +1. Main video-SALMONN model [checkpoint](https://huggingface.co/tsinghua-ee/Video-SALMONN/tree/main), then put it under `ckpt/MultiResQFormer` +2. InstructBLIP [checkpoint](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth) for Vicuna-13B model, then put it under `ckpt/pretrained_ckpt` +3. EVA_VIT model [checkpoint](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth) for InstructBLIP, then put it under `ckpt/pretrained_ckpt` +4. BEATs encoder [checkpoint](https://huggingface.co/spaces/fffiloni/SALMONN-7B-gradio/blob/677c0125de736ab92751385e1e8664cd03c2ce0d/beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt), then put it under `ckpt/pretrained_ckpt` + + +### Run inference +``` +python inference.py --cfg-path config/test.yaml +``` + +### Check the result +The result is saved in the following path: +``` +./ckpt/MultiResQFormer//eval_result.json +``` + +Expecting the following result: +``` +[ + { + "id": "./dummy/4405327307.mp4_Describe the video and audio in detail", + "conversation": [ + { + "from": "human", + "value": "Describe the video and audio in detail" + }, + { + "from": "gpt", + "value": "None" + } + ], + "task": "audiovisual_video_input", + "ref_answer": "None", + "gen_answer": "The video shows a group of musicians performing on stage, with a man singing into a microphone and playing the piano. There is also a drum set and a saxophone on stage. The audience is not visible in the video. The music is upbeat and energetic, and the performers seem to be enjoying themselves." + } +] +``` diff --git a/video_salmonn/config/__init__.py b/video_salmonn/config/__init__.py new file mode 100644 index 0000000..826b6ef --- /dev/null +++ b/video_salmonn/config/__init__.py @@ -0,0 +1,37 @@ +import yaml + +def load_model_config(model, mode): + # load special config for each model + config_path = f'config/{model}.yaml' + print(f'[!] load configuration from {config_path}') + with open(config_path) as f: + configuration = yaml.load(f, Loader=yaml.FullLoader) + new_config = {} + for key, value in configuration.items(): + if key in ['train', 'test', 'validation']: + if mode == key: + new_config.update(value) + else: + new_config[key] = value + configuration = new_config + return configuration + +def load_config(args): + '''the configuration of each model can rewrite the base configuration''' + # base config + base_configuration = load_base_config() + + # load one model config + configuration = load_model_config(args['model'], args['mode']) + + # update and append the special config for base config + base_configuration.update(configuration) + configuration = base_configuration + return configuration + +def load_base_config(): + config_path = f'config/base.yaml' + with open(config_path) as f: + configuration = yaml.load(f, Loader=yaml.FullLoader) + print(f'[!] load base configuration: {config_path}') + return configuration diff --git a/video_salmonn/config/base.yaml b/video_salmonn/config/base.yaml new file mode 100644 index 0000000..867e5c8 --- /dev/null +++ b/video_salmonn/config/base.yaml @@ -0,0 +1,20 @@ +models: + openllama: + model_name: OpenLLAMAModel + agent_name: DeepSpeedAgent + stage1_train_dataset: SupervisedDataset + test_dataset: SelfInstructTestDataset + openllama_peft: + model_name: OpenLLAMAPEFTModel + agent_name: DeepSpeedAgent + stage1_train_dataset: SupervisedDataset + test_dataset: SelfInstructTestDataset + openllama_peft_small: + model_name: OpenLLAMAPEFTModel + agent_name: DeepSpeedAgent + stage1_train_dataset: SupervisedDataset + test_dataset: SelfInstructTestDataset + +# ========= Global configuration ========== # +logging_step: 5 +# ========= Global configuration ========== # diff --git a/video_salmonn/config/config.py b/video_salmonn/config/config.py new file mode 100644 index 0000000..93ef09f --- /dev/null +++ b/video_salmonn/config/config.py @@ -0,0 +1,29 @@ +from omegaconf import OmegaConf + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + user_config = self._build_opt_list(self.args.options) + config = OmegaConf.load(self.args.cfg_path) + config = OmegaConf.merge(config, user_config) + self.config = config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) diff --git a/video_salmonn/config/openllama_peft.yaml b/video_salmonn/config/openllama_peft.yaml new file mode 100644 index 0000000..76f36ab --- /dev/null +++ b/video_salmonn/config/openllama_peft.yaml @@ -0,0 +1,22 @@ +# generation hyper-parameters +max_len: 512 +penalty_alpha: 0.6 +top_k: 10 +top_p: 0.7 +random_prefix_len: 5 +sample_num: 2 +decoding_method: sampling +generate_len: 512 + +# lora hyper-parameters +lora_r: 8 +lora_alpha: 32 +lora_dropout: 0.1 + +# some train configuration, more can be found under dsconfig folder +train: + seed: 1337 # 0 + warmup_rate: 0.2 + epochs: 10 + max_length: 2000 + max_shard_size: 80GB diff --git a/video_salmonn/config/openllama_peft_small.yaml b/video_salmonn/config/openllama_peft_small.yaml new file mode 100644 index 0000000..9022a95 --- /dev/null +++ b/video_salmonn/config/openllama_peft_small.yaml @@ -0,0 +1,22 @@ +# generation hyper-parameters +max_len: 512 +penalty_alpha: 0.6 +top_k: 10 +top_p: 0.7 +random_prefix_len: 5 +sample_num: 2 +decoding_method: sampling +generate_len: 512 + +# lora hyper-parameters +lora_r: 32 +lora_alpha: 32 +lora_dropout: 0.1 + +# some train configuration, more can be found under dsconfig folder +train: + seed: 0 + warmup_rate: 0.3 + epochs: 10 + max_length: 1024 + max_shard_size: 80GB diff --git a/video_salmonn/config/test.yaml b/video_salmonn/config/test.yaml new file mode 100644 index 0000000..180e06b --- /dev/null +++ b/video_salmonn/config/test.yaml @@ -0,0 +1,51 @@ +model: openllama_peft +imagebind_ckpt_path: "" +vicuna_ckpt_path: /scratch/LLM/LLM.ckpts/vicuna-13b-v1.5 # Should be modified to your own place +orig_delta_path: "" +delta_ckpt_path: ./ckpt/MultiResQFormer/pytorch_model_4_5001.pt + +all_decode_info: [ + ["audiovideoimage", "audiovisual_video_input", "example.json"] +] + +stage: 2y +max_tgt_len: 512 # 32000 +yu_lora_r: 32 # 8 +yu_lora_alpha: 32 +yu_lora_dropout: 0.1 +lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # ['q_proj', 'v_proj'] +use_lora: "true" +qformer: "true" +use_whisper: "true" +use_blip: "true" +instructblip: "true" +proj_checkpoint: "" +num_video_query: 30 +instructblip_video: "false" +video_window_size: 240 +skip_vqformer: "false" +speech_qformer: "false" +early_align: "true" +cascaded: "" +causal: "false" +diversity_loss: "false" +causal_attention: "true" # "false" +groupsize: 10 +alignmode: 2 +pure_aud: False +num_speech_query: 1 +second_per_frame: 0.333333 +second_stride: 0.333333 +sin_pos: False +use_beats: True # True +return_raw: True # True +n_pos: 120 +flash_attn: False +batch_size: 1 +infer_mode: 2 +bilinear_pooling: False +# ext_groupsize: [1, 30] +low_groupsize: 1 +# # high_groupsize: 20 +ext_same_qformer: True +cache_dir: ./ckpt/pretrained_ckpt \ No newline at end of file diff --git a/video_salmonn/datasets/__init__.py b/video_salmonn/datasets/__init__.py new file mode 100644 index 0000000..2bcd824 --- /dev/null +++ b/video_salmonn/datasets/__init__.py @@ -0,0 +1,230 @@ +# from header import * +from torch.utils.data import Dataset, DataLoader +import torch +from .samplers import DistributedBatchSampler +from .sft_dataset import SupervisedAudioVisualDataset, SupervisedDataset +from .sft_dataset_nomix import SupervisedAudioVisualDataset4Test + +''' +def get_tokenizer(model): + tokenizer = LlamaTokenizer.from_pretrained(model) + tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2 + tokenizer.pad_token = tokenizer.eos_token + return tokenizer +''' + +def load_sft_dataset(args): + ''' + tokenizer = get_tokenizer(args['model_path']) + dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str + data_path = args["data_path"] + data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset + ''' + if args["data_type"] == "video": + data = SupervisedAudioVisualDataset( + args['data_type'], + video_data_path=args['data_path'], + video_root_path=args['image_root_path'], + ) + elif args["data_type"] == "image": + data = SupervisedAudioVisualDataset( + args['data_type'], + image_data_path=args['image_data_path'], + image_root_path=args['llava_root_path'], + ) + elif args["data_type"] == "videoimage": + data = SupervisedAudioVisualDataset( + args['data_type'], + video_data_path=args['data_path'], + video_root_path=args['image_root_path'], + image_data_path=args['image_data_path'], + image_root_path=args['llava_root_path'], + ) + elif args["data_type"] == "audio": + data = SupervisedAudioVisualDataset( + args['data_type'], + audio_data_path=args['audio_data_path'], + audio_root_path=args['image_root_path'], + use_whisper=args["use_whisper"], + sin_pos=args["sin_pos"], + return_raw=args["return_raw"] + ) + elif args["data_type"] == "audioimage": + data = SupervisedAudioVisualDataset( + args['data_type'], + audio_data_path=args['audio_data_path'], + audio_root_path=args['image_root_path'], + image_data_path=args['image_data_path'], + image_root_path=args['llava_root_path'], + use_whisper=args["use_whisper"], + ) + elif args["data_type"] == "audiovideoimage": + data = SupervisedAudioVisualDataset( + args['data_type'], + audio_data_path=args['audio_data_path'], + audio_root_path=args['image_root_path'], + image_data_path=args['image_data_path'], + image_root_path=args['llava_root_path'], + video_data_path=args['data_path'], + video_root_path=args['image_root_path'], + use_whisper=args["use_whisper"], + sin_pos=args["sin_pos"], + return_raw=args["return_raw"], + audio_only=args.get('audio_only', False), + video_only=args.get('video_only', False), + use_npy=args.get('use_npy', False), + ) + else: + data = SupervisedDataset(args['data_path'], args['image_root_path']) + + sampler = torch.utils.data.RandomSampler(data) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] + batch_sampler = DistributedBatchSampler( + sampler, + batch_size, + True, + rank, + world_size + ) + iter_ = DataLoader( + data, + batch_sampler=batch_sampler, + num_workers=3, + collate_fn=data.collate, + pin_memory=True + ) + return data, iter_, sampler + + +def load_sft_dataset_val(args, drop_last=True): + ''' + tokenizer = get_tokenizer(args['model_path']) + dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str + data_path = args["data_path"] + data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset + ''' + if args["data_type"] == "video" or args["data_type"] == "videoimage": + data = SupervisedAudioVisualDataset( + args['data_type'], + video_data_path=args['val_data_path'], + video_root_path=args['image_root_path'], + training=False, + ) + elif args["data_type"] == "image": + data = SupervisedAudioVisualDataset( + args['data_type'], + image_data_path=args['image_val_data_path'], + image_root_path=args['llava_root_path'], + training=False, + ) + elif args["data_type"] == "audio": + data = SupervisedAudioVisualDataset( + args['data_type'], + audio_data_path=args['audio_val_data_path'], + audio_root_path=args['image_root_path'], + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"] + ) + elif args["data_type"] == "audioimage": + data = SupervisedAudioVisualDataset( + args['data_type'], + audio_data_path=args['audio_val_data_path'], + audio_root_path=args['image_root_path'], + image_data_path=args['image_val_data_path'], + image_root_path=args['llava_root_path'], + use_whisper=args["use_whisper"], + training=False, + ) + # visualdata = SupervisedAudioVisualDataset( + # "image", + # image_data_path=args['image_val_data_path'], + # image_root_path=args['llava_root_path'], + # training=False, + # ) + # data = [visualdata, avdata] + + elif args["data_type"] == "audiovideoimage": + avdata = SupervisedAudioVisualDataset( + args['data_type'], + video_data_path=args['val_data_path'], + video_root_path=args['image_root_path'], + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"], + video_only=args.get('video_only', False), + ) + if args.get('video_only', False): + data = avdata + else: + visualdata = SupervisedAudioVisualDataset( + "audioimage", + audio_data_path=args['audio_val_data_path'], + audio_root_path=args['image_root_path'], + image_data_path=args['image_val_data_path'], + image_root_path=args['llava_root_path'], + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"] + ) + data = [visualdata, avdata] + else: + data = SupervisedDataset(args['data_path'], args['image_root_path']) + + sampler = torch.utils.data.RandomSampler(data) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] + batch_sampler = DistributedBatchSampler( + sampler, + batch_size, + drop_last, + rank=rank, + world_size=world_size + ) + if isinstance(data, list): + audio_sampler = torch.utils.data.RandomSampler(avdata) + audio_batch_sampler = DistributedBatchSampler( + audio_sampler, + batch_size, + True, + rank, + world_size + ) + audio_iter_ = DataLoader( + avdata, + batch_sampler=audio_batch_sampler, + num_workers=0, + collate_fn=avdata.collate, + pin_memory=True + ) + video_sampler = torch.utils.data.RandomSampler(visualdata) + video_batch_sampler = DistributedBatchSampler( + video_sampler, + batch_size, + True, + rank, + world_size + ) + image_iter_ = DataLoader( + visualdata, + batch_sampler=video_batch_sampler, + num_workers=0, + collate_fn=visualdata.collate, + pin_memory=True + ) + iter_ = [image_iter_, audio_iter_] + else: + iter_ = DataLoader( + data, + batch_sampler=batch_sampler, + num_workers=4, + collate_fn=data.collate, + pin_memory=True + ) + return data, iter_, sampler \ No newline at end of file diff --git a/video_salmonn/datasets/samplers.py b/video_salmonn/datasets/samplers.py new file mode 100644 index 0000000..d3ce1e9 --- /dev/null +++ b/video_salmonn/datasets/samplers.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""batch samplers that work with either random or sequential data samplers""" +import math +import os +import sys + +import torch +from torch.utils import data +import numpy as np + + +class RandomSampler(data.sampler.Sampler): + r""" + Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, + but this class lets the user set an epoch like DistributedSampler + Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify ``num_samples`` to draw. + Arguments: + data_source (Dataset): dataset to sample from + num_samples (int): number of samples to draw, default=len(dataset) + replacement (bool): samples are drawn with replacement if ``True``, default=False + """ + + def __init__(self, data_source, replacement=False, num_samples=None): + super(RandomSampler, self).__init__(data_source) + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.epoch = -1 + + if self._num_samples is not None and replacement is False: + raise ValueError("With replacement=False, num_samples should not be specified, " + "since a random permute will be performed.") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples)) + if not isinstance(self.replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(self.replacement)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + g = torch.Generator() + if self.epoch >= 0: + g.manual_seed(self.epoch) + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, + generator=g).tolist() + else: + yield from torch.randperm(n, generator=self.generator).tolist() + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class DistributedSequentialSampler(data.sampler.Sampler): + def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2): + super().__init__(num_samples) + if rank == -1: + rank = 0 + world_size = 1 + self.num_samples = num_samples + self.rank = rank + self.world_size = world_size + self.start_iter = 0 + self.train_iters = train_iters + self.batch_size = batch_size + self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)] + + def __iter__(self): + for idx in range(self.start_iter, self.train_iters * 10): + batch = [(idx + bias) % self.num_samples for bias in self.batch_bias] + tbatch = self._batch(batch) + yield tbatch + + def __len__(self): + return self.train_iters + + def _batch(self, batch): + """extracts samples only pertaining to this worker's batch""" + start = self.rank*self.batch_size//self.world_size + end = (self.rank+1)*self.batch_size//self.world_size + return batch[start:end] + + +class DistributedBatchSampler(data.sampler.BatchSampler): + """ + similar to normal implementation of distributed sampler, except implementation is at the + batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary + data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. + """ + def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None): + super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) + if rank == -1: + assert False, 'should not be here' + self.rank = rank + self.world_size = world_size + self.sampler.wrap_around = 0 + self.wrap_around = 0 + self.wrap_last = wrap_last + self.start_iter = 0 + self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps + + def __iter__(self): + batch = [] + i = 0 + for idx in self.data_iterator(self.sampler, wrap_around=False): + batch.append(idx) + if len(batch) == self.batch_size: + tbatch = self._batch(batch) + if i >= self.start_iter * self.effective_batch_size: + yield tbatch + self.start_iter = 0 + i += len(batch) + batch = [] + batch_len = len(batch) + if batch_len > 0 and not self.drop_last: + if self.wrap_last: + self.sampler.wrap_around -= (self.batch_size) + self.wrap_around += (len(batch)) + self.wrap_around %= self.batch_size + yield self._batch(batch) + if self.wrap_last: + self.sampler.wrap_around += self.batch_size + + def data_iterator(self, _iter, wrap_around=False): + """iterates through data and handles wrap around""" + for i, idx in enumerate(_iter): + if i < self.wrap_around%self.batch_size: + continue + if wrap_around: + self.wrap_around += 1 + self.wrap_around %= self.batch_size + yield idx + + def _batch(self, batch): + """extracts samples only pertaining to this worker's batch""" + start = self.rank*self.batch_size//self.world_size + end = (self.rank+1)*self.batch_size//self.world_size + return batch[start:end] diff --git a/video_salmonn/datasets/sft_dataset.py b/video_salmonn/datasets/sft_dataset.py new file mode 100644 index 0000000..bd3ed92 --- /dev/null +++ b/video_salmonn/datasets/sft_dataset.py @@ -0,0 +1,513 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import json +import csv +from tqdm import tqdm +import random +from torch.nn.utils.rnn import pad_sequence +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence +from fractions import Fraction +import soundfile as sf + +import torch +import torch.distributed as dist +import transformers +from torch.utils.data import Dataset +import numpy as np +from tqdm import tqdm +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +from pytorchvideo.data.encoded_video import EncodedVideo +from transformers import WhisperFeatureExtractor + + +AUDIO_EXISTANCE = ["Is there any sound?", "Can you hear anything?", "Is there audio with this video?"] +AUDIO_VIDEO_MATCHING = [ + "Is the audio compatible with the video?", + "Does the audio come from the same source as the video?", + "Is the audio related to the video?" +] +video_specaug_params = { + "mask_rate": 0.0, +} + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, image_root_path: str): + super(SupervisedDataset, self).__init__() + + with open(data_path, 'r') as f: + json_data = json.load(f) + + self.image_path_list, self.caption_list = [], [] + for item in json_data: + one_image_name, one_caption = item["image_name"], item["conversation"] + # TODO: stage 2 dataset format is invalid + if not one_image_name.endswith('.jpg'): + one_image_name += '.jpg' + one_image_path = image_root_path + '/{}'.format(one_image_name) + self.image_path_list.append(one_image_path) + self.caption_list.append(one_caption) + print(f'[!] collect {len(self.image_path_list)} samples for training') + + def __len__(self): # number of instances + return len(self.image_path_list) + + #def __getitem__(self, i) -> Dict[str, torch.Tensor]: # how to get item, 取一个样本 + def __getitem__(self, i): + print(i) + return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i]) + + def collate(self, instances): + image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts")) + return dict( + image_paths=image_paths, + output_texts=output_texts + ) + + +class SupervisedAudioVisualDataset(Dataset): + """Dataset for supervised fine-tuning with audio captioning.""" + + def __init__(self, + data_type: str, + audio_data_path: str = "", + audio_root_path: str = "", + video_data_path: str = "", + image_data_path: str = "", + video_root_path: str = "", + image_root_path: str = "", + sample_rate: int = 16000, + sample_per_clip: int = 2, + clip_duration: int = 1, + use_whisper: str = "", + use_blip: str = "", + training: bool = True, + # [Yu] + sin_pos: bool = False, + return_raw: bool = False, + audio_only: bool = False, + video_only: bool = False, + use_nemo: bool = False, + # [npy] + use_npy: bool = False + ): + super(SupervisedAudioVisualDataset, self).__init__() + if audio_data_path == "" and video_data_path == "" and image_data_path == "": + raise + self.modality_range = [] + self.audiofiles = [] + self.spokencocofiles = [] + self.training = training + # [Yu] + self.sin_pos = sin_pos + self.return_raw = return_raw + self.audio_path_list, self.audio_caption_list = [], [] + self.audio_only = audio_only + self.video_only = video_only + self.use_nemo = use_nemo + # [npy] + self.use_npy = use_npy + if audio_data_path != "" and "audio" in data_type and audio_data_path is not None: + self.audio_path_list, self.audio_caption_list = self.get_data_json( + audio_data_path, audio_root_path, modality="audio", + ) + self.modality_range.append("audio") + self.image_path_list, self.image_caption_list = [], [] + if image_data_path != "" and "image" in data_type and image_data_path is not None: + self.image_path_list, self.image_caption_list = self.get_data_json( + image_data_path, image_root_path, modality="image", + ) + self.modality_range.append("image") + self.video_path_list, self.video_caption_list = [], [] + if video_data_path != "" and "video" in data_type and video_data_path is not None: + self.video_path_list, self.video_caption_list = self.get_data_json( + video_data_path, video_root_path, modality="video", + ) + if data_type != "audiovideoimage": + self.modality_range.append("video") + self.frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=sample_per_clip) + self.clip_sampler = UniformClipSampler( + clip_duration=clip_duration, backpad_last=True + ) + self.sample_per_clip = sample_per_clip + self.clip_duration = clip_duration + self.use_whisper = use_whisper + self.use_blip = use_blip + self.sample_rate = sample_rate + self.data_type = data_type + if self.data_type == "audiovideoimage" and self.training: + if audio_only: + self.modality_range = ["audiovideoimage", "audioimage", "audio"] + elif video_only: + self.modality_range = ["audiovideoimage"] + else: + self.modality_range = ["audiovideoimage", "audioimage"] + print(self.modality_range) + elif self.data_type == "audiovideoimage": + self.modality_range = ["audiovideoimage"] + self.modality = random.choice(self.modality_range) + if self.use_whisper == "true": + # whispermodel = "/mnt/bn/audio-visual-llm-data/yuwenyi/ckpt/whisper/whisper-large-v3" + whispermodel = "/mnt/bn/audio-visual-llm-data/yuwenyi/ckpt/whisper/whisper_large_v2" + self.transform = WhisperFeatureExtractor.from_pretrained(whispermodel) + self.use_whisper = True + + def get_data_json(self, data_path, root_path, modality='image'): + with open(data_path, 'r') as f: + json_data = json.load(f) + # if not self.training: + # json_data = json_data[:2000] + if self.video_only and not self.training: + json_data = json_data[:100] + elif self.video_only and self.training: + json_data = json_data[:1000] + + path_list, caption_list = [], [] + for item in json_data: + one_image_name, one_caption = item["image_name"], item["conversation"] + if isinstance(one_image_name, list) and "SpokenCOCO" in one_image_name[1]: + self.spokencocofiles.append(one_image_name[1]) + elif "audiocaps" in one_image_name: + self.audiofiles.append(one_image_name) + if modality in ["image", "video", "audio"]: + one_path = one_image_name + else: + one_path = root_path + '/{}'.format(one_image_name) + # if modality == "image" or os.path.exists(one_path): + path_list.append(one_path) + caption_list.append(one_caption) + print(f'[!] collect {len(path_list)} {modality} samples for {"train" if self.training else "valid"}') + return path_list, caption_list + + def __len__(self): # number of instances + return len(self.audio_path_list) + len(self.image_path_list) + len(self.video_path_list) + + def get_audio(self, i, audiopath=None): + i = i % max(len(self.audio_path_list), 1) + if audiopath is None: + audiopath = self.audio_path_list[i] + if self.use_whisper: + audio, _ = sf.read(audiopath) + if len(audio.shape) == 2: + audio = audio[:, 0] + if audio.shape[0] < 3 * self.sample_rate: + audio = np.concatenate((audio, np.zeros((3 * self.sample_rate - audio.shape[0]), dtype=float)), axis=0) + if len(audio) > 30 * self.sample_rate and self.sin_pos: + audio_list = [audio[i: i + 30 * self.sample_rate] for i in range(0, len(audio), 30 * self.sample_rate)] + spectrogram_list = [] + for audio_piece in audio_list: + spectrogram_piece = self.transform( + audio_piece, + sampling_rate=self.sample_rate, + return_tensors="pt", + max_length=30 * self.sample_rate, + ) + spectrogram_list.append(spectrogram_piece["input_features"].squeeze()) + spectrogram = torch.stack(spectrogram_list, dim=0) + return dict( + image_paths=spectrogram, + output_texts=copy.deepcopy(self.audio_caption_list[i]) if self.audio_caption_list != [] else None, + modality="audio", + orig_paths=audiopath, + raw_audio=audio_list if self.return_raw else None, + ) + else: + spectrogram = self.transform( + audio, + sampling_rate=self.sample_rate, + return_tensors="pt", + max_length=30 * self.sample_rate, + ) + spectrogram = spectrogram["input_features"].squeeze() + return dict( + image_paths=spectrogram, + output_texts=copy.deepcopy(self.audio_caption_list[i]) if self.audio_caption_list != [] else None, + modality="audio", + orig_paths=audiopath, + raw_audio=[audio[:30 * self.sample_rate]] if self.return_raw else None, + ) + elif self.use_nemo: + pass + else: + return dict( + image_paths=audiopath, + output_texts=copy.deepcopy(self.audio_caption_list[i]) if self.audio_caption_list != [] else None, + modality="audio", + ) + + def get_image(self, i): + if i >= len(self.image_path_list): + i = i % len(self.image_path_list) + if not self.training and isinstance(self.image_path_list[i], list): + imagepath = self.image_path_list[i][0] + else: + imagepath = self.image_path_list[i] + return dict(image_paths=imagepath, output_texts=copy.deepcopy(self.image_caption_list[i]), modality="image") + + def get_video(self, i, videopath=None): + if videopath is None: + if i >= len(self.video_path_list): + i = i % len(self.video_path_list) + videopath = self.video_path_list[i] + if isinstance(videopath, list): + videopath = videopath[0] + + if self.training: + if self.use_npy: # npy training + return dict(image_paths=videopath, output_texts=copy.deepcopy(self.video_caption_list[i]), modality="video") + video = EncodedVideo.from_path( + videopath, + decoder="decord", + decode_audio=False, + **{"sample_rate": self.sample_rate}, + ) + if "egovideos" in videopath or "how2videos" in videopath: + durations = videopath[:-4].split("_")[-2:] + if durations[-1] == "sum": + duration = 30 + else: + duration = float(durations[1]) - float(durations[0]) + else: + duration = video.duration + + all_clips_timepoints = self.get_clip_timepoints( + self.clip_sampler, duration) + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + try: + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + video_clip = self.frame_sampler(clip["video"]) + if "mask_rate" in video_specaug_params and random.random() < video_specaug_params["mask_rate"]: + video_clip = video_clip * 0 # mask specific video frame + video_clip = video_clip / 255.0 # since this is float, need 0-1 + all_video.append(video_clip) + except: + print("skipped frame {}".format(clip_timepoints)) + print(videopath) + pass + return dict(image_paths=all_video, output_texts=copy.deepcopy(self.video_caption_list[i]), modality="video") + + def get_audioimage(self, i): + image_data = self.get_image(i) + if isinstance(image_data["image_paths"], list): + audiopath = image_data["image_paths"][1] + image_data["image_paths"] = image_data["image_paths"][0] + prompt = image_data["output_texts"] + avmask = 1 + if image_data["output_texts"][1]["value"] == "audio_text_matching": + if random.random() > 0.5: + audio_data = self.get_audio(i, random.choice(self.spokencocofiles)) + prompt[1]["value"] = "No" + else: + audio_data = self.get_audio(i, audiopath) + prompt[1]["value"] = "Yes" + else: + audio_data = self.get_audio(i, audiopath) + else: + if self.audio_only: + audio_data = self.get_audio(i) + prompt = image_data["output_texts"] + avmask = 0 + else: + audio_data = self.get_audio(i) + image_data["output_texts"][0]['value'] = "In the image, " + image_data["output_texts"][0]['value'].lower() + promptlist = [audio_data["output_texts"], image_data["output_texts"]] + random.shuffle(promptlist) + avmask = 1 + if random.random() < 0.5: + prompt = promptlist[0] + else: + userprompt = promptlist[0][0]['value'] + ", and " + promptlist[1][0]['value'].lower() + gptresponse = promptlist[0][1]['value'] + ", and " + promptlist[1][1]['value'].lower() + prompt = [{'from': 'human', 'value': userprompt}, {'from': 'gpt', 'value': gptresponse}] + return dict( + image_paths=[audio_data["image_paths"], image_data["image_paths"]], + output_texts=prompt, + modality="audioimage", + mask_audio=avmask, + orig_paths=audio_data["orig_paths"], + raw_audio=audio_data["raw_audio"] + ) + + def get_videoaudioimage(self, i): + if i >= len(self.video_path_list): + i = i % len(self.video_path_list) + videopath = self.video_path_list[i] + # print(videopath) + if isinstance(videopath, list): + videopath, audiopath = videopath + video_data = self.get_video(i, videopath) + audio_data = self.get_audio(i, audiopath) + if "only_need_video" in videopath: + avmask = [1, 0] + # elif "egovideos" in videopath[0]: + # avmask = [1, 1] if random.random() > 0.2 else [0, 1] + # elif "how2videos" in videopath[0]: + # avmask = [1, 1] if random.random() > 0.2 else [1, 0] + else: + avmask = [1, 1] + output_texts = video_data["output_texts"] + if self.use_npy and self.training: # npy training + pass + elif random.random() > 0.9 and len(self.audiofiles) != 0 and "yuwenyi" not in audiopath and self.training: + output_texts[0]["value"] = random.choice(AUDIO_VIDEO_MATCHING) + if random.random() > 0.5: + audio_data = self.get_audio(i, random.choice(self.audiofiles)) + output_texts[1]["value"] = "No." + else: + output_texts[1]["value"] = "Yes." + return dict( + image_paths=[audio_data["image_paths"], video_data["image_paths"]], + output_texts=output_texts, + modality=self.data_type, + mask_audio=avmask, + orig_paths=audio_data["orig_paths"], + raw_audio=audio_data["raw_audio"] + ) + else: + video_data = self.get_video(i, videopath) + audio_data = self.get_audio(i) + if random.random() < 0.8: + output_texts = video_data["output_texts"] + mask_audio = [1, 0] + # mask_audio = [1, 1] + else: + video_data["output_texts"][0]['value'] = "In the video, " + video_data["output_texts"][0]['value'].lower() + promptlist = [audio_data["output_texts"], video_data["output_texts"]] + random.shuffle(promptlist) + userprompt = promptlist[0][0]['value'] + ", and, " + promptlist[1][0]['value'].lower() + gptresponse = promptlist[0][1]['value'] + ", and, " + promptlist[1][1]['value'].lower() + output_texts = [{'from': 'human', 'value': userprompt}, {'from': 'gpt', 'value': gptresponse}] + mask_audio = [1, 1] + # promptlist = [audio_data["output_texts"], video_data["output_texts"]] + return dict( + image_paths=[audio_data["image_paths"], video_data["image_paths"]], + output_texts=output_texts, + modality=self.data_type, + mask_audio=mask_audio, + orig_paths=audio_data["orig_paths"], + raw_audio=audio_data["raw_audio"] + ) + + def __getitem__(self, i): + if self.data_type == "audioimage" or self.modality == "audioimage": + return self.get_audioimage(i) + elif self.modality == "audiovideoimage": + return self.get_videoaudioimage(i) + elif self.modality == "audio": + return self.get_audio(i) + elif self.modality == "image": + return self.get_image(i) + elif self.modality == "video": + return self.get_video(i) + + def sample_modality(self): + self.modality = random.choice(self.modality_range) + + def collate(self, instances): + image_paths = [] + output_texts = [] + first_modality = instances[0]["modality"] + audiomasks = [] + orig_paths = [] + raw_audios = [] + trigger_reduce = 0 + if "video" in first_modality: + length_thred = int(30 / self.clip_duration * self.sample_per_clip) + for instance in instances: + assert instance["modality"] == first_modality # should have the same modality in one minibatch + if instance["modality"] == "video": + if len(instance["image_paths"]) < length_thred: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + elif instance["modality"] in ["image", "audio", "audioimage"]: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + if "mask_audio" in instance: + if instance["mask_audio"] == 1: + instance["mask_audio"] = [1, 1] + else: + instance["mask_audio"] = [1, 0] + audiomasks.append(instance["mask_audio"] if "mask_audio" in instance else [1, 1]) + orig_paths.append(instance["orig_paths"] if "orig_paths" in instance else "") + raw_audios.append(instance["raw_audio"] if "raw_audio" in instance else None) + elif instance["modality"] == "audiovideoimage": + if self.use_npy or len(instance["image_paths"][1]) < length_thred: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + audiomasks.append(instance["mask_audio"] if "mask_audio" in instance else [1, 1]) + orig_paths.append(instance["orig_paths"] if "orig_paths" in instance else "") + raw_audios.append(instance["raw_audio"] if "raw_audio" in instance else None) + # reduce if long + # if len(instance["output_texts"][1]["value"].split()) > 80: + # trigger_reduce = max(trigger_reduce, len(instance["output_texts"][1]["value"].split()) // 80) + if len(instance["output_texts"][1]["value"].split()) > 500: + trigger_reduce = max(trigger_reduce, len(instance["output_texts"][1]["value"].split()) // 500) + elif len(instance["output_texts"]) > 2: + trigger_reduce = 3 + + if "/AMI/BeamformIt/" in instance["orig_paths"]: + image_paths = [instance["image_paths"]] + output_texts = [instance["output_texts"]] + audiomasks = [instance["mask_audio"] if "mask_audio" in instance else 0] + orig_paths = [instance["orig_paths"] if "orig_paths" in instance else ""] + raw_audios = [instance["raw_audio"] if "raw_audio" in instance else None] + break + + if image_paths == []: + if first_modality == "audiovideoimage": + image_paths.append( + [instances[0]["image_paths"][0], instances[0]["image_paths"][1][:length_thred]] + ) + audiomasks.append(instances[0]["mask_audio"]) + else: + image_paths.append(instances[0]["image_paths"][:length_thred]) + output_texts.append(instances[0]["output_texts"]) + orig_paths.append(instances[0]["orig_paths"]) + raw_audios.append(instances[0]["raw_audio"] if "raw_audio" in instances[0] else None) + elif len(image_paths) >= 2 and trigger_reduce > 0 and self.training: + cut_len = len(image_paths) // (trigger_reduce + 1) + 1 + # print("reducing batchsize to {}".format(cut_len)) + image_paths = image_paths[:cut_len] + output_texts = output_texts[:cut_len] + audiomasks = audiomasks[:cut_len] + orig_paths = orig_paths[:cut_len] + raw_audios = raw_audios[:cut_len] + + self.sample_modality() + return dict( + image_paths=image_paths, + output_texts=output_texts, + modality=first_modality, + audiomasks=torch.tensor(audiomasks) if audiomasks != [] else None, + orig_paths=orig_paths, + raw_audios=None if None in raw_audios else raw_audios, + ) + + def get_clip_timepoints(self, clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints \ No newline at end of file diff --git a/video_salmonn/datasets/sft_dataset_nomix.py b/video_salmonn/datasets/sft_dataset_nomix.py new file mode 100644 index 0000000..8500f4c --- /dev/null +++ b/video_salmonn/datasets/sft_dataset_nomix.py @@ -0,0 +1,503 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import json +import csv +from tqdm import tqdm +import random +from torch.nn.utils.rnn import pad_sequence +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence +from fractions import Fraction +import soundfile as sf + +import torch +import torch.distributed as dist +import transformers +from torch.utils.data import Dataset +import numpy as np +from tqdm import tqdm +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +from pytorchvideo.data.encoded_video import EncodedVideo +from transformers import WhisperFeatureExtractor + + +AUDIO_EXISTANCE = ["Is there any sound?", "Can you hear anything?", "Is there audio with this video?"] +AUDIO_VIDEO_MATCHING = [ + "Is the audio compatible with the video?", + "Does the audio come from the same source as the video?", + "Is the audio related to the video?" +] +video_specaug_params = { + "mask_rate": 0.0, +} + +# class SupervisedDataset(Dataset): +# """Dataset for supervised fine-tuning.""" + +# def __init__(self, data_path: str, image_root_path: str): +# super(SupervisedDataset, self).__init__() + +# with open(data_path, 'r') as f: +# json_data = json.load(f) + +# self.image_path_list, self.caption_list = [], [] +# for item in json_data: +# one_image_name, one_caption = item["image_name"], item["conversation"] +# # TODO: stage 2 dataset format is invalid +# if not one_image_name.endswith('.jpg'): +# one_image_name += '.jpg' +# one_image_path = image_root_path + '/{}'.format(one_image_name) +# self.image_path_list.append(one_image_path) +# self.caption_list.append(one_caption) +# print(f'[!] collect {len(self.image_path_list)} samples for training') + +# def __len__(self): # number of instances +# return len(self.image_path_list) + +# #def __getitem__(self, i) -> Dict[str, torch.Tensor]: # how to get item, 取一个样本 +# def __getitem__(self, i): +# print(i) +# return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i]) + +# def collate(self, instances): +# image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts")) +# return dict( +# image_paths=image_paths, +# output_texts=output_texts +# ) + + +class SupervisedAudioVisualDataset4Test(Dataset): + """Dataset for supervised fine-tuning with audio captioning.""" + + def __init__(self, + data_type: str, + audio_data_path: str = "", + audio_root_path: str = "", + video_data_path: str = "", + image_data_path: str = "", + video_root_path: str = "", + image_root_path: str = "", + sample_rate: int = 16000, + sample_per_clip: int = 2, + clip_duration: int = 1, + use_whisper: str = "", + use_blip: str = "", + training: bool = True, + # [Yu] + sin_pos: bool = False, + return_raw: bool = False, + cache_dir: str = "", + ): + super(SupervisedAudioVisualDataset4Test, self).__init__() + if audio_data_path == "" and video_data_path == "" and image_data_path == "": + raise + self.modality_range = [] + self.audiofiles = [] + self.spokencocofiles = [] + self.training = training + # [Yu] + self.sin_pos = sin_pos + self.return_raw = return_raw + self.audio_path_list, self.audio_caption_list = [], [] + if audio_data_path != "" and "audio" in data_type: + self.audio_path_list, self.audio_caption_list = self.get_data_json( + audio_data_path, audio_root_path, modality="audio", + ) + self.modality_range.append("audio") + self.image_path_list, self.image_caption_list = [], [] + if image_data_path != "" and "image" in data_type: + self.image_path_list, self.image_caption_list = self.get_data_json( + image_data_path, image_root_path, modality="image", + ) + self.modality_range.append("image") + self.video_path_list, self.video_caption_list = [], [] + if video_data_path != "" and "video" in data_type: + self.video_path_list, self.video_caption_list = self.get_data_json( + video_data_path, video_root_path, modality="video", + ) + if data_type != "audiovideoimage": + self.modality_range.append("video") + self.frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=sample_per_clip) + self.clip_sampler = UniformClipSampler( + clip_duration=clip_duration, backpad_last=True + ) + self.sample_per_clip = sample_per_clip + self.clip_duration = clip_duration + self.use_whisper = use_whisper + self.use_blip = use_blip + self.sample_rate = sample_rate + self.data_type = data_type + if self.data_type == "audiovideoimage" and self.training: + self.modality_range = ["audiovideoimage", "audioimage", "audio"] + elif self.data_type == "audiovideoimage": + self.modality_range = ["audiovideoimage"] + self.modality = random.choice(self.modality_range) + if self.use_whisper == "true": + whispermodel = "openai/whisper-large-v2" + self.transform = WhisperFeatureExtractor.from_pretrained(whispermodel, cache_dir=cache_dir) + self.use_whisper = True + + def get_data_json(self, data_path, root_path, modality='image'): + with open(data_path, 'r') as f: + json_data = json.load(f) + # if not self.training: + # json_data = json_data[:2000] + + path_list, caption_list = [], [] + for item in json_data: + one_image_name, one_caption = item["image_name"], item["conversation"] + if isinstance(one_image_name, list) and "SpokenCOCO" in one_image_name[1]: + self.spokencocofiles.append(one_image_name[1]) + elif "audiocaps" in one_image_name: + self.audiofiles.append(one_image_name) + if modality in ["image", "video", "audio"]: + one_path = one_image_name + else: + one_path = root_path + '/{}'.format(one_image_name) + # if modality == "image" or os.path.exists(one_path): + path_list.append(one_path) + caption_list.append(one_caption) + print(f'[!] collect {len(path_list)} {modality} samples for {"train" if self.training else "valid"}') + return path_list, caption_list + + def __len__(self): # number of instances + if self.training: + return len(self.audio_path_list) + len(self.image_path_list) + len(self.video_path_list) + else: + if self.data_type == "audio": + return len(self.audio_path_list) + elif self.data_type == "audiovideoimage": + return len(self.video_path_list) + elif self.data_type == "audioimage": + return len(self.image_path_list) + + def get_audio(self, i, audiopath=None): + i = i % max(len(self.audio_path_list), 1) + if audiopath is None: + audiopath = self.audio_path_list[i] + if self.use_whisper: + audio, _ = sf.read(audiopath) + if len(audio.shape) == 2: + audio = audio[:, 0] + if audio.shape[0] < 3 * self.sample_rate: + audio = np.concatenate((audio, np.zeros((3 * self.sample_rate - audio.shape[0]), dtype=float)), axis=0) + if len(audio) > 30 * self.sample_rate and self.sin_pos: + audio_list = [audio[i: i + 30 * self.sample_rate] for i in range(0, len(audio), 30 * self.sample_rate)] + spectrogram_list = [] + for audio_piece in audio_list: + spectrogram_piece = self.transform( + audio_piece, + sampling_rate=self.sample_rate, + return_tensors="pt", + max_length=30 * self.sample_rate, + ) + spectrogram_list.append(spectrogram_piece["input_features"].squeeze()) + spectrogram = torch.stack(spectrogram_list, dim=0) + return dict( + image_paths=spectrogram, + output_texts=self.audio_caption_list[i] if self.audio_caption_list != [] else None, + modality="audio", + orig_paths=[audiopath, None], + raw_audio=audio_list if self.return_raw else None, + ) + else: + spectrogram = self.transform( + audio, + sampling_rate=self.sample_rate, + return_tensors="pt", + max_length=30 * self.sample_rate, + ) + spectrogram = spectrogram["input_features"].squeeze() + return dict( + image_paths=spectrogram, + output_texts=self.audio_caption_list[i] if self.audio_caption_list != [] else None, + modality="audio", + orig_paths=[audiopath, None], + raw_audio=[audio[:30 * self.sample_rate]] if self.return_raw else None, + ) + else: + return dict( + image_paths=audiopath, + output_texts=self.audio_caption_list[i] if self.audio_caption_list != [] else None, + modality="audio", + ) + + def get_image(self, i): + if i >= len(self.image_path_list): + i = i % len(self.image_path_list) + # if not self.training and isinstance(self.image_path_list[i], list): + # imagepath = self.image_path_list[i][0] + # else: + # imagepath = self.image_path_list[i] + imagepath = self.image_path_list[i] + return dict(image_paths=imagepath, output_texts=self.image_caption_list[i], modality="image", orig_paths=[None, imagepath]) + + def get_video(self, i, videopath=None): + if videopath is None: + if i >= len(self.video_path_list): + i = i % len(self.video_path_list) + videopath = self.video_path_list[i] + if isinstance(videopath, list): + videopath = videopath[0] + + if self.training: + video = EncodedVideo.from_path( + videopath, + decoder="decord", + decode_audio=False, + **{"sample_rate": self.sample_rate}, + ) + else: + video = EncodedVideo.from_path( + videopath, + decoder="pyav", + decode_audio=False, + # **{"sample_rate": sample_rate}, + ) + if "egovideos" in videopath or "how2videos" in videopath: + durations = videopath[:-4].split("_")[-2:] + if durations[-1] == "sum": + duration = 30 + else: + duration = float(durations[1]) - float(durations[0]) + else: + duration = video.duration + + all_clips_timepoints = self.get_clip_timepoints( + self.clip_sampler, duration) + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + try: + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + video_clip = self.frame_sampler(clip["video"]) + if "mask_rate" in video_specaug_params and random.random() < video_specaug_params["mask_rate"]: + video_clip = video_clip * 0 # mask specific video frame + video_clip = video_clip / 255.0 # since this is float, need 0-1 + all_video.append(video_clip) + except: + print("skipped frame {}".format(clip_timepoints)) + print(videopath) + pass + return dict(image_paths=all_video, output_texts=self.video_caption_list[i], modality="video", orig_path=[None, videopath]) + + def get_audioimage(self, i): + image_data = self.get_image(i) + if isinstance(image_data["image_paths"], list): + audiopath = image_data["image_paths"][1] + image_data["image_paths"] = image_data["image_paths"][0] + prompt = image_data["output_texts"] + avmask = 1 + if image_data["output_texts"][1]["value"] == "audio_text_matching": + if random.random() > 0.5: + audio_data = self.get_audio(i, random.choice(self.spokencocofiles)) + prompt[1]["value"] = "No" + else: + audio_data = self.get_audio(i, audiopath) + prompt[1]["value"] = "Yes" + else: + audio_data = self.get_audio(i, audiopath) + else: + audio_data = self.get_audio(i) + prompt = image_data["output_texts"] + # image_data["output_texts"][0]['value'] = "In the image, " + image_data["output_texts"][0]['value'].lower() + # promptlist = [audio_data["output_texts"], image_data["output_texts"]] + # random.shuffle(promptlist) + avmask = 0 + # if random.random() < 0.5: + # prompt = promptlist[0] + # else: + # userprompt = promptlist[0][0]['value'] + ", and " + promptlist[1][0]['value'].lower() + # gptresponse = promptlist[0][1]['value'] + ", and " + promptlist[1][1]['value'].lower() + # prompt = [{'from': 'human', 'value': userprompt}, {'from': 'gpt', 'value': gptresponse}] + return dict( + image_paths=[audio_data["image_paths"], image_data["image_paths"]], + output_texts=prompt, + modality="audioimage", + mask_audio=avmask, + orig_paths=[audio_data["orig_paths"], image_data["image_paths"]], + raw_audio=audio_data["raw_audio"] + ) + + def get_videoaudioimage(self, i): + if i >= len(self.video_path_list): + i = i % len(self.video_path_list) + videopath = self.video_path_list[i] + # print(videopath) + if isinstance(videopath, list): + videopath, audiopath = videopath + video_data = self.get_video(i, videopath) + audio_data = self.get_audio(i, audiopath) + # if "egovideos" in videopath[0]: + # avmask = [1, 1] if random.random() > 0.2 else [0, 1] + # elif "how2videos" in videopath[0]: + # avmask = [1, 1] if random.random() > 0.2 else [1, 0] + # else: + # avmask = [1, 1] + avmask = [1, 1] + output_texts = video_data["output_texts"] + if random.random() > 0.9 and len(self.audiofiles) != 0 and "yuwenyi" not in audiopath and self.training: + output_texts[0]["value"] = random.choice(AUDIO_VIDEO_MATCHING) + if random.random() > 0.5: + audio_data["image_paths"] = self.get_audio(i, random.choice(self.audiofiles))["image_paths"] + output_texts[1]["value"] = "No." + else: + output_texts[1]["value"] = "Yes." + return dict( + image_paths=[audio_data["image_paths"], video_data["image_paths"]], + output_texts=output_texts, + modality=self.data_type, + mask_audio=avmask, + orig_paths=videopath, + raw_audio=audio_data["raw_audio"] + ) + else: + video_data = self.get_video(i, videopath) + audio_data = self.get_audio(i) + if random.random() < 1.0: + output_texts = video_data["output_texts"] + mask_audio = [1, 0] + # mask_audio = [1, 1] + else: + video_data["output_texts"][0]['value'] = "In the video, " + video_data["output_texts"][0]['value'].lower() + promptlist = [audio_data["output_texts"], video_data["output_texts"]] + random.shuffle(promptlist) + userprompt = promptlist[0][0]['value'] + ", and, " + promptlist[1][0]['value'].lower() + gptresponse = promptlist[0][1]['value'] + ", and, " + promptlist[1][1]['value'].lower() + output_texts = [{'from': 'human', 'value': userprompt}, {'from': 'gpt', 'value': gptresponse}] + mask_audio = [1, 1] + # promptlist = [audio_data["output_texts"], video_data["output_texts"]] + return dict( + image_paths=[audio_data["image_paths"], video_data["image_paths"]], + output_texts=output_texts, + modality=self.data_type, + mask_audio=mask_audio, + orig_paths=[audio_data["orig_paths"], videopath], + raw_audio=audio_data["raw_audio"] + ) + + def __getitem__(self, i): + if self.data_type == "audioimage" or self.modality == "audioimage": + return self.get_audioimage(i) + elif self.modality == "audiovideoimage": + return self.get_videoaudioimage(i) + elif self.modality == "audio": + return self.get_audio(i) + elif self.modality == "image": + return self.get_image(i) + elif self.modality == "video": + return self.get_video(i) + + def sample_modality(self): + self.modality = random.choice(self.modality_range) + + def collate(self, instances): + image_paths = [] + output_texts = [] + first_modality = instances[0]["modality"] + audiomasks = [] + orig_paths = [] + raw_audios = [] + trigger_reduce = 0 + if "video" in first_modality: + length_thred = int(30 / self.clip_duration * self.sample_per_clip) + for instance in instances: + assert instance["modality"] == first_modality # should have the same modality in one minibatch + if instance["modality"] == "video": + if len(instance["image_paths"]) < length_thred: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + orig_paths.append(instance["orig_paths"]) + elif instance["modality"] in ["image", "audio", "audioimage"]: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + if "mask_audio" in instance: + if instance["mask_audio"] == 1: + instance["mask_audio"] = [1, 1] + else: + instance["mask_audio"] = [1, 0] + audiomasks.append(instance["mask_audio"] if "mask_audio" in instance else [1, 1]) + # audiomasks.append([1, 1] if instance["mask_audio"] == 1 else [1, 0]) + # orig_paths.append(instance["orig_paths"] if "orig_paths" in instance else "") + orig_paths.append(instance["orig_paths"]) + raw_audios.append(instance["raw_audio"] if "raw_audio" in instance else None) + elif instance["modality"] == "audiovideoimage": + if len(instance["image_paths"][1]) < length_thred: + image_paths.append(instance["image_paths"]) + output_texts.append(instance["output_texts"]) + audiomasks.append(instance["mask_audio"] if "mask_audio" in instance else [1, 1]) + # orig_paths.append(instance["orig_paths"] if "orig_paths" in instance else "") + orig_paths.append(instance["orig_paths"]) + raw_audios.append(instance["raw_audio"] if "raw_audio" in instance else None) + # reduce if long + # if len(instance["output_texts"][1]["value"].split()) > 80: + # trigger_reduce = max(trigger_reduce, len(instance["output_texts"][1]["value"].split()) // 80) + # if len(instance["output_texts"][1]["value"].split()) > 500: + # trigger_reduce = max(trigger_reduce, len(instance["output_texts"][1]["value"].split()) // 500) + # elif len(instance["output_texts"]) > 2: + # trigger_reduce = 3 + + # if "/AMI/BeamformIt/" in instance["orig_paths"]: + # image_paths = [instance["image_paths"]] + # output_texts = [instance["output_texts"]] + # audiomasks = [instance["mask_audio"] if "mask_audio" in instance else 0] + # orig_paths = [instance["orig_paths"] if "orig_paths" in instance else ""] + # raw_audios = [instance["raw_audio"] if "raw_audio" in instance else None] + # break + + if image_paths == []: + if first_modality == "audiovideoimage": + image_paths.append( + [instances[0]["image_paths"][0], instances[0]["image_paths"][1][:length_thred]] + ) + audiomasks.append(instances[0]["mask_audio"]) + else: + image_paths.append(instances[0]["image_paths"][:length_thred]) + output_texts.append(instances[0]["output_texts"]) + # orig_paths.append(instances[0]["orig_paths"] if "orig_paths" in instances[0] else "") + orig_paths.append(instance["orig_paths"]) + raw_audios.append(instances[0]["raw_audio"] if "raw_audio" in instances[0] else None) + elif len(image_paths) >= 2 and trigger_reduce > 0 and self.training: + cut_len = len(image_paths) // (trigger_reduce + 1) + 1 + # print("reducing batchsize to {}".format(cut_len)) + image_paths = image_paths[:cut_len] + output_texts = output_texts[:cut_len] + audiomasks = audiomasks[:cut_len] + orig_paths = orig_paths[:cut_len] + raw_audios = raw_audios[:cut_len] + + self.sample_modality() + return dict( + image_paths=image_paths, + output_texts=output_texts, + modality=first_modality, + audiomasks=torch.tensor(audiomasks) if audiomasks != [] else None, + orig_paths=orig_paths, + raw_audios=None if None in raw_audios else raw_audios, + ) + + def get_clip_timepoints(self, clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints \ No newline at end of file diff --git a/video_salmonn/dummy/1272-128104-0000.flac b/video_salmonn/dummy/1272-128104-0000.flac new file mode 100644 index 0000000..3b0c844 Binary files /dev/null and b/video_salmonn/dummy/1272-128104-0000.flac differ diff --git a/video_salmonn/dummy/4405327307.mp4 b/video_salmonn/dummy/4405327307.mp4 new file mode 100644 index 0000000..ee67f05 Binary files /dev/null and b/video_salmonn/dummy/4405327307.mp4 differ diff --git a/video_salmonn/dummy/4405327307.wav b/video_salmonn/dummy/4405327307.wav new file mode 100644 index 0000000..b199432 Binary files /dev/null and b/video_salmonn/dummy/4405327307.wav differ diff --git a/video_salmonn/dummy/761183272.jpg b/video_salmonn/dummy/761183272.jpg new file mode 100644 index 0000000..4580730 Binary files /dev/null and b/video_salmonn/dummy/761183272.jpg differ diff --git a/video_salmonn/dummy/args.txt b/video_salmonn/dummy/args.txt new file mode 100644 index 0000000..7c62c96 --- /dev/null +++ b/video_salmonn/dummy/args.txt @@ -0,0 +1 @@ +--model openllama_peft --stage 3y --data_path null --val_data_path null --audio_data_path /mnt/bn/audio-visual-llm-data/datasets/multitask_json/bs_how2_300h_train_longest200.json --audio_val_data_path /mnt/bn/audio-visual-llm-data/datasets/multitask_json/bs_how2_300h_val.json --image_root_path null --data_type audio --imagebind_ckpt_path /mnt/bn/audio-visual-llm-data/guangzhisun/audio_visual_llm/pandagpt2/pretrained_ckpt/imagebind_ckpt --vicuna_ckpt_path /mnt/bn/audio-visual-llm-data/yuwenyi/ckpt/vicuna/vicuna.13b --max_tgt_len 2000 --save_path /mnt/bn/audio-visual-llm-data/yuwenyi/playground/pandagpt/code/output/debug/ --log_path /mnt/bn/audio-visual-llm-data/yuwenyi/playground/pandagpt/code/output/debug/log/ --use_lora true --image_data_path null --image_val_data_path null --llava_root_path null --qformer true --use_blip true --use_whisper true --instructblip true --early_align false --alignmode 2 --num_video_query 32 --groupsize 10 --causal_attention false --diversity_loss true --diversity_loss_factor 0.01 --divsche 1 --pure_aud True --speech_qformer true --num_speech_query 1 --second_per_frame 0.333333 --second_stride 0.333333 \ No newline at end of file diff --git a/video_salmonn/dummy/dummy_audio.json b/video_salmonn/dummy/dummy_audio.json new file mode 100644 index 0000000..547a073 --- /dev/null +++ b/video_salmonn/dummy/dummy_audio.json @@ -0,0 +1,15 @@ +[ + { + "image_name": "./dummy/1272-128104-0000.flac", + "conversation": [ + { + "from": "human", + "value": "Describe the audio" + }, + { + "from": "gpt", + "value": "Constant rattling noise and sharp vibrations." + } + ] + } +] \ No newline at end of file diff --git a/video_salmonn/example.json b/video_salmonn/example.json new file mode 100644 index 0000000..40fb4a6 --- /dev/null +++ b/video_salmonn/example.json @@ -0,0 +1,18 @@ +[ + { + "image_name": [ + "./dummy/4405327307.mp4", + "./dummy/4405327307.wav" + ], + "conversation": [ + { + "from": "human", + "value": "Describe the video and audio in detail" + }, + { + "from": "gpt", + "value": "None" + } + ] + } +] \ No newline at end of file diff --git a/video_salmonn/header.py b/video_salmonn/header.py new file mode 100644 index 0000000..e2a7d06 --- /dev/null +++ b/video_salmonn/header.py @@ -0,0 +1,35 @@ +import torch +import datetime +import types +# import deepspeed +# from transformers.deepspeed import HfDeepSpeedConfig +import transformers +import numpy as np +from collections import OrderedDict +from torch.utils.data import Dataset, DataLoader +from torch.nn.utils import clip_grad_norm_ +from torch.cuda.amp import autocast, GradScaler +from torch.nn import DataParallel +from torch.optim import lr_scheduler +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import os +import re +import math +import random +import json +import time +import logging +from copy import deepcopy +# import ipdb +import argparse +import data +from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from torch.nn.utils.rnn import pad_sequence +from peft import LoraConfig, TaskType, get_peft_model + +logging.getLogger("transformers").setLevel(logging.WARNING) +logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) +os.environ['TOKENIZERS_PARALLELISM'] = 'false' diff --git a/video_salmonn/inference.py b/video_salmonn/inference.py new file mode 100644 index 0000000..31df7cc --- /dev/null +++ b/video_salmonn/inference.py @@ -0,0 +1,128 @@ +import os +from config.config import Config +import argparse +import yaml +import json +from omegaconf import OmegaConf + +from datasets import SupervisedAudioVisualDataset4Test +from model.openllama import OpenLLAMAPEFTModel +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datetime import datetime + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +# Set device +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Load arguments +args = parse_args() +args = Config(args).config + +# Load the list of test files +all_decode_info = args.all_decode_info + +# Set the decoder output directory +decode_root = os.path.dirname(args.delta_ckpt_path) +current_time = datetime.now() +timestamp = current_time.strftime("%Y%m%d%H%M") +decode_root = os.path.join(decode_root, timestamp) +os.makedirs(decode_root, exist_ok=True) +OmegaConf.save(args, os.path.join(decode_root, "config.yaml")) + +# Initialise the model +ds_engine = OpenLLAMAPEFTModel(**args) +delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu')) +ds_engine.load_state_dict(delta_ckpt, strict=False) +ds_engine = ds_engine.eval().half().to(device) + +# Load test data as a list of dataloaders +dataloader_lst = [] +for modality, task, data_path in all_decode_info: + print("Loading data from: {}".format(data_path)) + + if modality == "audio": + dataset = SupervisedAudioVisualDataset4Test( + 'audio', + audio_data_path=data_path, + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"], + cache_dir=args["cache_dir"], + ) + elif modality == "audioimage": + dataset = SupervisedAudioVisualDataset4Test( + 'audioimage', + audio_data_path="./dummy/dummy_audio.json", + image_data_path=data_path, + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"], + cache_dir=args["cache_dir"], + ) + elif modality == "audiovideoimage": + dataset = SupervisedAudioVisualDataset4Test( + 'audiovideoimage', + audio_data_path="./dummy/dummy_audio.json", + video_data_path=data_path, + use_whisper=args["use_whisper"], + training=False, + sin_pos=args["sin_pos"], + return_raw=args["return_raw"], + cache_dir=args["cache_dir"], + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=args['batch_size'], + num_workers=3, + shuffle=False, + collate_fn=dataset.collate, + drop_last=False + ) + + dataloader_lst.append([dataloader, task]) + +# Start inference +results = [] +pbar = tqdm(total=sum([len(dataloader) for dataloader, _ in dataloader_lst]), desc="Decoding", position=0) + +for dataloader, task in dataloader_lst: + for batch_i, batch in enumerate(dataloader): + with torch.no_grad(): + text = ds_engine(batch, generate=True) + print(text) + for gen, ref, id in zip(text, batch['output_texts'], batch['orig_paths']): + results.append( + { + "id": f"{str(id)}_{ref[0]['value']}", + "conversation": ref, + "task": task, + "ref_answer": ref[1]['value'], + "gen_answer": gen + } + ) + pbar.update(1) + +# Write the results out +with open(os.path.join(decode_root, f"eval_result.json"), "w", encoding='utf-8') as f: + json.dump(results, f, indent=4, ensure_ascii=False) \ No newline at end of file diff --git a/video_salmonn/model/ImageBind/CODE_OF_CONDUCT.md b/video_salmonn/model/ImageBind/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..f913b6a --- /dev/null +++ b/video_salmonn/model/ImageBind/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/video_salmonn/model/ImageBind/CONTRIBUTING.md b/video_salmonn/model/ImageBind/CONTRIBUTING.md new file mode 100644 index 0000000..63d0b75 --- /dev/null +++ b/video_salmonn/model/ImageBind/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to ImageBind +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to Omnivore, you agree that your contributions will be licensed +under the [LICENSE](LICENSE) file in the root directory of this source tree. diff --git a/video_salmonn/model/ImageBind/LICENSE b/video_salmonn/model/ImageBind/LICENSE new file mode 100644 index 0000000..bfef380 --- /dev/null +++ b/video_salmonn/model/ImageBind/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/video_salmonn/model/ImageBind/README.md b/video_salmonn/model/ImageBind/README.md new file mode 100644 index 0000000..028fa98 --- /dev/null +++ b/video_salmonn/model/ImageBind/README.md @@ -0,0 +1,155 @@ +# ImageBind: One Embedding Space To Bind Them All + +**[FAIR, Meta AI](https://ai.facebook.com/research/)** + +Rohit Girdhar*, +Alaaeldin El-Nouby*, +Zhuang Liu, +Mannat Singh, +Kalyan Vasudev Alwala, +Armand Joulin, +Ishan Misra* + +To appear at CVPR 2023 (*Highlighted paper*) + +[[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)] + +PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**. + +ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. + + + +![ImageBind](https://user-images.githubusercontent.com/8495451/236859695-ffa13364-3e39-4d99-a8da-fbfab17f9a6b.gif) + +## ImageBind model + +Emergent zero-shot classification performance. + + + + + + + + + + + + + + + + + + + + + + + +
ModelIN1kK400NYU-DESCLLVIPEgo4Ddownload
imagebind_huge77.750.054.066.963.425.0checkpoint
+ +## Usage + +Install pytorch 1.13+ and other 3rd party dependencies. + +```shell +conda create --name imagebind python=3.8 -y +conda activate imagebind + +pip install -r requirements.txt +``` + +For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977) + +``` +pip install soundfile +``` + + +Extract and compare features across modalities (e.g. Image, Text and Audio). + +```python +import data +import torch +from models import imagebind_model +from models.imagebind_model import ModalityType + +text_list=["A dog.", "A car", "A bird"] +image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"] +audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"] + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Instantiate model +model = imagebind_model.imagebind_huge(pretrained=True) +model.eval() +model.to(device) + +# Load data +inputs = { + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), +} + +with torch.no_grad(): + embeddings = model(inputs) + +print( + "Vision x Text: ", + torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1), +) +print( + "Audio x Text: ", + torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1), +) +print( + "Vision x Audio: ", + torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), +) + +# Expected output: +# +# Vision x Text: +# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05], +# [3.3836e-05, 9.9994e-01, 2.4118e-05], +# [4.7997e-05, 1.3496e-02, 9.8646e-01]]) +# +# Audio x Text: +# tensor([[1., 0., 0.], +# [0., 1., 0.], +# [0., 0., 1.]]) +# +# Vision x Audio: +# tensor([[0.8070, 0.1088, 0.0842], +# [0.1036, 0.7884, 0.1079], +# [0.0018, 0.0022, 0.9960]]) + +``` + +## Model card +Please see the [model card](model_card.md) for details. + +## License + +ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details. + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Citing ImageBind + +If you find this repository useful, please consider giving a star :star: and citation + +``` +@inproceedings{girdhar2023imagebind, + title={ImageBind: One Embedding Space To Bind Them All}, + author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang +and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, + booktitle={CVPR}, + year={2023} +} +``` diff --git a/video_salmonn/model/ImageBind/__init__.py b/video_salmonn/model/ImageBind/__init__.py new file mode 100644 index 0000000..d872d07 --- /dev/null +++ b/video_salmonn/model/ImageBind/__init__.py @@ -0,0 +1,2 @@ +from .models import imagebind_model +from .models.imagebind_model import ModalityType diff --git a/video_salmonn/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz b/video_salmonn/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/video_salmonn/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz differ diff --git a/video_salmonn/model/ImageBind/data.py b/video_salmonn/model/ImageBind/data.py new file mode 100644 index 0000000..10eb145 --- /dev/null +++ b/video_salmonn/model/ImageBind/data.py @@ -0,0 +1,721 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +import torchaudio +import logging + +from .models.multimodal_preprocessors import SimpleTokenizer +from PIL import Image +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +from pytorchvideo.data.encoded_video import EncodedVideo + +from torchvision import transforms +from torchvision.transforms._transforms_video import NormalizeVideo +from torchvision.transforms.functional import InterpolationMode + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds + +BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" + + +def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): + # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 + waveform -= waveform.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform, + htk_compat=True, + sample_frequency=sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + # Convert to [mel_bins, num_frames] shape + fbank = fbank.transpose(0, 1) + # Pad to target_length + n_frames = fbank.size(1) + p = target_length - n_frames + # if p is too large (say >20%), flash a warning + if abs(p) / n_frames > 0.2: + logging.warning( + "Large gap between audio n_frames(%d) and " + "target_length (%d). Is the audio_target_length " + "setting correct?", + n_frames, + target_length, + ) + # cut and pad + if p > 0: + fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) + elif p < 0: + fbank = fbank[:, 0:target_length] + # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 + # channel image + fbank = fbank.unsqueeze(0) + return fbank + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def load_and_transform_vision_data(image_paths, device): + if image_paths is None: + return None + + image_ouputs = [] + for image_path in image_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + image = data_transform(image).to(device) + image_ouputs.append(image) + return torch.stack(image_ouputs, dim=0) + + +def load_and_transform_vision_data_blip(image_paths, device, training=False, hi_rs=False, hi_rs_cfg=None): + if image_paths is None: + return None + + if training and not hi_rs: + data_transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + 224, + scale=(0.5, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + else: + data_transform = transforms.Compose( + [ + transforms.Resize( + (224, 224), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + image_ouputs = [] + for image_path in image_paths: + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + if hi_rs: + if hi_rs_cfg is None: + hi_rs_cfg = (4, 1) + n_split, dup = hi_rs_cfg + width, height = image.size + image_blocks = [] + for _ in range(dup): + image_blocks.append( + data_transform(image).to(device).unsqueeze(0) + ) + + dx = width // n_split * 2 + dy = height // n_split * 2 + for ny in range(n_split - 1): + for nx in range(n_split - 1): + x = width // n_split * nx + y = height // n_split * ny + box = (x, y, x + dx, y + dy) + for _ in range(dup): + image_blocks.append( + data_transform( + image.crop(box) + ).to(device).unsqueeze(0) + ) + + # for y in range(0, height, height // 4): + # for x in range(0, width, width // 4): + # box = (x, y, x + width // 4 * 2, y + height // 4 * 2) + # image_blocks.append( + # data_transform( + # image.crop(box) + # ).to(device).unsqueeze(0) + # ) + image_blocks = torch.cat(image_blocks, dim=0) + image_ouputs.append(image_blocks) + + else: + image = data_transform(image).to(device) + image_ouputs.append(image) + + if hi_rs: + image_lens = [img.shape[0] for img in image_ouputs] + max_image_len = max(image_lens) + img_mask = torch.arange(max_image_len).unsqueeze(0) < torch.tensor(image_lens).unsqueeze(1) + return pad_sequence(image_ouputs, batch_first=True), img_mask.to(device) + else: + return torch.stack(image_ouputs, dim=0) + + +def load_and_transform_thermal_data(thermal_paths, device): + if thermal_paths is None: + return None + + thermal_ouputs = [] + for thermal_path in thermal_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + ] + ) + with open(thermal_path, "rb") as fopen: + thermal = Image.open(fopen).convert("L") + thermal = data_transform(thermal).to(device) + thermal_ouputs.append(thermal) + return torch.stack(thermal_ouputs, dim=0) + + +def load_and_transform_text(text, device): + if text is None: + return None + tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) + tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] + tokens = torch.cat(tokens, dim=0) + return tokens + +def load_and_transform_audio_data_fulllen( + audio_paths, + device, + num_mel_bins=128, + target_length=204, + sample_rate=16000, + clip_duration=2, + mean=-4.268, + std=9.138, + maxlen=30, +): + if audio_paths is None: + return None + + audio_outputs = [] + # clip_sampler = ConstantClipsPerVideoSampler( + # clip_duration=clip_duration, clips_per_video=clips_per_video + # ) + + for audio_path in audio_paths: + waveform, sr = torchaudio.load(audio_path) + if sample_rate != sr: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sr, new_freq=sample_rate + ) + full_lengths = waveform.size(1) + if full_lengths < maxlen * sample_rate: + diffsize = maxlen * sample_rate - full_lengths - 1 + waveform = torch.cat( + [waveform, waveform.new_zeros(waveform.size(0), diffsize)], dim=-1) + full_lengths = min(waveform.size(1), maxlen * sample_rate) + all_clips = [] + start = 0 + stepsize = clip_duration * sample_rate + while start < full_lengths: + end = min(start + stepsize, full_lengths) + waveform_clip = waveform[ + :, + int(start) : int(end), + ] + if int(end) - int(start) < stepsize: + diffsize = stepsize - int(end) + int(start) + waveform_clip = torch.cat( + [waveform_clip, waveform_clip.new_zeros(waveform_clip.size(0), diffsize)], dim=-1) + waveform_melspec = waveform2melspec( + waveform_clip, sample_rate, num_mel_bins, target_length + ) + all_clips.append(waveform_melspec) + start = start + stepsize + + normalize = transforms.Normalize(mean=mean, std=std) + all_clips = [normalize(ac).to(device) for ac in all_clips] + + all_clips = torch.stack(all_clips, dim=0) + audio_outputs.append(all_clips) + # for audio in audio_outputs: + # if audio.size(0) > 5: + # import pdb; pdb.set_trace() + + return torch.stack(audio_outputs, dim=0) + + +def load_and_transform_audio_data( + audio_paths, + device, + num_mel_bins=128, + target_length=204, + sample_rate=16000, + clip_duration=2, + clips_per_video=3, + mean=-4.268, + std=9.138, +): + if audio_paths is None: + return None + + audio_outputs = [] + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + + for audio_path in audio_paths: + waveform, sr = torchaudio.load(audio_path) + if sample_rate != sr: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sr, new_freq=sample_rate + ) + all_clips_timepoints = get_clip_timepoints( + clip_sampler, waveform.size(1) / sample_rate + ) + all_clips = [] + for clip_timepoints in all_clips_timepoints: + waveform_clip = waveform[ + :, + int(clip_timepoints[0] * sample_rate) : int( + clip_timepoints[1] * sample_rate + ), + ] + waveform_melspec = waveform2melspec( + waveform_clip, sample_rate, num_mel_bins, target_length + ) + all_clips.append(waveform_melspec) + + normalize = transforms.Normalize(mean=mean, std=std) + all_clips = [normalize(ac).to(device) for ac in all_clips] + + all_clips = torch.stack(all_clips, dim=0) + audio_outputs.append(all_clips) + + return torch.stack(audio_outputs, dim=0) + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError("Nothing else supported yet") + + def forward(self, videos): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res + + +class ToUint8(object): + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.to(torch.uint8) + + def __repr__(self): + return self.__class__.__name__ + + +class ToTHWC(object): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) + """ + + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.permute(1, 2, 3, 0) + + def __repr__(self): + return self.__class__.__name__ + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError( + f"target size should be tuple (height, width), instead got {target_size}" + ) + return torch.nn.functional.interpolate( + clip, size=target_size, mode=interpolation_mode, align_corners=False + ) + +class ResizeVideo(object): + def __init__(self, target_size, interpolation_mode="bilinear"): + self.target_size = target_size + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return resize(clip, self.target_size, self.interpolation_mode) + + def __repr__(self): + return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) + +def load_and_transform_video_data_full( + video_paths, + device, + clip_duration=1, + sample_per_clip=2, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + ResizeVideo((224, 224), interpolation_mode="bicubic"), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = UniformClipSampler( + clip_duration=clip_duration, backpad_last=True + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=sample_per_clip) + + maxlen = 0 + for video_path in video_paths: + if not isinstance(video_path, list): + video = EncodedVideo.from_path( + video_path, + decoder="decord", + decode_audio=False, + **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + else: + all_video = video_path + + all_video = [video_transform(clip) for clip in all_video] + # all_video = SpatialCrop(224, num_crops=3)(all_video) + if len(all_video) > maxlen: + maxlen = len(all_video) + all_video = torch.stack(all_video, dim=0) + video_outputs.append(all_video) + + padded_video_outputs = [] + padded_video_mask = [] + for video in video_outputs: + if video.size(0) < maxlen: + diffsize = maxlen - video.size(0) + padded_video_mask.append([1] * video.size(0) + [0] * diffsize) + video = torch.cat([video, video.new_zeros( + diffsize, video.size(1), video.size(2), video.size(3), video.size(4))], dim=0) + else: + padded_video_mask.append([1] * video.size(0)) + padded_video_outputs.append(video) + + return torch.stack(padded_video_outputs, dim=0).to(device), torch.tensor(padded_video_mask).to(device) + +def load_and_transform_video_data_blip( + video_paths, + device, + clip_duration=1, + sample_per_clip=2, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + ResizeVideo((224, 224), interpolation_mode="bicubic"), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = UniformClipSampler( + clip_duration=clip_duration, backpad_last=True + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=sample_per_clip) + + maxlen = 0 + for all_video in video_paths: + if not isinstance(all_video, list): + video = EncodedVideo.from_path( + all_video, + decoder="pyav", + decode_audio=False, + # **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + # Hard set here to be less than 60 seconds + if len(all_video) > 60: + all_video = all_video[:60] + + all_video = torch.cat(all_video, dim=1) + all_video = video_transform(all_video).transpose(0, 1) # C, T, H, W -> T, C, H, W + if all_video.size(0) > maxlen: + maxlen = all_video.size(0) + video_outputs.append(all_video) + + padded_video_outputs = [] + padded_video_mask = [] + for video in video_outputs: + if video.size(0) < maxlen: + diffsize = maxlen - video.size(0) + padded_video_mask.append([1] * video.size(0) + [0] * diffsize) + video = torch.cat([video, video.new_zeros( + diffsize, video.size(1), video.size(2), video.size(3))], dim=0) + else: + padded_video_mask.append([1] * video.size(0)) + padded_video_outputs.append(video) + + return torch.stack(padded_video_outputs, dim=0).to(device), torch.tensor(padded_video_mask).to(device) + +def load_and_transform_video_data( + video_paths, + device, + clip_duration=2, + clips_per_video=5, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + pv_transforms.ShortSideScale(224), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) + + for video_path in video_paths: + video = EncodedVideo.from_path( + video_path, + decoder="decord", + decode_audio=False, + **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + + all_video = [video_transform(clip) for clip in all_video] + all_video = SpatialCrop(224, num_crops=3)(all_video) + + all_video = torch.stack(all_video, dim=0) + video_outputs.append(all_video) + + return torch.stack(video_outputs, dim=0).to(device) diff --git a/video_salmonn/model/ImageBind/model_card.md b/video_salmonn/model/ImageBind/model_card.md new file mode 100644 index 0000000..c7bb265 --- /dev/null +++ b/video_salmonn/model/ImageBind/model_card.md @@ -0,0 +1,94 @@ +# Model Card for ImageBind + +Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images. +Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks. + +# Model Details + +## Model Description + + +Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images + +- **Developed by:** Meta AI +- **Model type:** Multimodal model +- **Language(s) (NLP):** en +- **License:** CC BY-NC-SA 4.0 +- **Resources for more information:** + - [GitHub Repo](https://github.com/facebookresearch/ImageBind) + + +# Uses + + +This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images. +We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities. + +## Out-of-Scope Use + + + + +This model is *NOT* intended to be used in any real world application -- commercial or otherwise. +It may produce harmful associations with different inputs. +The model needs to be investigated and likely re-trained on specific data for any such application. +The model is expected to work better on web-based visual data since it was trained on such data. +The text encoder is likely to work only on English language text because of the underlying training datasets. + +# Bias, Risks, and Limitations + + +Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness). +Since our model uses such models as initialization, it will exhibit such biases too. +Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes. + + + +# Training Details + +## Training Data + + + +ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data. +In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder. +We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset. +We provide the exact training data details in the paper. + + +## Training Procedure + + +Please refer to the research paper and github repo for exact details on this. + +# Evaluation + +## Testing Data, Factors & Metrics + +We evaluate the model on a variety of different classification benchmarks for each modality. +The evaluation details are presented in the paper. +The models performance is measured using standard classification metrics such as accuracy and mAP. + +# Citation + + + +**BibTeX:** +``` +@inproceedings{girdhar2023imagebind, + title={ImageBind: One Embedding Space To Bind Them All}, + author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang +and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, + booktitle={CVPR}, + year={2023} +} +``` + + +# Model Card Contact + +Please reach out to the authors at: rgirdhar@meta.com imisra@meta.com alaaelnouby@gmail.com + +# How to Get Started with the Model + +Our github repo provides a simple example to extract embeddings from images, audio etc. diff --git a/video_salmonn/model/ImageBind/models/__init__.py b/video_salmonn/model/ImageBind/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/video_salmonn/model/ImageBind/models/helpers.py b/video_salmonn/model/ImageBind/models/helpers.py new file mode 100644 index 0000000..049e1f1 --- /dev/null +++ b/video_salmonn/model/ImageBind/models/helpers.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import einops +import numpy as np +import torch + +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x diff --git a/video_salmonn/model/ImageBind/models/imagebind_model.py b/video_salmonn/model/ImageBind/models/imagebind_model.py new file mode 100644 index 0000000..e4b5334 --- /dev/null +++ b/video_salmonn/model/ImageBind/models/imagebind_model.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import urllib +from functools import partial +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from .helpers import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from .multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) + +from .transformer import MultiheadAttention, SimpleTransformer + + +ModalityType = SimpleNamespace( + VISION="vision", + TEXT="text", + AUDIO="audio", + THERMAL="thermal", + DEPTH="depth", + IMU="imu", +) + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + super().__init__() + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs): + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + modality_value = self.modality_heads[modality_key]( + modality_value, **head_inputs + ) + if modality_key in [ModalityType.AUDIO]: + modality_value = self.modality_postprocessors[modality_key][0]( + modality_value + ) + else: + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + if reduce_list: + modality_value = modality_value.reshape(B, S, -1) + # modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False, store_path=r'.checkpoints'): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + if not os.path.exists("{}/imagebind_huge.pth".format(store_path)): + print( + "Downloading imagebind weights to {}/imagebind_huge.pth ...".format(store_path) + ) + os.makedirs(store_path, exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + "{}/imagebind_huge.pth".format(store_path), + progress=True, + ) + + model.load_state_dict(torch.load("{}/imagebind_huge.pth".format(store_path))) + + return model, 1024 diff --git a/video_salmonn/model/ImageBind/models/multimodal_preprocessors.py b/video_salmonn/model/ImageBind/models/multimodal_preprocessors.py new file mode 100644 index 0000000..44de961 --- /dev/null +++ b/video_salmonn/model/ImageBind/models/multimodal_preprocessors.py @@ -0,0 +1,687 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional + +import ftfy + +import numpy as np +import regex as re +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ + +from .helpers import cast_if_src_dtype, VerboseNNModule + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: PatchEmbedGeneric, + img_size: List = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: List = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = imu_stem + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict diff --git a/video_salmonn/model/ImageBind/models/transformer.py b/video_salmonn/model/ImageBind/models/transformer.py new file mode 100644 index 0000000..98902ac --- /dev/null +++ b/video_salmonn/model/ImageBind/models/transformer.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +import copy +import fnmatch +import logging +from functools import partial +from typing import Callable, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: str = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Callable = None, + post_transformer_layer: Callable = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: List[int] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens diff --git a/video_salmonn/model/ImageBind/requirements.txt b/video_salmonn/model/ImageBind/requirements.txt new file mode 100644 index 0000000..572ae07 --- /dev/null +++ b/video_salmonn/model/ImageBind/requirements.txt @@ -0,0 +1,10 @@ +--extra-index-url https://download.pytorch.org/whl/cu113 +torchvision==0.14.0 +torchaudio==0.13.0 +pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d +timm==0.6.7 +ftfy +regex +einops +fvcore +decord==0.6.0 diff --git a/video_salmonn/model/Qformer.py b/video_salmonn/model/Qformer.py new file mode 100644 index 0000000..877ce4d --- /dev/null +++ b/video_salmonn/model/Qformer.py @@ -0,0 +1,1268 @@ +""" +Adapted from salesforce@LAVIS. Below is the original copyright: + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention, is_causal_attention=False): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + + if is_causal_attention: + self.attention_head_size = int(config.encoder_width / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention or is_causal_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config, is_causal_attention=False): + super().__init__() + if is_causal_attention: + self.dense = nn.Linear(config.encoder_width, config.encoder_width) + self.LayerNorm = nn.LayerNorm(config.encoder_width, eps=config.layer_norm_eps) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, is_causal_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention, is_causal_attention) + self.output = BertSelfOutput(config, is_causal_attention) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Causal encoder + if getattr(self.config, "causal_encoder", False): + self.encoder_causal_attention = BertAttention(config, is_causal_attention=True) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + encoder_causal_mask=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + if encoder_causal_mask is not None: + encoder_hidden_states_output = self.encoder_causal_attention( + encoder_hidden_states, + encoder_causal_mask.squeeze(1), + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + encoder_hidden_states = encoder_hidden_states_output[0] + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + encoder_causal_mask=None, + segsize=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + encoder_causal_mask, + ) + + hidden_states = layer_outputs[0] + if segsize != 0: + if i != 0: + hidden_states = hidden_states[:, :-query_length] + prev_hidden_states = hidden_states[:, :query_length].reshape(-1, segsize*query_length, hidden_states.size(-1)) + prev_hidden_states = prev_hidden_states[:, :-query_length] + prev_hidden_states = torch.cat( + [prev_hidden_states.new_zeros(prev_hidden_states.size(0), query_length, prev_hidden_states.size(-1)), prev_hidden_states], + dim=1, + ) + prev_hidden_states = prev_hidden_states.view(hidden_states.size(0), query_length, -1) + hidden_states = torch.cat([hidden_states, prev_hidden_states], dim=1) + if attention_mask.size(-1) != hidden_states.size(1): + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros(attention_mask.size(0), 1, 1, query_length)], dim=-1) + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + segsize=0, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + # Encoder causal masks + encoder_causal_mask = None + if getattr(self.config, "causal_encoder", False) and encoder_sequence_length % 32 == 0: + encoder_causal_mask = torch.triu(torch.ones(encoder_sequence_length//32, encoder_sequence_length//32)).to(encoder_hidden_states.device) + mat2 = torch.ones(32, 32).to(encoder_hidden_states.device) + encoder_causal_mask = torch.kron(encoder_causal_mask, mat2) + encoder_causal_mask = encoder_causal_mask.unsqueeze(0).repeat(encoder_batch_size, 1, 1) + encoder_causal_mask = self.invert_attention_mask(encoder_causal_mask).unsqueeze(1) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + encoder_causal_mask=encoder_causal_mask, + segsize=segsize, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/video_salmonn/model/__init__.py b/video_salmonn/model/__init__.py new file mode 100644 index 0000000..8d11e46 --- /dev/null +++ b/video_salmonn/model/__init__.py @@ -0,0 +1,17 @@ +from .agent import DeepSpeedAgent +from .openllama import OpenLLAMAPEFTModel + +def load_model(args): + agent_name = args['models'][args['model']]['agent_name'] + model_name = args['models'][args['model']]['model_name'] + model = globals()[model_name](**args) + for name, module in model.named_children(): + if hasattr(module, "gradient_checkpointing_enable"): + try: + module.config.use_cache = False + module.enable_input_require_grads() + module.gradient_checkpointing_enable() + except Exception as e: + print(e) + agent = globals()[agent_name](model, args) + return agent diff --git a/video_salmonn/model/agent.py b/video_salmonn/model/agent.py new file mode 100644 index 0000000..910165b --- /dev/null +++ b/video_salmonn/model/agent.py @@ -0,0 +1,142 @@ +from header import * + +class DeepSpeedAgent: + + def __init__(self, model, args): + super(DeepSpeedAgent, self).__init__() + self.args = args + self.model = model + if args['stage'] == 2: + self.load_stage_1_parameters(args["delta_ckpt_path"]) + print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}') + + # load config parameters of deepspeed + ds_params = json.load(open(self.args['ds_config_path'])) + # ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps'] + # ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps'] / ds_params['train_micro_batch_size_per_gpu'] + ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps'] / 2 + ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']) / 8) + if self.args['world_size'] * ds_params['gradient_accumulation_steps'] * ds_params['train_micro_batch_size_per_gpu'] != ds_params['train_batch_size']: + print("Force setting train batch size") + ds_params['train_batch_size'] = self.args['world_size'] * ds_params['gradient_accumulation_steps'] * ds_params['train_micro_batch_size_per_gpu'] + self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize( + model=self.model, + model_parameters=self.model.parameters(), + config_params=ds_params, + dist_init_required=True, + args=types.SimpleNamespace(**args) + ) + + @torch.no_grad() + def predict(self, batch): + self.model.eval() + string = self.model.generate_one_sample(batch) + return string + + def train_model(self, batch, current_step=0, pbar=None): + self.ds_engine.module.train() + loss, mle_acc = self.ds_engine(batch) + # print("Begin backward, {}".format(int(os.getenv('RANK', '0')))) + self.ds_engine.backward(loss) + # print("After backward, {}".format(int(os.getenv('RANK', '0')))) + self.ds_engine.step() + # print("After optimizer step, {}".format(int(os.getenv('RANK', '0')))) + pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}') + pbar.update(1) + if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0: + elapsed = pbar.format_dict['elapsed'] + rate = pbar.format_dict['rate'] + remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0 + remaining = str(datetime.timedelta(seconds=remaining)) + logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}') + + mle_acc *= 100 + return mle_acc + + @torch.no_grad() + def valid_model(self, batch): + self.model.eval() + loss, mle_acc = self.ds_engine(batch) + return loss.item(), mle_acc + + def _zero3_consolidated_16bit_state_dict(self): + """ + Get a full non-partitioned state_dict with fp16 weights on cpu. + Important: this function must be called on all ranks and not just rank 0. + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: + 1. consolidates the weights from different partitions on gpu0 + 2. works on one layer at a time to require as little gpu0 memory as possible, by + moving the already consolidated weights to cpu + 3. takes care to keep the shared params shared when gradually copying the params to cpu + Returns: + a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks + """ + state_dict = OrderedDict() + shared_params = {} + + def get_layer_state_dict(module, prefix=""): + # gather one layer at a time to be memory-efficient + # must use modifier_rank=0 to release GPU memory after each layer gathered + #see_memory_usage("before GatheredParameters", force=True) + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if int(os.getenv('RANK', '0')) == 0: + # handle params + for name, param in module.named_parameters(recurse=False): + if param is None: + continue + key = prefix + name + if param.requires_grad: + # can't rely on param.data_ptr() as it will be reused as weights gets + # gathered and reduced, but param.ds_id is unique across all zero weights + # (and shared params will have the same param.ds_id) + if param.ds_id in shared_params: + # shared weights + # print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") + state_dict[key] = state_dict[shared_params[param.ds_id]] + else: + state_dict[key] = param.detach().cpu() + shared_params[param.ds_id] = key + # print(f"param {param.ds_id} {param.shape} {key} ") + + # now buffers - not sure if need to take care of potentially shared weights here + for name, buf in module.named_buffers(recurse=False): + if (buf is not None and name not in module._non_persistent_buffers_set): + state_dict[prefix + name] = buf.detach().cpu() + #see_memory_usage("after GatheredParameters", force=True) + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + # Prepare for checkpoint save by ensuring all parameters are partitioned + self.optimizer.checkpoint_event_prologue() + + # see_memory_usage("before get_layer_state_dict", force=False) + get_layer_state_dict(self.ds_engine.module, prefix="") + # see_memory_usage("after get_layer_state_dict", force=False) + + # self.ds_engine.optimizer.checkpoint_event_epilogue() + + return state_dict + + def save_model(self, path, current_step): + # only save trainable model parameters + if self.ds_engine.zero_gather_16bit_weights_on_model_save(): + # state_dict = self.ds_engine._zero3_consolidated_16bit_state_dict() + checkpoint = self._zero3_consolidated_16bit_state_dict() + else: + checkpoint = OrderedDict() + for k, v in self.ds_engine.module.named_parameters(): + if v.requires_grad: + checkpoint[k] = v + if int(os.getenv('RANK', '0')) == 0: + torch.save(checkpoint, f'{path}/pytorch_model_{current_step}.pt') + # save tokenizer + self.model.llama_tokenizer.save_pretrained(path) + # save configuration + self.model.llama_model.config.save_pretrained(path) + print(f'[!] save model into {path}') + + def load_stage_1_parameters(self, path): + delta_ckpt = torch.load(path, map_location=torch.device('cpu')) + self.model.load_state_dict(delta_ckpt, strict=False) diff --git a/video_salmonn/model/beats/BEATs.py b/video_salmonn/model/beats/BEATs.py new file mode 100644 index 0000000..17fc6f6 --- /dev/null +++ b/video_salmonn/model/beats/BEATs.py @@ -0,0 +1,186 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +import torchaudio.compliance.kaldi as ta_kaldi + +from .backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + # [Yu] Add device for BEATs + @property + def device(self): + return next(self.parameters()).device + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + feature_only=False, + ): + # fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) + fbank = self.preprocess(source.to(torch.float32), fbank_mean=fbank_mean, fbank_std=fbank_std).to(source.dtype) # [Yu] modify to support fp16 training + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + if not feature_only and self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask \ No newline at end of file diff --git a/video_salmonn/model/beats/Tokenizers.py b/video_salmonn/model/beats/Tokenizers.py new file mode 100644 index 0000000..ece8019 --- /dev/null +++ b/video_salmonn/model/beats/Tokenizers.py @@ -0,0 +1,172 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +import torchaudio.compliance.kaldi as ta_kaldi + +from .backbone import ( + TransformerEncoder, +) +from .quantizer import ( + NormEMAVectorQuantizer, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenizersConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # quantizer + self.quant_n: int = 1024 # codebook number in quantizer + self.quant_dim: int = 256 # codebook dimension in quantizer + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class Tokenizers(nn.Module): + def __init__( + self, + cfg: TokenizersConfig, + ) -> None: + super().__init__() + logger.info(f"Tokenizers Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.quantize = NormEMAVectorQuantizer( + n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, + ) + self.quant_n = cfg.quant_n + self.quantize_layer = nn.Sequential( + nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), + nn.Tanh(), + nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize + ) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_labels( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + quantize_input = self.quantize_layer(x) + quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) + + return embed_ind diff --git a/video_salmonn/model/beats/__init__.py b/video_salmonn/model/beats/__init__.py new file mode 100644 index 0000000..b1c5846 --- /dev/null +++ b/video_salmonn/model/beats/__init__.py @@ -0,0 +1 @@ +from .BEATs import BEATs, BEATsConfig \ No newline at end of file diff --git a/video_salmonn/model/beats/backbone.py b/video_salmonn/model/beats/backbone.py new file mode 100644 index 0000000..d876929 --- /dev/null +++ b/video_salmonn/model/beats/backbone.py @@ -0,0 +1,783 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) \ No newline at end of file diff --git a/video_salmonn/model/beats/modules.py b/video_salmonn/model/beats/modules.py new file mode 100644 index 0000000..18e2d20 --- /dev/null +++ b/video_salmonn/model/beats/modules.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/video_salmonn/model/beats/quantizer.py b/video_salmonn/model/beats/quantizer.py new file mode 100644 index 0000000..704be4c --- /dev/null +++ b/video_salmonn/model/beats/quantizer.py @@ -0,0 +1,215 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on VQGAN code bases +# https://github.com/CompVis/taming-transformers +# --------------------------------------------------------' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed + +try: + from einops import rearrange, repeat +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + # z = rearrange(z, 'b c h w -> b h w c') + # z = z.transpose(1, 2) + z = l2norm(z) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + # z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q = z_q.transpose(1, 2) + return z_q, loss, encoding_indices \ No newline at end of file diff --git a/video_salmonn/model/eva_vit.py b/video_salmonn/model/eva_vit.py new file mode 100644 index 0000000..aee412b --- /dev/null +++ b/video_salmonn/model/eva_vit.py @@ -0,0 +1,438 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + cached_file = "/scratch/OpenSource/SALMONN/video_salmonn/ckpt/pretrained_ckpt/eva_vit_g.pth" # Change this to your local path + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/video_salmonn/model/llama_attn_replace.py b/video_salmonn/model/llama_attn_replace.py new file mode 100644 index 0000000..31af0ce --- /dev/null +++ b/video_salmonn/model/llama_attn_replace.py @@ -0,0 +1,237 @@ +from typing import Optional, Tuple +import warnings + +import torch +from torch import nn +from .modeling_llama import apply_rotary_pos_emb +from .modeling_llama import LlamaAttention, LlamaModel + +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.Tensor] = None, +# past_key_value: Optional[Tuple[torch.Tensor]] = None, +# output_attentions: bool = False, +# use_cache: bool = False, +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +# if output_attentions: +# warnings.warn( +# "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." +# ) + +# bsz, q_len, _ = hidden_states.size() + +# query_states = ( +# self.q_proj(hidden_states) +# .view(bsz, q_len, self.num_heads, self.head_dim) +# .transpose(1, 2) +# ) +# key_states = ( +# self.k_proj(hidden_states) +# .view(bsz, q_len, self.num_heads, self.head_dim) +# .transpose(1, 2) +# ) +# value_states = ( +# self.v_proj(hidden_states) +# .view(bsz, q_len, self.num_heads, self.head_dim) +# .transpose(1, 2) +# ) # shape: (b, num_heads, s, head_dim) + +# kv_seq_len = key_states.shape[-2] +# if past_key_value is not None: +# kv_seq_len += past_key_value[0].shape[-2] +# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) +# query_states, key_states = apply_rotary_pos_emb( +# query_states, key_states, cos, sin, position_ids +# ) + +# if past_key_value is not None: +# # reuse k, v +# key_states = torch.cat([past_key_value[0], key_states], dim=2) +# value_states = torch.cat([past_key_value[1], value_states], dim=2) + +# past_key_value = (key_states, value_states) if use_cache else None + +# # Transform the data into the format required by flash attention +# qkv = torch.stack([query_states, key_states, value_states], dim=2) +# qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] +# key_padding_mask = attention_mask + +# if key_padding_mask is None: +# qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) +# cu_q_lens = torch.arange( +# 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device +# ) +# max_s = q_len +# output = flash_attn_varlen_qkvpacked_func( +# qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True +# ) +# output = output.view(bsz, q_len, -1) +# else: +# qkv = qkv.reshape(bsz, q_len, -1) +# qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) +# qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) +# output_unpad = flash_attn_varlen_qkvpacked_func( +# qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True +# ) +# output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) +# output = pad_input(output_unpad, indices, bsz, q_len) + +# return self.o_proj(output), None, past_key_value + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if attention_mask is None: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + output = flash_attn_func(query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True).view( + bsz, q_len, -1 + ) + else: + q, indices, cu_q_lens, max_s = unpad_input(query_states.transpose(1, 2), attention_mask[:, -q_len:]) + kv = torch.stack((key_states, value_states), dim=2).transpose(1, 3) + kv = kv.reshape(bsz, kv_seq_len, -1) + kv, _, cu_k_lens, max_k = unpad_input(kv, attention_mask) + kv = kv.view(-1, 2, self.num_heads, self.head_dim) + output_unpad = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_q_lens, + cu_k_lens, + max_s, + max_k, + 0.0, + softmax_scale=None, + causal=True, + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +def _prepare_decoder_attention_mask_inference( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + ( + torch.full( + (input_shape[0], past_key_values_length), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) + + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +# def _prepare_decoder_attention_mask( +# self, attention_mask, input_shape, inputs_embeds, past_key_values_length +# ): +# # [bsz, seq_len] +# if past_key_values_length > 0 and attention_mask is not None: +# attention_mask = torch.cat( +# ( +# torch.full( +# (input_shape[0], past_key_values_length), +# True, +# dtype=attention_mask.dtype, +# device=attention_mask.device, +# ), +# attention_mask, +# ), +# dim=-1, +# ) + +# if attention_mask is not None and torch.all(attention_mask): +# return None # This uses the faster call when training with full samples + +# return attention_mask + + +def replace_llama_attn_with_flash_attn(inference=False): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + if inference: + LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask_inference + ) + else: + LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + LlamaAttention.forward = forward \ No newline at end of file diff --git a/video_salmonn/model/modeling_llama.py b/video_salmonn/model/modeling_llama.py new file mode 100644 index 0000000..dbdbd60 --- /dev/null +++ b/video_salmonn/model/modeling_llama.py @@ -0,0 +1,768 @@ +# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.llama.configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # print(hidden_states.size()) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + modality_lengths: int = 0, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + batch_size, seq_length, _ = inputs_embeds.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + if modality_lengths > 0: + attn_mask_blk1 = (1 - attention_mask[:, :modality_lengths]).unsqueeze(1).unsqueeze(1).repeat(1, 1, modality_lengths, 1) * -1e9 + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length-modality_lengths), inputs_embeds, modality_lengths + ) + attn_mask_blk2 = attention_mask.new_ones(batch_size, 1, modality_lengths, seq_length-modality_lengths) * -1e9 + attn_mask_blk = torch.cat([attn_mask_blk1, attn_mask_blk2], dim=-1) + attention_mask = torch.cat([attn_mask_blk, attention_mask], dim=2) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + modality_lengths: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + modality_lengths=modality_lengths, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + query_embeds = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "query_embeds": query_embeds, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + diff --git a/video_salmonn/model/modeling_whisper.py b/video_salmonn/model/modeling_whisper.py new file mode 100644 index 0000000..ca63982 --- /dev/null +++ b/video_salmonn/model/modeling_whisper.py @@ -0,0 +1,1772 @@ +# This script is based on https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py + +""" PyTorch Whisper model.""" +""" Added by Yu. """ + +import math +import random +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.whisper.configuration_whisper import WhisperConfig +from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "WhisperConfig" +_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" + + +WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "openai/whisper-base", + # See all Whisper models at https://huggingface.co/models?filter=whisper +] + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +class WhisperPositionalEmbedding(nn.Embedding): + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + + def forward(self, input_ids, past_key_values_length=0): + return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + + +class WhisperAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper +class WhisperEncoderLayer(nn.Module): + def __init__(self, config: WhisperConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper +class WhisperDecoderLayer(nn.Module): + def __init__(self, config: WhisperConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class WhisperPreTrainedModel(PreTrainedModel): + config_class = WhisperConfig + base_model_prefix = "model" + main_input_name = "input_features" + supports_gradient_checkpointing = True + _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (WhisperDecoder, WhisperEncoder)): + module.gradient_checkpointing = value + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + +WHISPER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`WhisperConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +WHISPER_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WHISPER_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class WhisperEncoder(WhisperPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`WhisperEncoderLayer`]. + + Args: + config: WhisperConfig + """ + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + + self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + _, seqlen, _ = inputs_embeds.size() # [Yu] Support variable length. + hidden_states = inputs_embeds + embed_pos[:seqlen] # [Yu] Support variable length. + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class WhisperDecoder(WhisperPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`] + + Args: + config: WhisperConfig + """ + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + + self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Whisper Model outputting raw hidden-states without any specific head on top.", + WHISPER_START_DOCSTRING, +) +class WhisperModel(WhisperPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"proj_out.weight"] + + def __init__(self, config: WhisperConfig): + super().__init__(config) + + self.encoder = WhisperEncoder(config) + self.decoder = WhisperDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. + """ + self.encoder._freeze_parameters() + + def _mask_input_features( + self, + input_features: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return input_features + + # generate indices & apply SpecAugment along time axis + batch_size, hidden_size, sequence_length = input_features.size() + + if self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) + mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) + input_features[mask_time_indices] = 0 + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) + input_features[mask_feature_indices] = 0 + + return input_features + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", + WHISPER_START_DOCSTRING, +) +class WhisperForConditionalGeneration(WhisperPreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"encoder.version", + r"decoder.version", + r"proj_out.weight", + ] + _keys_to_ignore_on_save = [ + r"proj_out.weight", + ] + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.model = WhisperModel(config) + self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + return new_embeddings + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. + """ + self.model.encoder._freeze_parameters() + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + **kwargs, + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + return_timestamps (`bool`, *optional*): + Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. + task (`bool`, *optional*): + Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` + will be updated accordingly. + language (`bool`, *optional*): + Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can + find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. + is_multilingual (`bool`, *optional*): + Whether or not the model is multilingual. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set." + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`." + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + + generation_config.return_timestamps = return_timestamps + else: + generation_config.return_timestamps = False + + if language is not None: + language = language.lower() + generation_config.language = language + if task is not None: + generation_config.task = task + + forced_decoder_ids = [] + if task is not None or language is not None: + if hasattr(generation_config, "language"): + if generation_config.language in generation_config.lang_to_id.keys(): + language_token = generation_config.language + elif generation_config.language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" + else: + is_language_code = len(generation_config.language) == 2 + raise ValueError( + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." + ) + forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + else: + forced_decoder_ids.append((1, None)) # automatically detect the language + + if hasattr(generation_config, "task"): + if generation_config.task in TASK_IDS: + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + raise ValueError( + f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" + ) + else: + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe + if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + # Legacy code for backward compatibility + elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: + forced_decoder_ids = self.config.forced_decoder_ids + elif ( + hasattr(self.generation_config, "forced_decoder_ids") + and self.generation_config.forced_decoder_ids is not None + ): + forced_decoder_ids = self.generation_config.forced_decoder_ids + + if generation_config.return_timestamps: + logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] + + if len(forced_decoder_ids) > 0: + generation_config.forced_decoder_ids = forced_decoder_ids + + return super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": None, + } + + # + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings( + """ + Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks + like SUPERB Keyword Spotting. + """, + WHISPER_ENCODER_INPUTS_DOCSTRING, +) +class WhisperForAudioClassification(WhisperPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.encoder = WhisperEncoder(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. Only the projection layers and classification head will be updated. + """ + self.encoder._freeze_parameters() + + def get_input_embeddings(self) -> nn.Module: + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module): + self.encoder.set_input_embeddings(value) + + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + + >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) + >>> sample = next(iter(ds)) + + >>> inputs = feature_extractor( + ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> input_features = inputs.input_features + + >>> with torch.no_grad(): + ... logits = model(input_features).logits + + >>> predicted_class_ids = torch.argmax(logits).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + >>> predicted_label + 'af_za' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) \ No newline at end of file diff --git a/video_salmonn/model/openllama.py b/video_salmonn/model/openllama.py new file mode 100644 index 0000000..2094ec5 --- /dev/null +++ b/video_salmonn/model/openllama.py @@ -0,0 +1,2279 @@ +from header import * +import torch.nn.functional as F +from .ImageBind import * +from .ImageBind import data +from .modeling_llama import LlamaForCausalLM, LlamaConfig +# from .llama_attn_replace import replace_llama_attn_with_flash_attn +from .Qformer import BertConfig, BertLMHeadModel +from .modeling_whisper import WhisperModel +from .beats import BEATs, BEATsConfig +from .eva_vit import create_eva_vit_g +from transformers import StoppingCriteria, StoppingCriteriaList, BertTokenizer +from transformers import WhisperFeatureExtractor +import soundfile as sf +from peft.tuners.lora import LoraLayer +try: + import nemo.collections.asr as nemo_asr +except: + print("no nemo!") + +import torch +from torch.nn.utils import rnn + +class StoppingCriteriaSub(StoppingCriteria): + + def __init__(self, stops = [], encounters=1): + super().__init__() + self.stops = stops + self.ENCOUNTERS = encounters + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + stop_count = 0 + for stop in self.stops: + stop_count = (stop == input_ids[0]).sum().item() + if stop_count >= self.ENCOUNTERS: + return True + return False + +def modify_lora_layer(model, lora_alpha): + for name, layer in model.named_children(): + if isinstance(layer, LoraLayer): + layer.lora_alpha['default'] = lora_alpha + layer.scaling['default'] = lora_alpha / layer.r['default'] + if isinstance(layer, nn.Module): + modify_lora_layer(layer, lora_alpha) + +def build_one_instance(tokenizer, conversation, generate=False, prompt=None, use_llama2=False): + # with open("/mnt/bn/audio-visual-llm-data/yuwenyi/playground/pandagpt/code/prompt/alignment_speech_multitask.json", "r", encoding='utf-8') as f: + # prompt = json.load(f) + text_list = [] + turn_num = len(conversation) + input_ids, target_ids, instructs = [], [], [] + for i in range(turn_num): + turn = conversation[i] + role = turn['from'] + if i == 0: # the first human turn + assert role == 'human' + if turn['value'][1: -1] in prompt: + pc = random.choice(prompt[turn['value'][1: -1]]) + if use_llama2: + text = ' ' + pc + ' [/INST]' + else: + text = ' ' + pc + '\nASSISTANT:' + instructs.append(pc) + else: + if use_llama2: + text = ' ' + turn['value'] + ' [/INST]' + else: + text = ' ' + turn['value'] + '\nASSISTANT:' + instructs.append(turn['value']) + one_input_id = tokenizer(text, add_special_tokens=False).input_ids + input_ids += one_input_id + target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt + if generate: + return None, input_ids, target_ids, instructs + else: + if role == 'human': + if turn['value'][1: -1] in prompt: + pc = random.choice(prompt[turn['value'][1: -1]]) + if use_llama2: + text = '[INST] ' + pc + ' [/INST]' + else: + text = 'USER: ' + pc + '\nASSISTANT:' + instructs.append(pc) + else: + if use_llama2: + text = '[INST] ' + turn['value'] + ' [/INST]' + else: + text = 'USER: ' + turn['value'] + '\nASSISTANT:' + instructs.append(turn['value']) + one_input_id = tokenizer(text, add_special_tokens=False).input_ids + input_ids += one_input_id + target_ids += [-100]*len(one_input_id) + elif role == 'gpt': + # text = turn['value'] + '\n###' + text = turn['value'] + '' + one_input_id = tokenizer(text, add_special_tokens=False).input_ids + input_ids += one_input_id + target_ids += one_input_id + else: + raise Exception('Wrong Role!!!') + text_list.append(text) + assert len(input_ids) == len(target_ids) + return text_list, input_ids, target_ids, instructs + +def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len, modality='image', generate=False, prompt=None, use_llama2=False): + batch_input_ids, batch_target_ids, instructs = [], [], [] + for conversation in batch_of_conversations: + text_list, one_input_ids, one_target_ids, inst = build_one_instance(tokenizer, conversation, generate=generate, prompt=prompt, use_llama2=use_llama2) + batch_input_ids.append(torch.LongTensor(one_input_ids)) + batch_target_ids.append(torch.LongTensor(one_target_ids)) + instructs.append(inst[0]) # [Yu] TODO support multi-turn training + input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100) + assert input_ids.size() == target_ids.size() + input_ids = input_ids[:,:max_tgt_len] + target_ids = target_ids[:,:max_tgt_len] + attention_mask = input_ids.ne(tokenizer.pad_token_id) + assert attention_mask.size() == input_ids.size() + return input_ids, target_ids, attention_mask.long(), instructs + +def sinusoidal_position(max_len, d_model): + position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * -(torch.log(torch.tensor(10000.0)) / d_model)) + pos_enc = torch.zeros((max_len, d_model)) + pos_enc[:, 0::2] = torch.sin(position * div_term) + pos_enc[:, 1::2] = torch.cos(position * div_term) + return pos_enc + +PROMPT_START = 'USER: ' +dummy_image_path = ["./dummy/761183272.jpg"] +dummy_audio_path = ["./dummy/1272-128104-0000.flac"] +dummy_raw_audio, _ = sf.read("./dummy/1272-128104-0000.flac") +dummy_raw_audio = [[dummy_raw_audio]] + +class OpenLLAMAPEFTModel(nn.Module): + + '''LoRA for LLaMa model''' + + def __init__(self, **args): + super(OpenLLAMAPEFTModel, self).__init__() + self.args = args + imagebind_ckpt_path = args['imagebind_ckpt_path'] + vicuna_ckpt_path = args['vicuna_ckpt_path'] + max_tgt_len = args['max_tgt_len'] + stage = args['stage'] + self.use_whisper = (args["use_whisper"] == "true") + self.use_blip = (args["use_blip"] == "true") + self.instructblip = (args["instructblip"] == "true") + self.instructblip_video = (args["instructblip_video"] == "true") + self.skip_vqformer = (args["skip_vqformer"] == "true") + self.video_window_size = args["video_window_size"] + self.speech_qformer = (args["speech_qformer"] == "true") + self.early_align = (args["early_align"] == "true") + self.cascaded = args["cascaded"] + self.causal = (args["causal"] == "true") + self.diversity = (args["diversity_loss"] == "true") + self.diversity_loss_factor = args.get("diversity_loss_factor", 0.01) + self.causal_encoder = (args.get("causal_attention", "false") == "true") + self.groupsize = args.get("groupsize", 0) + self.alignmode = args.get("alignmode", 1) + self.modalitymask = (args.get("modalitymask", "false") == "true") + self.xsegalign = (args.get("xsegalign", "false") == "true") + self.seglen = args.get("seglen", 3) + self.pure_aud = args.get("pure_aud", False) + self.second_per_frame = args.get("second_per_frame", False) + self.second_stride = args.get("second_stride", False) + self.sin_pos = args.get("sin_pos", False) + self.use_beats = args.get("use_beats", False) + self.ps_instruct = args.get("ps_instruct", False) + self.ps_n_qformer_layers = args.get("ps_n_qformer_layers", 2) + self.n_pos = args.get("n_pos", 120) + self.use_nemo = args.get('use_nemo', False) + self.bilinear_pooling = args.get('bilinear_pooling', False) + self.ext_groupsize = args.get('ext_groupsize', None) + self.low_groupsize = args.get('low_groupsize', None) + self.high_groupsize = args.get('high_groupsize', None) + self.ext_same_qformer = args.get('ext_same_qformer', False) + self.add_time = args.get('add_time', False) + self.img_hi_rs = args.get('img_hi_rs', False) + self.img_hi_rs_cfg = args.get('img_hi_rs_cfg', None) + self.use_llama2 = args.get('use_llama2', False) + self.cache_dir = args.get('cache_dir', False) + # [npy] + self.use_npy = args.get("use_npy", False) + + self.PROMPT_START = 'USER: ' if not self.use_llama2 else '[INST] ' + + with open("./prompt/alignment_speech_multitask.json", "r", encoding='utf-8') as f: + self.prompt = json.load(f) + + if not self.pure_aud: + if self.use_blip: + print("Loading visual encoder ViT") + self.visual_encoder = create_eva_vit_g( + 224, 0, False, "fp16" + ) + print("Finished loading visual encoder ViT") + self.ln_vision = nn.LayerNorm(self.visual_encoder.num_features) + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + print("Loading Qformer") + self.num_query_token = 32 + self.Qformer, self.query_tokens = self.init_video_Qformer( + num_query_token = self.num_query_token, + vision_width=self.visual_encoder.num_features, + num_hidden_layers = -1, + cache_dir=self.cache_dir, + ) + if self.instructblip: + self.bert_tokenizer = BertTokenizer.from_pretrained( + "bert-base-uncased", truncation_side="left", cache_dir=self.cache_dir) + self.bert_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.Qformer.resize_token_embeddings(len(self.bert_tokenizer)) + else: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.Qformer.cls = None + + qformer_path = "./ckpt/pretrained_ckpt/instruct_blip_vicuna13b_trimmed.pth" + state_dict = torch.load(qformer_path, map_location="cpu")["model"] + msg = self.load_state_dict(state_dict, strict=False) + # print(f"Finished loading Qformer, Unused parameters: {msg}") + # Freeze Qformer + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.query_tokens.requires_grad = False + self.visual_hidden_size = self.Qformer.config.hidden_size + else: + print (f'Initializing visual encoder from {imagebind_ckpt_path} ...') + self.visual_encoder, self.visual_hidden_size = \ + imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path) + # freeze vision encoder + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder.eval() + print ('Visual encoder initialized.') + + if self.use_whisper: + print("Loading Whisper Model") + whispermodel = "openai/whisper-large-v2" + self.whispertransform = WhisperFeatureExtractor.from_pretrained(whispermodel, cache_dir=self.cache_dir) + self.speech_encoder = WhisperModel.from_pretrained(whispermodel, cache_dir=self.cache_dir).encoder + for name, param in self.speech_encoder.named_parameters(): + param.requires_grad = False + self.speech_encoder = self.speech_encoder.eval() + print("Freeze Whisper and Loading Whisper done.") + + if self.use_beats: + beatsmodel = "./ckpt/pretrained_ckpt/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" + beats_checkpoint = torch.load(beatsmodel, map_location="cpu") + beats_cfg = BEATsConfig(beats_checkpoint['cfg']) + beats = BEATs(beats_cfg) + beats.load_state_dict(beats_checkpoint['model']) + self.beats = beats + for name, param in self.beats.named_parameters(): + param.requires_grad = False + self.beats.eval() + print("Freeze BEATs and Loading BEATs done.") + + if args['flash_attn']: + replace_llama_attn_with_flash_attn() + + print (f'Initializing language decoder from {vicuna_ckpt_path} ...') + # add the lora module + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=self.args.get('yu_lora_r', 8), + lora_alpha=self.args.get('yu_lora_alpha', 32), + lora_dropout=self.args.get('yu_lora_dropout', 0.1), + # target_modules=['q_proj', 'v_proj'], + target_modules=self.args.get("lora_target_modules", ['q_proj', 'v_proj']), + ) + + self.llama_model = LlamaForCausalLM.from_pretrained( + vicuna_ckpt_path, + torch_dtype=torch.float16, + ) + if self.args['use_lora'] == 'true': + self.llama_model = get_peft_model(self.llama_model, peft_config) + self.llama_model.print_trainable_parameters() + else: + print("Not updating vicuna at all!") + # free vicuna + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + + self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + self.llama_tokenizer.padding_side = "right" + print ('Language decoder initialized.') + + self.use_qformer = args["qformer"] if "qformer" in args else "false" + if not self.pure_aud: + if self.use_qformer == 'true': + if self.use_whisper and not self.speech_qformer: + self.speech_pre_qformer_proj = nn.Linear(self.speech_encoder.config.d_model, self.visual_hidden_size) + if self.use_beats: + self.beats_pre_qformer_proj = nn.Linear(self.beats.cfg.encoder_embed_dim, self.visual_hidden_size) + if self.use_nemo and not self.speech_qformer: + self.speech_pre_qformer_proj = nn.Linear(512, self.visual_hidden_size) # TODO change 512 to a number in config + if self.early_align and self.alignmode == 2: + if self.speech_qformer: + self.visual_hidden_size = self.visual_hidden_size + self.speech_encoder.config.d_model + else: + if self.bilinear_pooling: + self.bp_proj = nn.Linear(self.visual_hidden_size, self.visual_hidden_size, bias=False) + self.bp_vis = nn.Linear(self.visual_hidden_size, self.visual_hidden_size, bias=False) + self.bp_whisper = nn.Linear(self.visual_hidden_size, self.visual_hidden_size, bias=False) + if self.use_beats: + self.bp_beats = nn.Linear(self.visual_hidden_size, self.visual_hidden_size, bias=False) + else: + if self.use_beats: + self.visual_hidden_size = self.visual_hidden_size * 3 + else: + self.visual_hidden_size = self.visual_hidden_size * 2 + self.ln_video = nn.LayerNorm(self.visual_hidden_size) + # if not self.use_whisper: + self.video_frame_position_embedding = nn.Embedding(4096, self.visual_hidden_size) + self.num_video_query_token = args['num_video_query'] + self.video_Qformer, self.video_query_tokens = self.init_video_Qformer( + num_query_token = self.num_video_query_token, + vision_width=self.visual_hidden_size, + num_hidden_layers = self.seglen if self.xsegalign else 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + if self.instructblip_video: + self.video_Qformer.resize_token_embeddings(len(self.bert_tokenizer)) + Qformer_embeddings = self.Qformer.bert.embeddings.state_dict() + self.video_Qformer.bert.embeddings.load_state_dict(Qformer_embeddings) + else: + self.video_Qformer.bert.embeddings.word_embeddings = None + self.video_Qformer.bert.embeddings.position_embeddings = None + if not self.xsegalign: + for layer in self.video_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.video_Qformer.cls = None + if self.ext_groupsize is not None: + self.llama_proj = nn.Linear( + self.video_Qformer.config.hidden_size * 3, self.llama_model.config.hidden_size + ) + elif self.low_groupsize is not None or self.high_groupsize is not None: + self.llama_proj = nn.Linear( + self.video_Qformer.config.hidden_size * 2, self.llama_model.config.hidden_size + ) + else: + self.llama_proj = nn.Linear( + self.video_Qformer.config.hidden_size, self.llama_model.config.hidden_size + ) + + if self.speech_qformer: + # A separate speech Qformer + self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) + self.speech_Qformer, self.speech_query_tokens = self.init_video_Qformer( + num_query_token = args.get("num_speech_query", self.num_video_query_token), + vision_width=self.speech_encoder.config.d_model, + num_hidden_layers = 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + if self.instructblip_video: + self.speech_Qformer.resize_token_embeddings(len(self.bert_tokenizer)) + else: + self.speech_Qformer.bert.embeddings.word_embeddings = None + self.speech_Qformer.bert.embeddings.position_embeddings = None + for layer in self.speech_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.speech_Qformer.cls = None + self.llama_proj_speech = nn.Linear( + self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size + ) + else: + self.speech_Qformer = self.video_Qformer + self.speech_query_tokens = self.video_query_tokens + self.llama_proj_speech = self.llama_proj + self.ln_speech = self.ln_video + + if self.early_align and self.alignmode == 3: + # A separate speech Qformer + self.speech_pre_qformer_proj = nn.Linear(self.speech_encoder.config.d_model, self.visual_hidden_size) + self.joint_size = self.visual_hidden_size * 2 + self.joint_frame_position_embedding = nn.Embedding(4096, self.joint_size) + self.ln_joint = nn.LayerNorm(self.joint_size) + self.joint_Qformer, self.joint_query_tokens = self.init_video_Qformer( + num_query_token = self.num_video_query_token, + vision_width=self.joint_size, + num_hidden_layers = 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + self.joint_Qformer.bert.embeddings.word_embeddings = None + self.joint_Qformer.bert.embeddings.position_embeddings = None + for layer in self.joint_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.joint_Qformer.cls = None + self.llama_proj_joint = nn.Linear( + self.joint_Qformer.config.hidden_size, self.llama_model.config.hidden_size + ) + elif self.early_align and self.alignmode == 2: + self.joint_Qformer = self.video_Qformer + self.joint_query_tokens = self.video_query_tokens + self.llama_proj_joint = self.llama_proj + self.ln_joint = self.ln_video + self.joint_frame_position_embedding = self.video_frame_position_embedding + if self.ext_groupsize is not None: + self.low_Qformer, self.low_query_tokens = self.init_video_Qformer( + num_query_token = int(self.num_video_query_token * self.ext_groupsize[0] / self.groupsize), + vision_width=self.visual_hidden_size, + num_hidden_layers = self.seglen if self.xsegalign else 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + self.high_Qformer, self.high_query_tokens = self.init_video_Qformer( + num_query_token = int(self.num_video_query_token * self.ext_groupsize[1] / self.groupsize), + vision_width=self.visual_hidden_size, + num_hidden_layers = self.seglen if self.xsegalign else 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + + self.low_Qformer.bert.embeddings.word_embeddings = None + self.low_Qformer.bert.embeddings.position_embeddings = None + if not self.xsegalign: + for layer in self.low_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.low_Qformer.cls = None + + self.high_Qformer.bert.embeddings.word_embeddings = None + self.high_Qformer.bert.embeddings.position_embeddings = None + if not self.xsegalign: + for layer in self.high_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.high_Qformer.cls = None + elif self.low_groupsize is not None: + self.low_Qformer, self.low_query_tokens = self.init_video_Qformer( + num_query_token = int(self.num_video_query_token * self.low_groupsize / self.groupsize), + vision_width=self.visual_hidden_size, + num_hidden_layers = self.seglen if self.xsegalign else 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + + self.low_Qformer.bert.embeddings.word_embeddings = None + self.low_Qformer.bert.embeddings.position_embeddings = None + if not self.xsegalign: + for layer in self.low_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.low_Qformer.cls = None + elif self.high_groupsize is not None: + self.high_Qformer, self.high_query_tokens = self.init_video_Qformer( + num_query_token = int(self.num_video_query_token * self.high_groupsize / self.groupsize), + vision_width=self.visual_hidden_size, + num_hidden_layers = self.seglen if self.xsegalign else 2, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + + self.high_Qformer.bert.embeddings.word_embeddings = None + self.high_Qformer.bert.embeddings.position_embeddings = None + if not self.xsegalign: + for layer in self.high_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.high_Qformer.cls = None + else: + self.llama_proj = nn.Linear( + self.visual_hidden_size, self.llama_model.config.hidden_size + ) + else: + if self.speech_qformer: + self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) + if self.use_beats: + self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim) + self.speech_Qformer, self.speech_query_tokens = self.init_video_Qformer( + num_query_token = args.get("num_speech_query"), + vision_width=self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim, + num_hidden_layers = self.ps_n_qformer_layers, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + else: + self.speech_Qformer, self.speech_query_tokens = self.init_video_Qformer( + num_query_token = args.get("num_speech_query"), + vision_width=self.speech_encoder.config.d_model, + num_hidden_layers = self.ps_n_qformer_layers, + causal_encoder=self.causal_encoder, + cache_dir=self.cache_dir, + ) + if self.ps_instruct: + bert_model = torch.load("/mnt/bn/audio-visual-llm-data/torch_home/hub/checkpoints/bert-base-uncased/pytorch_model.bin", map_location='cpu') + bert_emb_state_dict = {k: v for k, v in bert_model.items() if "bert.embeddings" in k} + self.speech_Qformer.load_state_dict(bert_emb_state_dict, strict=False) + + self.bert_tokenizer = BertTokenizer.from_pretrained( + "bert-base-uncased", truncation_side="left", cache_dir=self.cache_dir + ) + self.bert_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.speech_Qformer.resize_token_embeddings(len(self.bert_tokenizer)) + + for name, param in self.speech_Qformer.bert.embeddings.word_embeddings.named_parameters(): + param.requires_grad = False + for name, param in self.speech_Qformer.bert.embeddings.position_embeddings.named_parameters(): + param.requires_grad = False + + else: + self.speech_Qformer.bert.embeddings.word_embeddings = None + self.speech_Qformer.bert.embeddings.position_embeddings = None + for layer in self.speech_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.speech_Qformer.cls = None + self.llama_proj_speech = nn.Linear( + self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size + ) + else: + pass # [Yu] not implemented. + + if args["proj_checkpoint"] != "": + proj_state = torch.load(args["proj_checkpoint"]) + msg = self.load_state_dict(proj_state, strict=False) + + self.max_tgt_len = max_tgt_len + self.device = torch.cuda.current_device() + if args['delta_ckpt_path']: + delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=self.llama_model.device) + self.load_state_dict(delta_ckpt, strict=False) + + @classmethod + def init_video_Qformer(cls, num_query_token, vision_width, + num_hidden_layers=2, causal_encoder=False, cache_dir=""): + encoder_config = BertConfig.from_pretrained("bert-base-uncased", cache_dir=cache_dir) + if num_hidden_layers > 0: + encoder_config.num_hidden_layers = num_hidden_layers + encoder_config.cross_attention_freq = 1 + encoder_config.causal_encoder = causal_encoder + else: + encoder_config.cross_attention_freq = 2 + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.query_length = num_query_token + encoder_config.use_cache = False + # encoder_config.gradient_checkpointing = True + encoder_config.gradient_checkpointing = False + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + def encode_video(self, video_paths, instruction_inds=None, instruction_embs=None, earlyalign=False, audio_query=None, is_img=False): + if self.use_blip: + if is_img and self.img_hi_rs: + inputs, video_masks = data.load_and_transform_vision_data_blip(video_paths, self.device, self.training, hi_rs=True, hi_rs_cfg=self.img_hi_rs_cfg) + else: + inputs, video_masks = data.load_and_transform_video_data_blip(video_paths, self.device) + bsize, nframes = inputs.size(0), inputs.size(1) + inputs = inputs.to(self.llama_model.dtype).view( + bsize * nframes, inputs.size(2), inputs.size(3), inputs.size(4)) + with torch.no_grad(): + video_embeds = self.ln_vision(self.visual_encoder(inputs)) + video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(video_embeds.device) + query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1) + + if self.instructblip and instruction_inds is not None: + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(video_embeds.device) + instruction_mask = instruction_inds.attention_mask.unsqueeze(1).repeat(1, nframes, 1).view(bsize * nframes, -1) + Qformer_atts = torch.cat([query_atts, instruction_mask], dim=1) + input_ids = instruction_inds.input_ids.unsqueeze(1).repeat(1, nframes, 1).view(bsize * nframes, -1) + query_output = self.Qformer.bert( + input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=True, + ) + video_embeds = query_output.last_hidden_state # (B * T) * Q * H + if self.instructblip and instruction_inds is not None: + video_embeds = video_embeds[:, :self.num_query_token] + # delete later + # orig_video_embeds = video_embeds.reshape(bsize, nframes, self.num_query_token, video_embeds.size(-1)) + # sel_index = torch.linspace(0, nframes-1, steps=5).long().to(self.device) + # orig_video_embeds = orig_video_embeds[torch.arange(bsize), sel_index] + # orig_video_embeds = orig_video_embeds.reshape(bsize, 5*self.num_query_token, video_embeds.size(-1)) + + video_embeds = video_embeds.reshape(bsize, nframes * self.num_query_token, video_embeds.size(-1)) + video_masks = video_masks.unsqueeze(-1).repeat(1, 1, self.num_query_token).view(bsize, -1) + if self.video_window_size < nframes: + pad_len = (nframes // self.video_window_size + 1) * self.video_window_size - nframes + pad_len = int(pad_len * self.num_query_token) + n_windows = int(nframes // self.video_window_size + 1) + pad_video_embeds = video_embeds.new_zeros(bsize, pad_len, video_embeds.size(-1)) + pad_video_masks = video_masks.new_zeros(bsize, pad_len) + video_embeds = torch.cat([video_embeds, pad_video_embeds], dim=1).view( + bsize * n_windows, -1, video_embeds.size(-1)) + video_masks = torch.cat([video_masks, pad_video_masks], dim=1).view( + bsize * n_windows, -1) + else: + video_features, video_masks = data.load_and_transform_video_data_full(video_paths, self.device) + inputs = {ModalityType.VISION: video_features} + # convert into visual dtype + inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} + with torch.no_grad(): + embeddings = self.visual_encoder(inputs) + video_embeds = embeddings[ModalityType.VISION] # bsz x T x 1024 + + if earlyalign and self.alignmode != 3: + return video_embeds, video_masks, video_embeds + elif earlyalign and self.alignmode == 3: + pre_qformer_embeds = video_embeds + + if self.use_qformer == 'true': + video_embeds = self.ln_video(video_embeds) + position_ids = torch.arange(video_embeds.size(1), dtype=torch.long, device=video_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(video_embeds.size(0), -1) + # print(video_embeds.size(1)) + frame_position_embeddings = self.video_frame_position_embedding(position_ids) + frame_hidden_state = frame_position_embeddings + video_embeds + frame_atts = video_masks.long() + video_query_tokens = self.video_query_tokens.expand(frame_hidden_state.shape[0], -1, -1) + if audio_query is not None: + video_query_tokens = torch.cat([video_query_tokens, audio_query], dim=1) + if self.instructblip_video and instruction_inds is not None: + query_atts = torch.ones(video_query_tokens.size()[:-1], dtype=torch.long).to(video_embeds.device) + if self.video_window_size < nframes: + instruction_mask = instruction_inds.attention_mask.unsqueeze(1).repeat( + 1, n_windows, 1).view(bsize * n_windows, -1) + input_ids = instruction_inds.input_ids.unsqueeze(1).repeat( + 1, n_windows, 1).view(bsize * n_windows, -1) + else: + instruction_mask = instruction_inds.attention_mask + input_ids = instruction_inds.input_ids + Qformer_atts = torch.cat([query_atts, instruction_mask], dim=1) + video_query_output = self.video_Qformer.bert( + input_ids, + attention_mask=Qformer_atts, + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + # video_query_tokens = torch.cat([video_query_tokens, instruction_embs], dim=1) + else: + video_query_output = self.video_Qformer.bert( + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + video_embeds = video_query_output.last_hidden_state + video_embeds = video_embeds[:, :self.num_video_query_token] + if self.video_window_size < nframes: + video_embeds = video_embeds.view(bsize, -1, video_embeds.size(-1)) # bsz x Q*n_windows x embsize + inputs_llama = self.llama_proj(video_embeds) # bsz x Q x llama_size + atts_llama = (video_masks.sum(dim=-1) != 0).unsqueeze(1).repeat( + 1, self.num_query_token).view(bsize, -1) + else: + inputs_llama = self.llama_proj(video_embeds) # bsz x Q x llama_size + # delete later + # inputs_llama = self.llm_proj(orig_video_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + else: + atts_llama = video_masks.long() + inputs_llama = self.llama_proj(video_embeds) # bsz x T x llama_size + if earlyalign and self.alignmode == 3: + return pre_qformer_embeds, atts_llama, inputs_llama + else: + return inputs_llama, atts_llama, video_embeds + def load_video_npy_train(self, video_paths, instructs): + maxlen = 0 + video_embed_list = [] + padded_video_embeds_list = [] + padded_video_masks_list = [] + for idx, video_path in enumerate(video_paths): + video_npy_dict = np.load(video_path, allow_pickle=True) + instruct = instructs[idx] + if instruct not in video_npy_dict.item().keys(): + new_instruct = random.choice(list(video_npy_dict.item().keys())) + print(f'error, instruct "{instruct}" not in npy, use another to replace "{new_instruct}"') + instruct = new_instruct + video_embed = video_npy_dict.item()[instruct] + video_embed = torch.from_numpy(video_embed).squeeze(0).to(self.device).to(self.llama_model.dtype) + if video_embed.size(0) > maxlen: + maxlen = video_embed.size(0) + video_embed_list.append(video_embed) + + for video_embed in video_embed_list: + if video_embed.size(0) < maxlen: + diffsize = maxlen - video_embed.size(0) + padded_video_masks_list.append([1] * video_embed.size(0) + [0] * diffsize) + video_embed = torch.cat([video_embed, video_embed.new_zeros( + diffsize, video_embed.size(1))], dim=0) + else: + padded_video_masks_list.append([1] * video_embed.size(0)) + padded_video_embeds_list.append(video_embed) + + video_embeds = torch.stack(padded_video_embeds_list, dim=0).to(self.device) + video_masks = torch.tensor(padded_video_masks_list).to(self.device) + video_threadhold = int(30 * 2 * self.num_query_token) + if maxlen > video_threadhold: # npy training hardcode + video_embeds = video_embeds[:, :video_threadhold, :] + video_masks = video_masks[:, :video_threadhold] + return video_embeds, video_masks, video_embeds + + def encode_audio(self, audio_paths, instruction_inds=None, raw_audios=None, earlyalign=False, visual_query=None): + if self.use_whisper: + # For test only + if len(audio_paths) == 1 and isinstance(audio_paths[0], str): + audio, _ = sf.read(audio_paths[0]) + if len(audio.shape) == 2: + audio = audio[:, 0] + if len(audio) > 30 * 16000 and self.sin_pos: + audio_list = [audio[i: i + 30 * 16000] for i in range(0, len(audio), 30 * 16000)] + spectrogram_list = [] + for audio_piece in audio_list: + spectrogram_piece = self.whispertransform( + audio_piece, + sampling_rate=16000, + return_tensors="pt", + max_length=30 * 16000, + ) + spectrogram_list.append(spectrogram_piece["input_features"].squeeze()) + audio_paths = [torch.stack(spectrogram_list, dim=0)] + if self.use_beats: + raw_audios = [audio_list] + else: + spectrogram = self.whispertransform( + audio, + sampling_rate=16000, + return_tensors="pt", + max_length=30 * 16000, + ) + audio_paths = [spectrogram["input_features"].squeeze()] + if self.use_beats: + raw_audios = [[audio]] + + if isinstance(audio_paths, tuple): + audio_paths = list(audio_paths) + + for i in range(len(audio_paths)): + if audio_paths[i].dim() == 2: + audio_paths[i] = audio_paths[i].unsqueeze(0) + num_seg = [audio.shape[0] for audio in audio_paths] + + with torch.no_grad(): + audio_paths = torch.cat(audio_paths, dim=0) + audio_embeds = self.speech_encoder( + audio_paths.to(self.llama_model.dtype).to(self.speech_encoder.device), return_dict=True).last_hidden_state + if self.use_beats: + beats_features = [] + for raw_audio in raw_audios: + beats_feature = [torch.from_numpy(audio) for audio in raw_audio] + beats_feature_lens = torch.tensor([feature.shape[0] for feature in beats_feature]) + beats_feature = pad_sequence(beats_feature, batch_first=True, padding_value=0) + # if beats_feature.ndim == 1: + # beats_feature.unsqueeze(0) + beats_feature_mask = torch.arange(beats_feature.shape[1]).unsqueeze(0) >= beats_feature_lens.unsqueeze(1) + beats_features.append( + self.beats.extract_features(beats_feature.to(self.llama_model.dtype).to(self.beats.device), padding_mask=beats_feature_mask.to(self.beats.device), feature_only=True)[0] + ) + max_feature_len = max([feature.size(1) for feature in beats_features]) + for i in range(len(beats_features)): + if beats_features[i].size(1) < max_feature_len: + beats_features[i] = F.pad(beats_features[i], (0, 0, 0, max_feature_len - beats_features[i].size(1)), 'constant', 0) + beats_features = torch.cat(beats_features, dim=0) + if not self.speech_qformer and not self.pure_aud: + audio_embeds = self.speech_pre_qformer_proj(audio_embeds) + if self.use_beats: + if beats_features.size(1) < audio_embeds.size(1): + beats_features = F.pad(beats_features, (0, 0, 0, audio_embeds.size(1) - beats_features.size(1)), 'constant', 0).to(audio_embeds.device) + beats_features = self.beats_pre_qformer_proj(beats_features) + audio_embeds = torch.cat([audio_embeds, beats_features], dim=-1) + elif self.use_nemo: + # For test only + if len(audio_paths) == 1 and isinstance(audio_paths[0], str): + audio, _ = sf.read(audio_paths[0]) + if len(audio.shape) == 2: + audio = audio[:, 0] + if len(audio) > 30 * 16000 and self.sin_pos: + audio_list = [audio[i: i + 30 * 16000] for i in range(0, len(audio), 30 * 16000)] + spectrogram_list = [] + for audio_piece in audio_list: + spectrogram_piece = self.whispertransform( + audio_piece, + sampling_rate=16000, + return_tensors="pt", + max_length=30 * 16000, + ) + spectrogram_list.append(spectrogram_piece["input_features"].squeeze()) + audio_paths = [torch.stack(spectrogram_list, dim=0)] + if self.use_beats: + raw_audios = [audio_list] + else: + spectrogram = self.whispertransform( + audio, + sampling_rate=16000, + return_tensors="pt", + max_length=30 * 16000, + ) + audio_paths = [spectrogram["input_features"].squeeze()] + if self.use_beats: + raw_audios = [[audio]] + + if isinstance(audio_paths, tuple): + audio_paths = list(audio_paths) + + for i in range(len(audio_paths)): + if audio_paths[i].dim() == 2: + audio_paths[i] = audio_paths[i].unsqueeze(0) + num_seg = [audio.shape[0] for audio in audio_paths] + + with torch.no_grad(): + audio_paths = torch.cat(audio_paths, dim=0) + audio_embeds = self.speech_encoder( + audio_paths.to(self.llama_model.dtype).to(self.speech_encoder.device), return_dict=True).last_hidden_state + if self.use_beats: + beats_features = [] + for raw_audio in raw_audios: + beats_feature = [torch.from_numpy(audio) for audio in raw_audio] + beats_feature_lens = torch.tensor([feature.shape[0] for feature in beats_feature]) + beats_feature = pad_sequence(beats_feature, batch_first=True, padding_value=0) + # if beats_feature.ndim == 1: + # beats_feature.unsqueeze(0) + beats_feature_mask = torch.arange(beats_feature.shape[1]).unsqueeze(0) >= beats_feature_lens.unsqueeze(1) + beats_features.append( + self.beats.extract_features(beats_feature.to(self.llama_model.dtype).to(self.beats.device), padding_mask=beats_feature_mask.to(self.beats.device), feature_only=True)[0] + ) + max_feature_len = max([feature.size(1) for feature in beats_features]) + for i in range(len(beats_features)): + if beats_features[i].size(1) < max_feature_len: + beats_features[i] = F.pad(beats_features[i], (0, 0, 0, max_feature_len - beats_features[i].size(1)), 'constant', 0) + beats_features = torch.cat(beats_features, dim=0) + if not self.speech_qformer and not self.pure_aud: + audio_embeds = self.speech_pre_qformer_proj(audio_embeds) + if self.use_beats: + if beats_features.size(1) < audio_embeds.size(1): + beats_features = F.pad(beats_features, (0, 0, 0, audio_embeds.size(1) - beats_features.size(1)), 'constant', 0).to(audio_embeds.device) + beats_features = self.beats_pre_qformer_proj(beats_features) + audio_embeds = torch.cat([audio_embeds, beats_features], dim=-1) + else: + inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data_fulllen(audio_paths, self.device)} + # convert into visual dtype + inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} + with torch.no_grad(): + embeddings = self.visual_encoder(inputs) + audio_embeds = embeddings[ModalityType.AUDIO] # bsz x T x 1024 + + if earlyalign and self.alignmode != 3: + return audio_embeds, None, None + elif earlyalign and self.alignmode == 3: + pre_qformer_embed = audio_embeds + + if self.use_qformer == 'true': + audio_embeds = self.ln_speech(audio_embeds) + if self.use_beats: + beats_features = self.ln_audio(beats_features) + if beats_features.size(1) < audio_embeds.size(1): + beats_features = F.pad(beats_features, (0, 0, 0, audio_embeds.size(1) - beats_features.size(1)), 'constant', 0).to(audio_embeds.device) + audio_embeds = torch.cat([audio_embeds, beats_features], dim=-1) + if not self.pure_aud: + position_ids = torch.arange(audio_embeds.size(1), dtype=torch.long, device=audio_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(audio_embeds.size(0), -1) + frame_position_embeddings = self.video_frame_position_embedding(position_ids).mean() * 0 + frame_hidden_state = frame_position_embeddings + audio_embeds # audio/speech do not use pos enc + else: + frame_hidden_state = audio_embeds # audio/speech do not use pos enc + frame_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device) + # frame_atts = video_masks.long() + speech_query_tokens = self.speech_query_tokens.expand(audio_embeds.shape[0], -1, -1) + if visual_query is not None and not self.pure_aud: + speech_query_tokens = torch.cat([speech_query_tokens, visual_query], dim=1) + if self.instructblip_video and instruction_inds is not None and not self.pure_aud: + query_atts = torch.ones(speech_query_tokens.size()[:-1], dtype=torch.long).to(audio_embeds.device) + Qformer_atts = torch.cat([query_atts, instruction_inds.attention_mask],dim=1) + audio_query_output = self.speech_Qformer.bert( + instruction_inds.input_ids, + attention_mask=Qformer_atts, + query_embeds=speech_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + # video_query_tokens = torch.cat([video_query_tokens, instruction_embs], dim=1) + else: + # seg_pos_embs = sinusoidal_position(20, frame_hidden_state.shape[-1]).to(audio_embeds.device).to(frame_hidden_state.dtype) + seg_pos_embs = sinusoidal_position(self.n_pos, frame_hidden_state.shape[-1]).to(audio_embeds.device).to(frame_hidden_state.dtype) + seg_hidden_state = list(torch.split(frame_hidden_state, num_seg)) + fold_seg_hidden_state = [] + fold_size = [] + for seg, n in zip(seg_hidden_state, num_seg): + if self.sin_pos: + seg = (seg + seg_pos_embs[:n].unsqueeze(1).expand(seg.shape)).view(-1, seg.shape[-1]).unsqueeze(0) + else: + seg = seg.view(-1, seg.shape[-1]).unsqueeze(0) + + B, T, C = seg.shape + kernel = round(T * self.second_per_frame / 30.0 / n) + stride = round(T * self.second_stride / 30.0 / n) + kernel = (1, kernel) + stride = (1, stride) + seg_embeds_tr = seg.transpose(1, 2).unsqueeze(2) + seg_embeds_overlap = F.unfold(seg_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride) + _, _, L = seg_embeds_overlap.shape + seg_embeds_overlap = seg_embeds_overlap.view(B, -1, kernel[1], L) + seg_embeds_overlap = torch.permute(seg_embeds_overlap, [0, 3, 2, 1]) + seg_embeds = seg_embeds_overlap.reshape(-1, kernel[1], C) + fold_seg_hidden_state.append(seg_embeds) + fold_size.append(seg_embeds.shape[0]) + frame_hidden_state = torch.cat(fold_seg_hidden_state, dim=0) + frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(audio_embeds.device) + # frame_hidden_state = rnn.pad_sequence(seg_hidden_state, batch_first=True) + # length = torch.tensor([s.size(0) for s in seg_hidden_state], device=audio_embeds.device).unsqueeze(1) + # col_indices = torch.arange(frame_hidden_state.size(1), device=audio_embeds.device).unsqueeze(0) + # frame_atts = (col_indices < length).to(torch.long) + speech_query_tokens = self.speech_query_tokens.expand(frame_hidden_state.shape[0], -1, -1) + if self.ps_instruct: + inst_ids = torch.repeat_interleave( + instruction_inds.input_ids, + torch.tensor(fold_size, device=self.device), + dim=0 + ) + inst_atts = torch.repeat_interleave( + instruction_inds.attention_mask, + torch.tensor(fold_size, device=self.device), + dim=0 + ) + query_atts = torch.ones(speech_query_tokens.size()[:-1], dtype=torch.long).to(audio_embeds.device) + Qformer_atts = torch.cat([query_atts, inst_atts],dim=1) + audio_query_output = self.speech_Qformer.bert( + inst_ids.to(self.speech_Qformer.bert.device), + attention_mask=Qformer_atts.to(self.speech_Qformer.bert.device), + query_embeds=speech_query_tokens.to(self.speech_Qformer.bert.device), + encoder_hidden_states=frame_hidden_state.to(self.speech_Qformer.bert.device), + encoder_attention_mask=frame_atts.to(self.speech_Qformer.bert.device), + return_dict=True, + ) + else: + audio_query_output = self.speech_Qformer.bert( + query_embeds=speech_query_tokens.to(self.speech_Qformer.bert.device), + encoder_hidden_states=frame_hidden_state.to(self.speech_Qformer.bert.device), + encoder_attention_mask=frame_atts.to(self.speech_Qformer.bert.device), + return_dict=True, + ) + audio_embeds = audio_query_output.last_hidden_state + audio_embeds = audio_embeds[:, :speech_query_tokens.shape[1]] + inputs_llama = self.llama_proj_speech(audio_embeds) # bsz x Q x llama_size + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + seg_inputs_llama = list(torch.split(inputs_llama, fold_size)) + seg_inputs_llama = [seg.view(-1, seg.shape[-1]) for seg in seg_inputs_llama] + inputs_llama = rnn.pad_sequence(seg_inputs_llama, batch_first=True) + length = torch.tensor([s.size(0) for s in seg_inputs_llama], device=audio_embeds.device).unsqueeze(1) + col_indices = torch.arange(inputs_llama.size(1), device=audio_embeds.device).unsqueeze(0) + atts_llama = (col_indices < length).to(torch.long) + else: + # atts_llama = video_masks.long() + if not self.pure_aud: + inputs_llama = self.llama_proj(audio_embeds) # bsz x T x llama_size + else: + pass # [Yu] not implemented. + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + if earlyalign and self.alignmode == 3: + return pre_qformer_embed, inputs_llama + else: + return inputs_llama, audio_embeds, atts_llama + + def encode_image(self, image_paths, instruction_inds=None, instruction_embs=None, earlyalign=False, audio_query=None): + if self.use_blip: + inputs = data.load_and_transform_vision_data_blip(image_paths, self.device, self.training) + inputs = inputs.to(self.llama_model.dtype) + with torch.no_grad(): + image_embeds = self.ln_vision(self.visual_encoder(inputs)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + if self.instructblip and instruction_inds is not None: + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) + Qformer_atts = torch.cat([query_atts, instruction_inds.attention_mask],dim=1) + query_output = self.Qformer.bert( + instruction_inds.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + image_embeds = query_output.last_hidden_state # bsz x 32 x H + if self.instructblip and instruction_inds is not None: + image_embeds = image_embeds[:, :self.num_query_token] + else: + inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} + # convert into visual dtype + inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} + with torch.no_grad(): + embeddings = self.visual_encoder(inputs) + image_embeds = embeddings['vision'].unsqueeze(1) # bsz x 1 x 1024 + + if earlyalign and self.alignmode != 3: + return image_embeds, None + elif earlyalign and self.alignmode == 3: + pre_qformer_embed = image_embeds + + if self.use_qformer == 'true': + orig_image_embeds = image_embeds + image_embeds = self.ln_video(image_embeds) + position_ids = torch.arange(image_embeds.size(1), dtype=torch.long, device=image_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(image_embeds.size(0), -1) + frame_position_embeddings = self.video_frame_position_embedding(position_ids) + frame_hidden_state = frame_position_embeddings + image_embeds + frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(image_embeds.device) + video_query_tokens = self.video_query_tokens.expand(frame_hidden_state.shape[0], -1, -1) + if audio_query is not None: + video_query_tokens = torch.cat([video_query_tokens, audio_query], dim=1) + if self.instructblip_video and instruction_inds is not None: + query_atts = torch.ones(video_query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) + Qformer_atts = torch.cat([query_atts, instruction_inds.attention_mask],dim=1) + video_query_output = self.video_Qformer.bert( + instruction_inds.input_ids, + attention_mask=Qformer_atts, + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + # video_query_tokens = torch.cat([video_query_tokens, instruction_embs], dim=1) + else: + video_query_output = self.video_Qformer.bert( + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + image_embeds = video_query_output.last_hidden_state + image_embeds = image_embeds[:, :self.num_video_query_token] + if self.skip_vqformer: + image_embeds = image_embeds * 0 + orig_image_embeds + + inputs_llama = self.llama_proj(image_embeds) # bsz x 1 x llama_size + # delete later + # inputs_llama = self.llm_proj(orig_image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + if earlyalign and self.alignmode == 3: + return pre_qformer_embed, inputs_llama + else: + return inputs_llama, image_embeds + + def sequence_align_v2(self, video_embeds, video_masks, audio_embeds, inputmasks=None, instruction_inds=None, add_time=False): + """Args: + video_embeds: B x 32 T1 x D + video_mask: B x 32 T1 + audio_embeds: B x T1 x D + inputsmasks: B x 2 + """ + video_size = 32 + audio_size = 25 + bsize = video_embeds.size(0) + vid_T = video_embeds.size(1) // video_size + video_embeds = video_embeds.view(bsize, -1, video_size, video_embeds.size(-1)) + audio_embeds = audio_embeds.view(bsize, -1, audio_size, audio_embeds.size(-1)) + if self.early_align and self.alignmode == 3: + audio_embeds = self.speech_pre_qformer_proj(audio_embeds) + aud_T = audio_embeds.size(1) + if aud_T > vid_T: + diff_T = aud_T - vid_T + video_pad = video_embeds.new_zeros(bsize, diff_T, video_size, video_embeds.size(-1)) + video_embeds = torch.cat([video_embeds, video_pad], dim=1) + elif aud_T < vid_T: + diff_T = vid_T - aud_T + audio_pad = audio_embeds.new_zeros(bsize, diff_T, audio_size, audio_embeds.size(-1)) + audio_embeds = torch.cat([audio_embeds, audio_pad], dim=1) + audio_token_padding = audio_embeds.new_zeros(bsize, audio_embeds.size(1), video_size - audio_size, audio_embeds.size(-1)) + audio_embeds = torch.cat([audio_embeds, audio_token_padding], dim=2) + if inputmasks is not None: + video_embeds = video_embeds * inputmasks[:, 0:1].unsqueeze(-1).unsqueeze(-1) + audio_embeds = audio_embeds * inputmasks[:, 1:2].unsqueeze(-1).unsqueeze(-1) + concat_features = torch.cat([video_embeds, audio_embeds], dim=3).view( + bsize, video_embeds.size(1) * video_size, -1) + if self.bilinear_pooling: + if self.use_beats: + vis_feat, whisper_feat, beats_feat = torch.split(concat_features, self.visual_hidden_size, dim=-1) + bp_feat = self.bp_proj(F.tanh(self.bp_vis(vis_feat)) * F.tanh(self.bp_whisper(whisper_feat)) * F.tanh(self.bp_beats(beats_feat))) + concat_features = bp_feat + vis_feat + whisper_feat + beats_feat + else: + vis_feat, whisper_feat = torch.split(concat_features, self.visual_hidden_size, dim=-1) + bp_feat = self.bp_proj(F.tanh(self.bp_vis(vis_feat)) * F.tanh(self.bp_whisper(whisper_feat))) + concat_features = bp_feat + vis_feat + whisper_feat + + total_mask = concat_features.new_ones(concat_features.size()[:-1]) + vid_T = video_embeds.size(1) + + if self.ext_groupsize is not None: + hgroups = vid_T // self.ext_groupsize[1] + if vid_T % self.ext_groupsize[1] != 0: + hgroups = hgroups + 1 + diff_T = hgroups * self.ext_groupsize[1] - vid_T + concat_feature_paddings = concat_features.new_zeros(bsize, diff_T * video_size, concat_features.size(-1)) + concat_features = torch.cat([concat_features, concat_feature_paddings], dim=1) + total_mask = torch.cat([total_mask, total_mask.new_zeros(bsize, diff_T * video_size)], dim=1) + vid_T += diff_T + high_features = concat_features.view(bsize * hgroups, self.ext_groupsize[1] * video_size, concat_features.size(-1)) + high_mask = total_mask.view(bsize * hgroups, self.ext_groupsize[1] * video_size) + + lgroups = vid_T // self.ext_groupsize[0] + low_features = concat_features.view(bsize * lgroups, self.ext_groupsize[0] * video_size, concat_features.size(-1)) + low_mask = total_mask.view(bsize * lgroups, self.ext_groupsize[0] * video_size) + + low_embeds = self.ln_joint(low_features) # B x 32*T_max x D + position_ids = torch.arange(low_embeds.size(1), dtype=torch.long, device=low_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(low_embeds.size(0), -1) + frame_position_embeddings = self.joint_frame_position_embedding(position_ids) + low_hidden_state = frame_position_embeddings + low_embeds + low_query_tokens = self.low_query_tokens.expand(low_embeds.shape[0], -1, -1) + + if self.ext_same_qformer: + low_query_output = self.joint_Qformer.bert( + query_embeds=low_query_tokens, + encoder_hidden_states=low_hidden_state, + encoder_attention_mask=low_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + else: + low_query_output = self.low_Qformer.bert( + query_embeds=low_query_tokens, + encoder_hidden_states=low_hidden_state, + encoder_attention_mask=low_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + low_embeds = low_query_output.last_hidden_state + + high_embeds = self.ln_joint(high_features) # B x 32*T_max x D + position_ids = torch.arange(high_embeds.size(1), dtype=torch.long, device=high_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(high_embeds.size(0), -1) + frame_position_embeddings = self.joint_frame_position_embedding(position_ids) + high_hidden_state = frame_position_embeddings + high_embeds + high_query_tokens = self.high_query_tokens.expand(high_embeds.shape[0], -1, -1) + + if self.ext_same_qformer: + high_query_output = self.joint_Qformer.bert( + query_embeds=high_query_tokens, + encoder_hidden_states=high_hidden_state, + encoder_attention_mask=high_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + else: + high_query_output = self.high_Qformer.bert( + query_embeds=high_query_tokens, + encoder_hidden_states=high_hidden_state, + encoder_attention_mask=high_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + high_embeds = high_query_output.last_hidden_state + elif self.high_groupsize is not None: + hgroups = vid_T // self.high_groupsize + if vid_T % self.high_groupsize != 0: + hgroups = hgroups + 1 + diff_T = hgroups * self.high_groupsize - vid_T + concat_feature_paddings = concat_features.new_zeros(bsize, diff_T * video_size, concat_features.size(-1)) + concat_features = torch.cat([concat_features, concat_feature_paddings], dim=1) + total_mask = torch.cat([total_mask, total_mask.new_zeros(bsize, diff_T * video_size)], dim=1) + vid_T += diff_T + high_features = concat_features.view(bsize * hgroups, self.high_groupsize * video_size, concat_features.size(-1)) + high_mask = total_mask.view(bsize * hgroups, self.high_groupsize * video_size) + + high_embeds = self.ln_joint(high_features) # B x 32*T_max x D + position_ids = torch.arange(high_embeds.size(1), dtype=torch.long, device=high_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(high_embeds.size(0), -1) + frame_position_embeddings = self.joint_frame_position_embedding(position_ids) + high_hidden_state = frame_position_embeddings + high_embeds + high_query_tokens = self.high_query_tokens.expand(high_embeds.shape[0], -1, -1) + + if self.ext_same_qformer: + high_query_output = self.joint_Qformer.bert( + query_embeds=high_query_tokens, + encoder_hidden_states=high_hidden_state, + encoder_attention_mask=high_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + else: + high_query_output = self.high_Qformer.bert( + query_embeds=high_query_tokens, + encoder_hidden_states=high_hidden_state, + encoder_attention_mask=high_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + high_embeds = high_query_output.last_hidden_state + + if self.groupsize >= 1: + ngroups = vid_T // self.groupsize + if vid_T % self.groupsize != 0: + ngroups = ngroups + 1 + diff_T = ngroups * self.groupsize - vid_T + concat_feature_paddings = concat_features.new_zeros(bsize, diff_T * video_size, concat_features.size(-1)) + concat_features = torch.cat([concat_features, concat_feature_paddings], dim=1) + total_mask = torch.cat([total_mask, total_mask.new_zeros(bsize, diff_T * video_size)], dim=1) + vid_T += diff_T + if self.low_groupsize is not None: + lgroups = vid_T // self.low_groupsize + low_features = concat_features.view(bsize * lgroups, self.low_groupsize * video_size, concat_features.size(-1)) + low_mask = total_mask.view(bsize * lgroups, self.low_groupsize * video_size) + + low_embeds = self.ln_joint(low_features) # B x 32*T_max x D + position_ids = torch.arange(low_embeds.size(1), dtype=torch.long, device=low_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(low_embeds.size(0), -1) + frame_position_embeddings = self.joint_frame_position_embedding(position_ids) + low_hidden_state = frame_position_embeddings + low_embeds + low_query_tokens = self.low_query_tokens.expand(low_embeds.shape[0], -1, -1) + + if self.ext_same_qformer: + low_query_output = self.joint_Qformer.bert( + query_embeds=low_query_tokens, + encoder_hidden_states=low_hidden_state, + encoder_attention_mask=low_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + else: + low_query_output = self.low_Qformer.bert( + query_embeds=low_query_tokens, + encoder_hidden_states=low_hidden_state, + encoder_attention_mask=low_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + low_embeds = low_query_output.last_hidden_state + concat_features = concat_features.view(bsize * ngroups, self.groupsize * video_size, concat_features.size(-1)) + total_mask = total_mask.view(bsize * ngroups, self.groupsize * video_size) + + # Forward Q-Former + total_embeds = self.ln_joint(concat_features) # B x 32*T_max x D + position_ids = torch.arange(total_embeds.size(1), dtype=torch.long, device=total_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(total_embeds.size(0), -1) + frame_position_embeddings = self.joint_frame_position_embedding(position_ids) + frame_hidden_state = frame_position_embeddings + total_embeds + joint_query_tokens = self.joint_query_tokens.expand(total_embeds.shape[0], -1, -1) + + av_query_output = self.joint_Qformer.bert( + query_embeds=joint_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=total_mask, + return_dict=True, + segsize=ngroups if self.xsegalign else 0, + ) + total_embeds = av_query_output.last_hidden_state + if self.groupsize >= 1 and self.ext_groupsize is not None: + total_embeds = total_embeds.reshape(bsize, -1, total_embeds.size(-1)) + low_embeds = low_embeds.reshape(bsize, -1, low_embeds.size(-1)) + high_embeds = high_embeds.reshape(bsize, -1, high_embeds.size(-1)) + total_embeds = torch.cat((low_embeds, total_embeds, high_embeds), dim=-1) + elif self.groupsize >= 1 and self.low_groupsize is not None: + total_embeds = total_embeds.reshape(bsize, -1, total_embeds.size(-1)) + low_embeds = low_embeds.reshape(bsize, -1, low_embeds.size(-1)) + total_embeds = torch.cat((low_embeds, total_embeds), dim=-1) + elif self.groupsize >= 1 and self.high_groupsize is not None: + total_embeds = total_embeds.reshape(bsize, -1, total_embeds.size(-1)) + high_embeds = high_embeds.reshape(bsize, -1, high_embeds.size(-1)) + total_embeds = torch.cat((total_embeds, high_embeds), dim=-1) + if self.xsegalign and ngroups > self.seglen: + total_embeds = total_embeds[:, :self.num_video_query_token].reshape(bsize, ngroups, -1, total_embeds.size(-1)) + sel_index = torch.linspace(self.seglen-1, ngroups-1, steps=ngroups//self.seglen + 1).long().to(self.device) + total_embeds = total_embeds[:, sel_index] + inputs_llama = self.llama_proj_joint(total_embeds) + if self.groupsize >= 1 and self.ext_groupsize is None and self.low_groupsize is None and self.high_groupsize is None: + if add_time: + tlabels = np.round(np.tile((np.arange(vid_T) + 1), bsize) * 0.5 * self.groupsize, decimals=1).tolist() + tlabels = [str(t) + ' seconds' for t in tlabels] + tlabel_tokens = self.llama_tokenizer(tlabels, add_special_tokens=False).input_ids + if self.args['use_lora'] == 'true': + tlabel_embs = [self.llama_model.model.model.embed_tokens(torch.tensor(t).to(self.llama_model.model.model.device)) for t in tlabel_tokens] + else: + tlabel_embs = [self.llama_model.model.embed_tokens(torch.tensor(t).to(self.llama_model.model.device)) for t in tlabel_tokens] + feat_embs = torch.split(inputs_llama, 1, dim=0) + final_embs = [[] for _ in range(bsize)] + n = 0 + for f_emb, t_emb in zip(feat_embs, tlabel_embs): + final_embs[n // vid_T].append(torch.cat([f_emb.squeeze(0), t_emb], dim=0)) + n += 1 + final_embs = [torch.cat(fe, dim=0).unsqueeze(0) for fe in final_embs] + inputs_llama = torch.cat(final_embs, dim=0) + else: + inputs_llama = inputs_llama.reshape(bsize, -1, inputs_llama.size(-1)) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + return inputs_llama, atts_llama, total_embeds + + def sequence_align(self, video_embeds, video_mask, audio_embeds, inputmasks=None, instruction_inds=None): + """Args: + video_embeds: B x 32 T1 x D + video_mask: B x 32 T1 + audio_embeds: B x T1 x D + inputsmasks: B x 2 + """ + video_size = 32 * self.groupsize + audio_size = 25 * self.groupsize + bsize = video_embeds.size(0) + vid_T = video_embeds.size(1) // video_size + vid_T += 1 if video_embeds.size(1) % video_size != 0 else 0 + aud_T = audio_embeds.size(1) // audio_size + aud_T += 1 if video_embeds.size(1) % video_size != 0 else 0 + max_T = max(vid_T, aud_T) + audio_mask = video_mask.new_ones(audio_embeds.size()[:-1]) + if video_embeds.size(1) < max_T * video_size: + padlen = max_T * video_size - video_embeds.size(1) + video_pad = video_embeds.new_zeros(bsize, padlen, video_embeds.size(-1)) + video_mask_pad = video_mask.new_zeros(bsize, padlen) + video_embeds = torch.cat([video_embeds, video_pad], dim=1) + video_mask = torch.cat([video_mask, video_mask_pad], dim=1) + if audio_embeds.size(1) < max_T * audio_size: + padlen = max_T * audio_size - audio_embeds.size(1) + audio_pad = audio_embeds.new_zeros(bsize, padlen, audio_embeds.size(-1)) + audio_mask_pad = audio_mask.new_zeros(bsize, padlen) + audio_embeds = torch.cat([audio_embeds, audio_pad], dim=1) + audio_mask = torch.cat([audio_mask, audio_mask_pad], dim=1) + if inputmasks is not None: + video_embeds = video_embeds * inputmasks[:, 0:1].unsqueeze(-1) + audio_embeds = audio_embeds * inputmasks[:, 1:2].unsqueeze(-1) + video_embeds = video_embeds.view(-1, video_size, video_embeds.size(-1)) + audio_embeds = audio_embeds.view(-1, audio_size, audio_embeds.size(-1)) + video_mask = video_mask.view(-1, video_size) + audio_mask = audio_mask.view(-1, audio_size) + + video_embeds = self.ln_video(video_embeds) + audio_embeds = self.ln_speech(audio_embeds) + position_ids = torch.arange(video_size, dtype=torch.long, device=video_embeds.device) + position_ids = position_ids.unsqueeze(0).expand(video_embeds.size(0), -1) + frame_position_embeddings = self.video_frame_position_embedding(position_ids) + frame_hidden_state = frame_position_embeddings + video_embeds + video_query_tokens = self.video_query_tokens.expand(video_embeds.size(0), -1, -1) + audio_query_tokens = self.speech_query_tokens.expand(audio_embeds.size(0), -1, -1) + video_query_output = self.video_Qformer.bert( + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=video_mask, + return_dict=True, + ) + video_query_output = video_query_output.last_hidden_state + audio_query_output = self.speech_Qformer.bert( + query_embeds=audio_query_tokens, + encoder_hidden_states=audio_embeds, + encoder_attention_mask=audio_mask, + return_dict=True, + ) + audio_query_output = audio_query_output.last_hidden_state + + total_embeds = torch.cat([video_query_output, audio_query_output], dim=1) + inputs_llama = self.llama_proj(total_embeds) # B*T_max x Q x llama_size + inputs_llama = inputs_llama.view(bsize, -1, inputs_llama.size(-1)) # B x T_max*Q x llama_size + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 + return inputs_llama, atts_llama, video_query_output + + def prompt_wrap( + self, + img_embeds, + input_ids, + target_ids, + attention_mask, + audio_embs=None, + audiomasks=None, + img_mask=None, + ): + ''' + input_ids, target_ids, attention_mask: bsz x s2 + ''' + input_ids = input_ids.to(self.device) # bsz x s2 + target_ids = target_ids.to(self.device) # bsz x s2 + attention_mask = attention_mask.to(self.device) # bsz x s2 + + batch_size = img_embeds.shape[0] + p_before = self.PROMPT_START + p_before_tokens = self.llama_tokenizer(p_before, + return_tensors="pt", add_special_tokens=False).to(self.device) + if audio_embs is not None: + p_sep = "" + p_sep_tokens = self.llama_tokenizer(p_sep, + return_tensors="pt", add_special_tokens=False).to(self.device) + # peft model need deeper call + if self.args['use_lora'] == 'true': + p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids.to(self.llama_model.model.model.device)).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids.to(self.llama_model.model.model.device)).expand(batch_size, -1, -1) # bsz x s2 x embed_dim + if audio_embs is not None: + p_sep_embeds = self.llama_model.model.model.embed_tokens(p_sep_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s3 x embed_dim + else: + p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + p_after_embeds = self.llama_model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim + if audio_embs is not None: + p_sep_embeds = self.llama_model.model.embed_tokens(p_sep_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s3 x embed_dim + bos = torch.ones([batch_size, 1], + dtype=p_before_tokens.input_ids.dtype, + device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 + if self.args['use_lora'] == 'true': + bos_embeds = self.llama_model.model.model.embed_tokens(bos.to(self.llama_model.model.model.device)) # bsz x 1 x embed_dim + else: + bos_embeds = self.llama_model.model.embed_tokens(bos) # bsz x 1 x embed_dim + + if self.early_align and self.alignmode == 3: + joint_embeds = torch.cat(audio_embs + [img_embeds], dim=1) + inputs_embeds = torch.cat( + [bos_embeds, p_before_embeds, joint_embeds, p_after_embeds], dim=1) + empty_targets = ( + torch.ones([batch_size, 1+p_before_embeds.size(1)+joint_embeds.size(1)], + dtype=torch.long).to(self.device).fill_(-100) + ) + if audiomasks is not None: + audiomasks = audiomasks if self.training else audiomasks * 0 + 1 + joint_masks = torch.cat([audiomasks, audiomasks.new_ones(batch_size, 1)], dim=-1).to(self.device) + joint_masks = joint_masks.unsqueeze(-1).repeat(1, 1, self.num_video_query_token).view(batch_size, -1) + atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size(1)], dtype=torch.long).to(self.device) + atts_prefix = torch.cat([atts_prefix, joint_masks], dim=1) + else: + atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size(1)+joint_embeds.size(1)], + dtype=torch.long).to(self.device) + elif audio_embs is not None: + inputs_embeds = torch.cat( + [bos_embeds, p_before_embeds, img_embeds, p_sep_embeds, audio_embs, p_after_embeds], dim=1) + empty_targets = ( + torch.ones([batch_size, 1+p_before_embeds.size(1)+img_embeds.size(1)+p_sep_embeds.size(1)+audio_embs.size(1)], + dtype=torch.long).to(self.device).fill_(-100) + ) + if audiomasks is not None: + audiomasks = audiomasks if self.training else audiomasks * 0 + 1 + visual_masks = audiomasks[:, 0].unsqueeze(1).repeat(1, img_embeds.size(1)).to(self.device) + audio_masks = audiomasks[:, 1].unsqueeze(1).repeat(1, audio_embs.size(1)).to(self.device) + atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size(1)], + dtype=torch.long).to(self.device) + sep_masks = torch.ones([batch_size, p_sep_embeds.size(1)], dtype=torch.long).to(self.device) + atts_prefix = torch.cat([atts_prefix, visual_masks, sep_masks, audio_masks], dim=1) + else: + atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size(1)+img_embeds.size(1)+p_sep_embeds.size(1)+audio_embs.size(1)], + dtype=torch.long).to(self.device) + else: + inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) + empty_targets = ( + torch.ones([batch_size, 1+p_before_embeds.size()[1]+img_embeds.size(1)], + dtype=torch.long).to(self.device).fill_(-100) + ) + if img_mask is None: + atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+img_embeds.size(1)], dtype=torch.long).to(self.device) + else: + atts_prefix = torch.cat([torch.ones([batch_size, 1+p_before_embeds.size()[1]], dtype=torch.long).to(img_mask.device), img_mask], dim=1) + # try: + # targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2) + # except: + # import pdb; pdb.set_trace() + targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2) + attention_mask = torch.cat([atts_prefix, attention_mask.to(atts_prefix.device)], dim=1) + modality_lengths = atts_prefix.size(1) + assert attention_mask.size() == targets.size() # bsz x (1 + s1 + S + s2) + return inputs_embeds, targets, attention_mask, modality_lengths + + def calc_diversity_loss(self, query): + dotprod = torch.einsum("bij,bkj->bik", query, query) + modulus = torch.sqrt((query**2).sum(dim=-1)) + modulus = torch.einsum("bij,bjk->bik", modulus.unsqueeze(-1), modulus.unsqueeze(1)) + cos_sim = dotprod / (modulus + 1e-9) + diag_mask = (1 - torch.eye(modulus.size(1), device=self.device)).unsqueeze(0) + ave_sim = (cos_sim * diag_mask).sum() / ((modulus.size(1)**2 - modulus.size(1)) * modulus.size(0)) + return ave_sim + + def forward(self, inputs, reduction=True, generate=False, generate_config=None): + # print("begin forward, {}".format(int(os.getenv('RANK', '0')))) + output_texts = inputs['output_texts'] + if generate and len(output_texts[0]) > 2: + assert len(output_texts) == 1, "Only support bsz=1 for multi turn test!" + all_gen_text = [] + for n in range(len(output_texts[0]) // 2): + tmp_texts = [ + [ + output_texts[0][2 * n], output_texts[0][2 * n + 1] + ] + ] + if n == 0: + all_prompts = tmp_texts + else: + all_prompts = [ + [ + { + 'from': 'human', + 'value': f'{all_prompts[0][0]["value"]}\nASSISTANT: {gen_text}\n USER: {output_texts[0][2 * n]["value"]}' if not self.use_llama2 else f'{all_prompts[0][0]["value"]} [/INST] {gen_text}\n [INST]: {output_texts[0][2 * n]["value"]}', + }, + output_texts[0][2 * n + 1] + ] + ] + input_ids, target_ids, attention_mask, instructs = process_batch_instance( + self.llama_tokenizer, + tmp_texts, + self.max_tgt_len, + modality=inputs['modality'], + prompt=self.prompt, + use_llama2=self.use_llama2, + ) + + instruction_ids = None + dummy_instruct = None + if self.instructblip: + instruction_ids = self.bert_tokenizer( + instructs, + padding='longest', + truncation=True, + max_length=self.max_tgt_len, + return_tensors="pt", + ).to(self.device) + dummy_instruct = self.bert_tokenizer(["dummy"], return_tensors="pt").to(self.device) + + image_paths = inputs['image_paths'] + audio_embeds = None + diversity_loss = 0 + atts_llama = None + # print(inputs['modality']) + if inputs['modality'] == 'image': + img_embeds, img_query = self.encode_image(image_paths, instruction_ids) + if self.use_whisper: + dummy, _, _ = self.encode_audio(dummy_audio_path, dummy_instruct, raw_audios=dummy_raw_audio) + img_embeds = img_embeds + dummy.sum() * 0 + if not self.training: + if self.early_align and self.alignmode == 2: + Tvideo = dummy.size(1) // 25 + video_embs = img_embeds.unsqueeze(1).repeat(1, Tvideo, 1, 1).view(img_embeds.size(0), -1, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([1, 0]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + img_embeds, _, joint_query = self.sequence_align_v2( + video_embs, video_masks, dummy, audiomasks, instruction_inds=instruction_ids) + if self.diversity: + diversity_loss += self.calc_diversity_loss(img_query) + elif inputs['modality'] == 'video': + img_embeds, _, video_query = self.encode_video(image_paths, instruction_ids) + if self.diversity: + diversity_loss += self.calc_diversity_loss(video_query) + elif inputs['modality'] == 'audio': + img_embeds, audio_query, atts_llama = self.encode_audio(image_paths, instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + if self.speech_qformer and not self.pure_aud: + dummy, _ = self.encode_image(dummy_image_path, dummy_instruct) + img_embeds = img_embeds + dummy.sum() * 0 + if self.early_align: # TODO fix wrong dim + Tvideo = img_embeds.size(1) // 25 * 32 + if self.use_beats: + video_embs = img_embeds.new_zeros(img_embeds.size(0), Tvideo, img_embeds.size(2) // 2) + else: + video_embs = img_embeds.new_zeros(img_embeds.size(0), Tvideo, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([0, 1]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + img_embeds, _, audio_query = self.sequence_align_v2(video_embs, video_masks, img_embeds, audiomasks) + elif inputs['modality'] == 'audioimage': + image_paths = list(zip(*image_paths)) + query_mask = 1 + if self.cascaded == "audiogrounding": + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(audio_query.device) + inputs["audiomasks"] = None + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=audio_query*query_mask) + elif self.cascaded == "visualgrounding": + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(img_query.device) + inputs["audiomasks"] = None + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=img_query*query_mask) + elif self.cascaded == "bothgrounding": + _, pre_audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + _, pre_img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(pre_img_query.device) + inputs["audiomasks"] = None + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=pre_img_query*query_mask) + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=pre_audio_query*query_mask) + else: + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + # inputs["audiomasks"] = None + if self.early_align: + Tvideo = audio_embeds.size(1) // 25 # 0.5s per frame + video_embs = img_embeds.unsqueeze(1).repeat(1, Tvideo, 1, 1).view(img_embeds.size(0), -1, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([1, 1]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + if self.alignmode == 1: + img_embeds, _, joint_query = self.sequence_align( + video_embs, video_masks, audio_embeds, audiomasks, instruction_inds=instruction_ids) + else: + img_embeds, _, joint_query = self.sequence_align_v2( + video_embs, video_masks, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids) + if self.alignmode == 3: + audio_embeds = [audio_query, img_query] + else: + audio_embeds = None + if self.diversity: + if self.early_align: + diversity_loss = self.calc_diversity_loss(joint_query) + else: + diversity_loss = self.calc_diversity_loss(img_query) + elif inputs['modality'] == 'audiovideoimage': + image_paths = list(zip(*image_paths)) + if self.cascaded == "audiogrounding": + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + img_embeds, img_mask, video_query = self.encode_video( + image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=audio_query) + elif self.cascaded == "bothgrounding": + _, pre_audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + _, _, pre_video_query = self.encode_video(image_paths[1], instruction_ids, earlyalign=self.early_align) + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=pre_video_query) + img_embeds, img_mask, video_query = self.encode_video( + image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=pre_audio_query) + else: + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + img_embeds, img_mask, video_query = self.encode_video( + image_paths[1], instruction_ids, earlyalign=self.early_align) + if self.early_align: + if self.alignmode == 1: + img_embeds, _, joint_query = self.sequence_align( + img_embeds, img_mask, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids) + else: + img_embeds, _, joint_query = self.sequence_align_v2( + img_embeds, img_mask, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids) + if self.alignmode == 3: + audio_embeds = [audio_query, video_query] + else: + audio_embeds = None + if self.diversity: + if self.early_align: + diversity_loss = self.calc_diversity_loss(joint_query) + else: + diversity_loss = self.calc_diversity_loss(video_query) + else: + raise Exception("Undefined modality type") + # print("Finished encoder, {}".format(int(os.getenv('RANK', '0')))) + + gen_input_ids, gen_target_ids, gen_attention_mask, instructs = process_batch_instance( + self.llama_tokenizer, + all_prompts, + self.max_tgt_len, + modality=inputs['modality'], + generate=True, + prompt=self.prompt, + use_llama2=self.use_llama2, + ) + + gen_inputs_embeds, gen_targets, gen_attention_mask, gen_modality_lengths = self.prompt_wrap( + img_embeds, + gen_input_ids, + gen_target_ids, + gen_attention_mask, + audio_embs=audio_embeds, + audiomasks=inputs["audiomasks"], + img_mask=atts_llama + ) + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2], encounters=1)]) + gen_outputs = self.llama_model.generate( + inputs_embeds=gen_inputs_embeds.to(self.llama_model.device), + max_new_tokens=generate_config.get("max_new_tokens", self.max_tgt_len), + min_length=generate_config.get("min_length", 1), + do_sample=generate_config.get("do_sample", False), + num_beams=generate_config.get("num_beams", 5), + repetition_penalty=generate_config.get("repetition_penalty", 1.5), + length_penalty=generate_config.get("length_penalty", 1.0), + # length_penalty=generate_config.get("length_penalty", 1.5), + top_p=generate_config.get("top_p", 0.9), + stopping_criteria=stopping_criteria, + ) + gen_text = self.llama_tokenizer.batch_decode(gen_outputs, add_special_tokens=False)[0].replace("", "").replace("", "").strip() + all_gen_text.append(gen_text) + return [all_gen_text] + else: + input_ids, target_ids, attention_mask, instructs = process_batch_instance( + self.llama_tokenizer, + output_texts, + self.max_tgt_len, + modality=inputs['modality'], + prompt=self.prompt, + generate=generate, + use_llama2=self.use_llama2, + ) + + instruction_ids = None + dummy_instruct = None + # instruction_embs = None + if not self.pure_aud: + # instructs = [text[0]["value"] for text in output_texts] + if self.instructblip: + instruction_ids = self.bert_tokenizer( + instructs, + padding='longest', + truncation=True, + max_length=self.max_tgt_len, + return_tensors="pt", + ).to(self.device) + dummy_instruct = self.bert_tokenizer(["dummy"], return_tensors="pt").to(self.device) + + if self.ps_instruct: + instruction_ids = self.bert_tokenizer( + instructs, + padding='longest', + truncation=True, + max_length=self.max_tgt_len, + return_tensors="pt", + ).to(self.device) + + image_paths = inputs['image_paths'] + audio_embeds = None + diversity_loss = 0 + atts_llama = None + # print(inputs['modality']) + if inputs['modality'] == 'image': + img_embeds, img_query = self.encode_image(image_paths, instruction_ids) + if self.use_whisper: + dummy, _, _ = self.encode_audio(dummy_audio_path, dummy_instruct, raw_audios=dummy_raw_audio) + img_embeds = img_embeds + dummy.sum() * 0 + if not self.training: + if self.early_align and self.alignmode == 2: + Tvideo = dummy.size(1) // 25 + video_embs = img_embeds.unsqueeze(1).repeat(1, Tvideo, 1, 1).view(img_embeds.size(0), -1, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([1, 0]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + img_embeds, _, joint_query = self.sequence_align_v2( + video_embs, video_masks, dummy, audiomasks, instruction_inds=instruction_ids) + if self.diversity: + diversity_loss += self.calc_diversity_loss(img_query) + elif inputs['modality'] == 'video': + img_embeds, _, video_query = self.encode_video(image_paths, instruction_ids) + if self.diversity: + diversity_loss += self.calc_diversity_loss(video_query) + elif inputs['modality'] == 'audio': + img_embeds, audio_query, atts_llama = self.encode_audio(image_paths, instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + if self.speech_qformer and not self.pure_aud: + dummy, _ = self.encode_image(dummy_image_path, dummy_instruct) + img_embeds = img_embeds + dummy.sum() * 0 + if self.early_align: # TODO fix wrong dim + Tvideo = img_embeds.size(1) // 25 * 32 + if self.use_beats: + video_embs = img_embeds.new_zeros(img_embeds.size(0), Tvideo, img_embeds.size(2) // 2) + else: + video_embs = img_embeds.new_zeros(img_embeds.size(0), Tvideo, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([0, 1]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + img_embeds, _, audio_query = self.sequence_align_v2(video_embs, video_masks, img_embeds, audiomasks) + elif inputs['modality'] == 'audioimage': + image_paths = list(zip(*image_paths)) + query_mask = 1 + if self.cascaded == "audiogrounding": + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(audio_query.device) + inputs["audiomasks"] = None + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=audio_query*query_mask) + elif self.cascaded == "visualgrounding": + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(img_query.device) + inputs["audiomasks"] = None + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=img_query*query_mask) + elif self.cascaded == "bothgrounding": + _, pre_audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + _, pre_img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + if inputs["audiomasks"] is not None: + query_mask = inputs["audiomasks"].unsqueeze(-1).unsqueeze(-1).to(pre_img_query.device) + inputs["audiomasks"] = None + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=pre_img_query*query_mask) + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=pre_audio_query*query_mask) + else: + # import pdb; pdb.set_trace() + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + if self.img_hi_rs: + video_embs, video_masks, video_query = self.encode_video(image_paths[1], instruction_ids, earlyalign=self.early_align, is_img=True) + else: + img_embeds, img_query = self.encode_image(image_paths[1], instruction_ids, earlyalign=self.early_align) + # inputs["audiomasks"] = None + if self.early_align: + if not self.img_hi_rs: + Tvideo = audio_embeds.size(1) // 25 # 0.5s per frame + video_embs = img_embeds.unsqueeze(1).repeat(1, Tvideo, 1, 1).view(img_embeds.size(0), -1, img_embeds.size(2)) + video_masks = torch.ones(video_embs.size()[:-1], dtype=torch.long).to(video_embs.device) + audiomasks = torch.tensor([1, 1]).unsqueeze(0).repeat(video_masks.size(0), 1).to(video_embs.device) + if self.alignmode == 1: + img_embeds, _, joint_query = self.sequence_align( + video_embs, video_masks, audio_embeds, audiomasks, instruction_inds=instruction_ids) + else: + img_embeds, _, joint_query = self.sequence_align_v2( + video_embs, video_masks, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids) + if self.alignmode == 3: + audio_embeds = [audio_query, img_query] + else: + audio_embeds = None + if self.diversity: + if self.early_align: + if self.ext_groupsize is None: + if self.low_groupsize is not None: + joint_query = torch.split(joint_query, joint_query.size(-1) // 2, dim=-1)[1] + ngroups = joint_query.size(1) // self.num_video_query_token + joint_query = joint_query.reshape(joint_query.size(0) * ngroups, -1, joint_query.size(-1)) + diversity_loss = self.calc_diversity_loss(joint_query) + else: + mid_query, high_query = torch.split(joint_query, joint_query.size(-1) // 3, dim=-1)[1:] + ngroups = mid_query.size(1) // self.num_video_query_token + mid_query = mid_query.reshape(mid_query.size(0) * ngroups, -1, mid_query.size(-1)) + hgroups = high_query.size(1) // int(self.num_video_query_token * self.ext_groupsize[1] / self.groupsize) + high_query = high_query.reshape(high_query.size(0) * hgroups, -1, high_query.size(-1)) + diversity_loss = self.calc_diversity_loss(mid_query) + self.calc_diversity_loss(high_query) + else: + diversity_loss = self.calc_diversity_loss(img_query) + elif inputs['modality'] == 'audiovideoimage': + image_paths = list(zip(*image_paths)) + if self.cascaded == "audiogrounding": + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + img_embeds, img_mask, video_query = self.encode_video( + image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=audio_query) + elif self.cascaded == "bothgrounding": + _, pre_audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align) + _, _, pre_video_query = self.encode_video(image_paths[1], instruction_ids, earlyalign=self.early_align) + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, earlyalign=self.early_align, visual_query=pre_video_query) + img_embeds, img_mask, video_query = self.encode_video( + image_paths[1], instruction_ids, earlyalign=self.early_align, audio_query=pre_audio_query) + else: + audio_embeds, audio_query, _ = self.encode_audio(image_paths[0], instruction_ids, raw_audios=inputs["raw_audios"], earlyalign=self.early_align) + if self.use_npy and self.training: # npy training + img_embeds, img_mask, video_query = self.load_video_npy_train(image_paths[1], instructs=instructs) + else: + img_embeds, img_mask, video_query = self.encode_video(image_paths[1], instruction_ids, earlyalign=self.early_align) + if self.early_align: + if self.alignmode == 1: + img_embeds, _, joint_query = self.sequence_align( + img_embeds, img_mask, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids) + else: + if generate and inputs["audiomasks"][0].sum() == 1: + audio_embeds = img_embeds.new_zeros( + img_embeds.size(0), + img_embeds.size(1)//32*25, + audio_embeds.size(2) + ) + img_embeds, _, joint_query = self.sequence_align_v2( + img_embeds, img_mask, audio_embeds, inputs["audiomasks"].to(self.device), instruction_inds=instruction_ids, add_time=self.add_time) + if self.alignmode == 3: + audio_embeds = [audio_query, video_query] + else: + audio_embeds = None + if self.diversity: + if self.early_align: + if self.ext_groupsize is None: + if self.low_groupsize is not None: + joint_query = torch.split(joint_query, joint_query.size(-1) // 2, dim=-1)[1] + ngroups = joint_query.size(1) // self.num_video_query_token + joint_query = joint_query.reshape(joint_query.size(0) * ngroups, -1, joint_query.size(-1)) + diversity_loss = self.calc_diversity_loss(joint_query) + else: + mid_query, high_query = torch.split(joint_query, joint_query.size(-1) // 3, dim=-1)[1:] + ngroups = mid_query.size(1) // self.num_video_query_token + mid_query = mid_query.reshape(mid_query.size(0) * ngroups, -1, mid_query.size(-1)) + hgroups = high_query.size(1) // int(self.num_video_query_token * self.ext_groupsize[1] / self.groupsize) + high_query = high_query.reshape(high_query.size(0) * hgroups, -1, high_query.size(-1)) + diversity_loss = self.calc_diversity_loss(mid_query) + self.calc_diversity_loss(high_query) + else: + diversity_loss = self.calc_diversity_loss(video_query) + else: + raise Exception("Undefined modality type") + # print("Finished encoder, {}".format(int(os.getenv('RANK', '0')))) + + if generate: + gen_input_ids, gen_target_ids, gen_attention_mask, instructs = process_batch_instance( + self.llama_tokenizer, + output_texts, + self.max_tgt_len, + modality=inputs['modality'], + generate=True, + prompt=self.prompt, + use_llama2=self.use_llama2, + ) + + gen_inputs_embeds, gen_targets, gen_attention_mask, gen_modality_lengths = self.prompt_wrap( + img_embeds, + gen_input_ids, + gen_target_ids, + gen_attention_mask, + audio_embs=audio_embeds, + audiomasks=inputs["audiomasks"], + img_mask=atts_llama + ) + + if not isinstance(generate_config, dict): + generate_config = {} + + if len(generate_config) != 0: + lora_alpha = generate_config.get("lora_alpha", self.args.get('yu_lora_alpha', 32)) + modify_lora_layer(self.llama_model, lora_alpha) + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2], encounters=1)]) + gen_outputs = self.llama_model.generate( + inputs_embeds=gen_inputs_embeds.to(self.llama_model.device), + max_new_tokens=generate_config.get("max_new_tokens", self.max_tgt_len), + min_length=generate_config.get("min_length", 1), + do_sample=generate_config.get("do_sample", False), + num_beams=generate_config.get("num_beams", 5), + repetition_penalty=generate_config.get("repetition_penalty", 1.5), + length_penalty=generate_config.get("length_penalty", 1.0), + # length_penalty=generate_config.get("length_penalty", 0.1), + top_p=generate_config.get("top_p", 0.9), + stopping_criteria=stopping_criteria, + ) + gen_text = self.llama_tokenizer.batch_decode(gen_outputs, add_special_tokens=False) + + if len(generate_config) != 0: + modify_lora_layer(self.llama_model, self.args.get('yu_lora_alpha', 32)) + + return gen_text + + else: + inputs_embeds, targets, attention_mask, modality_lengths = self.prompt_wrap( + img_embeds, + input_ids, + target_ids, + attention_mask, + audio_embs=audio_embeds, + audiomasks=inputs["audiomasks"], + img_mask=atts_llama + ) + # print("Finished prompt wrap, {}".format(int(os.getenv('RANK', '0')))) + + outputs = self.llama_model( + inputs_embeds=inputs_embeds.to(self.llama_model.device), + attention_mask=attention_mask.to(self.llama_model.device), + return_dict=True, + labels=targets.to(self.llama_model.device), + modality_lengths=modality_lengths if self.modalitymask else 0, + ) + loss = outputs.loss + if self.diversity: + if not self.training: + print(diversity_loss) + loss += diversity_loss * self.diversity_loss_factor + # print("Finished vicuna forward, {}".format(int(os.getenv('RANK', '0')))) + # calculate the token accuarcy + chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1] + labels = targets[:, 2:] + gen_acc = (chosen_tokens.reshape(-1) == labels.to(chosen_tokens.device).reshape(-1)).to(torch.long) # [B*S] + valid_mask = (labels != -100).reshape(-1) + valid_tokens = gen_acc & valid_mask.to(gen_acc.device) # [B*S] + gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() + + return loss, gen_acc + + def extract_multimodal_feature(self, inputs): + features = [] + instruction_ids = None + instruction_embs = None + instructs = [inputs["prompt"]] + if self.instructblip and not self.pure_aud: + instruction_ids = self.bert_tokenizer( + instructs, + padding='longest', + truncation=True, + max_length=self.max_tgt_len, + return_tensors="pt", + ).to(self.device) + + feature_dict = {} + audio_query = None + img_query = None + video_query = None + if inputs['image_paths']: + if self.cascaded in ["audiogrounding", "bothgrounding"] and inputs['audio_paths']: + _, audio_query, _ = self.encode_audio(inputs['audio_paths'], instruction_ids) + elif self.cascaded == ["audiogrounding", "bothgrounding"]: + _, audio_query, _ = self.encode_audio(dummy_audio_path, instruction_ids) + audio_query = audio_query * 0 + image_embeds, image_query = self.encode_image(inputs['image_paths'], instruction_ids, audio_query=audio_query, earlyalign=self.early_align) + feature_dict["image"] = image_embeds + if inputs['audio_paths']: + pre_vis_query = None + if self.cascaded == "bothgrounding" and inputs['image_paths']: + _, pre_vis_query = self.encode_image(inputs['image_paths'], instruction_ids) + elif self.cascaded == "bothgrounding" and inputs['video_paths']: + _, _, pre_vis_query = self.encode_video(inputs['video_paths'], instruction_ids) + elif self.cascaded == "bothgrounding": + _, pre_vis_query = self.encode_image(dummy_image_path, instruction_ids) + pre_vis_query = pre_vis_query * 0 + audio_embeds, audio_query, _ = self.encode_audio(inputs['audio_paths'], instruction_ids, visual_query=pre_vis_query, earlyalign=self.early_align) + feature_dict["audio"] = audio_embeds + if inputs['video_paths']: + pre_audio_query = None + if self.cascaded in ["audiogrounding", "bothgrounding"] and inputs['audio_paths']: + _, pre_audio_query, _ = self.encode_audio(inputs['audio_paths'], instruction_ids) + elif self.cascaded in ["audiogrounding", "bothgrounding"]: + _, pre_audio_query, _ = self.encode_audio(dummy_audio_path, instruction_ids) + pre_audio_query = pre_audio_query * 0 + video_embeds, video_mask, video_query = self.encode_video(inputs['video_paths'], instruction_ids, earlyalign=self.early_align, audio_query=pre_audio_query) + feature_dict["video"] = video_embeds + + if self.early_align: + if not inputs['audio_paths'] and inputs["video_paths"]: + audio_embeds = video_embeds.new_zeros( + video_embeds.size(0), + video_embeds.size(1)//32*25, + self.speech_encoder.config.d_model if self.speech_qformer else video_embeds.size(-1), + ) + audiomasks = torch.tensor([1, 0]).unsqueeze(0).repeat(video_mask.size(0), 1).to(video_embeds.device) + elif not inputs["video_paths"]: + if inputs["image_paths"]: + if not inputs["audio_paths"]: + inputs["audio_paths"] = dummy_audio_path + audio_embeds, audio_query, _ = self.encode_audio(inputs['audio_paths'], instruction_ids, earlyalign=self.early_align) + if inputs["audio_paths"]: + video_embeds = image_embeds.unsqueeze(1).repeat(1, audio_embeds.size(1)//25, 1, 1).view( + image_embeds.size(0), -1, image_embeds.size(2)) + video_mask = audio_embeds.new_ones(video_embeds.size()[:-1]) + audiomasks = torch.tensor([1, 1]).unsqueeze(0).repeat(video_mask.size(0), 1).to(self.device) + else: + video_embeds = image_embeds.unsqueeze(1).repeat(1, 60, 1, 1).view( + image_embeds.size(0), -1, image_embeds.size(2)) + audio_embeds = video_embeds.new_zeros( + video_embeds.size(0), + video_embeds.size(1)//32*25, + self.speech_encoder.config.d_model if self.speech_qformer else video_embeds.size(-1), + ) + video_mask = audio_embeds.new_ones(video_embeds.size()[:-1]) + audiomasks = torch.tensor([1, 0]).unsqueeze(0).repeat(video_mask.size(0), 1).to(video_embeds.device) + elif inputs["audio_paths"]: + video_embeds = audio_embeds.new_zeros( + audio_embeds.size(0), + audio_embeds.size(1)//25*32, + self.visual_hidden_size if self.speech_qformer else audio_embeds.size(2), + ).to(self.device) + # image_embeds, pre_vis_query = self.encode_image(dummy_image_path, instruction_ids, earlyalign=self.early_align) + # video_embeds = image_embeds.unsqueeze(1).repeat(1, audio_embeds.size(1)//25, 1, 1).view( + # image_embeds.size(0), -1, image_embeds.size(2)) + video_mask = audio_embeds.new_ones(video_embeds.size()[:-1]).to(self.device) + audiomasks = torch.tensor([1, 1]).unsqueeze(0).repeat(video_mask.size(0), 1).to(self.device) + else: + raise Exception("Early align mode has to have either audio or video!") + elif inputs["video_paths"] and inputs["audio_paths"]: + audiomasks = torch.tensor([1, 1]).unsqueeze(0).repeat(video_mask.size(0), 1).to(self.device) + else: + raise Exception("Early align mode has to have either audio or video!") + if self.alignmode == 1: + video_embeds, _, video_query = self.sequence_align( + video_embeds, video_mask, audio_embeds, audiomasks, instruction_inds=instruction_ids) + else: + video_embeds, _, video_query = self.sequence_align_v2( + video_embeds, video_mask, audio_embeds, audiomasks, instruction_inds=instruction_ids) + if self.diversity: + # Calculate cosine similarity + ave_sim = self.calc_diversity_loss(video_query) + post_ave_sim = self.calc_diversity_loss(video_embeds) + print("Average cosine similarity: {}".format(ave_sim)) + print("Average post cosine similarity: {}".format(post_ave_sim)) + if self.alignmode == 3: + video_embeds = torch.cat([audio_query, video_query if video_query is not None else image_query, video_embeds], dim=1) + feature_dict["video"] = video_embeds + features = [] + if "video" in feature_dict: + features.append(feature_dict["video"]) + elif "image" in feature_dict: + features.append(feature_dict["image"]) + if "audio" in feature_dict and not self.early_align: + features.append(feature_dict["audio"]) + return features + + def prepare_generation_embedding(self, inputs): + prompt = inputs['prompt'] + if len(inputs['modality_embeds']) == 1: + feature_embeds = inputs['modality_embeds'][0] + else: + feature_embeds = self.extract_multimodal_feature(inputs) + inputs['modality_embeds'].append(feature_embeds) + modality_mask = inputs["avmask"] if "avmask" in inputs else [1, 1] + + batch_size = feature_embeds[0].shape[0] + p_before = self.PROMPT_START + p_before_tokens = self.llama_tokenizer(p_before, + return_tensors="pt", add_special_tokens=False).to(self.device) + if self.args['use_lora'] == 'true': + p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + else: + p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + if self.use_llama2: + text = ' ' + prompt + ' [/INST]' + else: + text = ' ' + prompt + '\nASSISTANT:' + textsep = '' + sep_tokens = self.llama_tokenizer(textsep, add_special_tokens=False, return_tensors='pt').to(self.device) + p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device) + if self.args['use_lora'] == 'true': + p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + sep_embeds = self.llama_model.model.model.embed_tokens(sep_tokens.input_ids).expand(batch_size, -1, -1) + else: + p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim + sep_embeds = self.llama_model.model.embed_tokens(sep_tokens.input_ids).expand(batch_size, -1, -1) + # delete later + # p_after_embeds = self.llama_tokenizer("" + prompt, add_special_tokens=False, return_tensors='pt').to(self.device) + # p_after_embeds = self.llama_model.model.embed_tokens(p_after_embeds.input_ids).expand(batch_size, -1, -1) + bos = torch.ones([batch_size, 1], + dtype=p_before_tokens.input_ids.dtype, + device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 + if self.args['use_lora'] == 'true': + bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim + else: + bos_embeds = self.llama_model.model.embed_tokens(bos) # bsz x 1 x embed_dim + if len(feature_embeds) == 1: + feature_embeds[0] = feature_embeds[0] * modality_mask[0] + inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds[0], p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim + # delete later + # inputs_embeds = torch.cat([feature_embeds[0], p_after_embeds], dim=1) + else: + totalemb = [bos_embeds, p_before_embeds] + for k, feature_emb in enumerate(feature_embeds[:-1]): + totalemb.append(feature_emb * modality_mask[k]) + totalemb.append(sep_embeds) + totalemb.append(feature_embeds[-1] * modality_mask[-1]) + totalemb.append(p_after_embeds) + inputs_embeds = torch.cat(totalemb, dim=1) + return inputs_embeds + + def generate_npy(self, inputs, video_name, npy_prefix): + video_paths = inputs['video_paths'] + instructs = [text[0]["value"] for text in inputs['output_texts']] + self.encode_video_and_save(video_paths, video_name, npy_prefix, instructs=instructs) + return + + def encode_video_and_save(self, video_paths, video_name, npy_path_prefix, instruction_inds=None, instruction_embs=None, earlyalign=False, audio_query=None, instructs=None): + if self.use_blip: + # import pdb; pdb.set_trace() + inputs, video_masks = data.load_and_transform_video_data_blip(video_paths, self.device) + bsize, nframes = inputs.size(0), inputs.size(1) + inputs = inputs.to(self.llama_model.dtype).view( + bsize * nframes, inputs.size(2), inputs.size(3), inputs.size(4)) + with torch.no_grad(): + video_embeds = self.ln_vision(self.visual_encoder(inputs)) + video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(video_embeds.device) + query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1) + if self.instructblip: + assert instructs + # get instruction_ids + instruction_ids = self.bert_tokenizer( + instructs, + padding='longest', + truncation=True, + max_length=self.max_tgt_len, + return_tensors="pt", + ).to(self.device) + instruction_inds = instruction_ids + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(video_embeds.device) + instruction_mask = instruction_inds.attention_mask.unsqueeze(1).repeat(1, nframes, 1).view(bsize * nframes, -1) + Qformer_atts = torch.cat([query_atts, instruction_mask], dim=1) + input_ids = instruction_inds.input_ids.unsqueeze(1).repeat(1, nframes, 1).view(bsize * nframes, -1) + query_output = self.Qformer.bert( + input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=True, + ) + video_embeds = query_output.last_hidden_state # (B * T) * Q * H + if self.instructblip and instruction_inds is not None: + video_embeds = video_embeds[:, :self.num_query_token] + + video_embeds = video_embeds.reshape(bsize, nframes * self.num_query_token, video_embeds.size(-1)) + video_masks = video_masks.unsqueeze(-1).repeat(1, 1, self.num_query_token).view(bsize, -1) + + # wav_name = os.path.basename(video_name) + wav_name = video_name + # vid = wav_name.split('.')[0] + npy_name = wav_name.replace('.mp4', ".npy") + # special for /mnt/bn/audio-visual-llm-data2/datasets/only_need_video + npy_name = npy_name.replace('/mnt/bn/audio-visual-llm-data2/datasets/cxz/', '') + npy_name = npy_name.replace('/', '-') + npy_path = os.path.join(npy_path_prefix, npy_name) + video_embeds_np = video_embeds.cpu().numpy() + if os.path.exists(npy_path): + ddd = np.load(npy_path, allow_pickle=True) + ddd.item()[instructs[0]] = video_embeds_np # [B, 1280, 768] + else: + ddd = {instructs[0]: video_embeds_np} + np.save(npy_path, ddd) + print(f"finish {video_name} to {npy_path}") + return + + def generate(self, inputs): + ''' + inputs = { + 'image_paths': optional, + 'audio_paths': optional + 'video_paths': optional + 'thermal_paths': optional + 'mode': generation mode, + 'prompt': human input prompt, + 'max_tgt_len': generation length, + 'top_p': top_p, + 'temperature': temperature + 'modality_embeds': None or torch.tensor + 'modality_cache': save the image cache + } + ''' + input_embeds = self.prepare_generation_embedding(inputs) + # stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)]) + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2], encounters=1)]) + outputs = self.llama_model.generate( + inputs_embeds=input_embeds, + max_new_tokens=inputs['max_tgt_len'], + min_length=1, + top_p=inputs['top_p'], + temperature=inputs['temperature'], + do_sample=inputs.get('dosample', False), + use_cache=True, + num_beams=1, + repetition_penalty=1.5, + length_penalty=1.0, + stopping_criteria=stopping_criteria, + ) + output_text = self.llama_tokenizer.decode(outputs[0], skip_special_tokens=True) + if "\n#" in output_text: + output_text = output_text.split("\n#")[0] + # elif "\n\n" in output_text: + # output_text = output_text.split("\n\n")[0] + return output_text + + def calc_entropy(self, inputs): + with torch.no_grad(): + input_embeds = self.extract_multimodal_feature(inputs) + output_texts = inputs["prompt"] + input_ids, target_ids, attention_mask, instructs = process_batch_instance( + self.llama_tokenizer, + output_texts, + self.max_tgt_len, + prompt=self.prompt, + use_llama2=self.use_llama2, + ) + inputs_embeds, targets, attention_mask = self.prompt_wrap( + input_embeds[0], + input_ids, + target_ids, + attention_mask, + ) + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + logits = outputs.logits[0][-(targets[0] != -100).sum().item():] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + entropy = -torch.sum(log_probs * torch.exp(log_probs), dim=-1) + return entropy \ No newline at end of file diff --git a/video_salmonn/prompt/alignment_speech_multitask.json b/video_salmonn/prompt/alignment_speech_multitask.json new file mode 100644 index 0000000..4024fc9 --- /dev/null +++ b/video_salmonn/prompt/alignment_speech_multitask.json @@ -0,0 +1,206 @@ +{ + "asr": [ + "Can you transcribe the speech into a written format?", + "Listen to the speech and write down its content.", + "What is the content of the speech you heard?", + "Please write down the transcription of the speech.", + "Please transcribe the speech into a written format.", + "Write down the content of the speech you heard.", + "Can you write down the transcription of the speech?", + "Put the speech into a written format.", + "Please help me to transcribe the speech into a written format.", + "Recognize the content of the speech you heard.", + "Can you recognize what you heard in the speech?", + "Recognize the speech and write it down in a written format.", + "Listen to the speech and recognize its content.", + "Give me the transcription of the speech you heard.", + "Recognize the speech and give me the transcription." + ], + "asr_zh": [ + "前面的语音说了什么?", + "请将语音中的内容写下来。", + "听前面的音频,写出对方说的内容。", + "写下你听到的内容。", + "请记下语音中人说的话。", + "仔细听这段语音,记下语音中的话", + "将你听到的话写下来", + "这个人说了什么?请记下来。", + "请将语音转换为文字", + "请识别这个人说的内容" + ], + "asr_de": [ + "Können Sie die Rede in ein schriftliches Format übertragen?", + "Hören Sie sich die Rede an und schreiben Sie ihren Inhalt auf.", + "Bitte notieren Sie die Transkription der Rede.", + "Geben Sie mir die Transkription der Rede, die Sie gehört haben.", + "Was hat dieser Mann gesagt? Bitte schreiben Sie es auf.", + "Können Sie die Transkription der Rede aufschreiben?", + "Hören Sie der Stimme aufmerksam zu und notieren Sie die Wörter in der Stimme", + "Schreiben Sie auf, was Sie hören.", + "Bitte Sprache in Text umwandeln.", + "Erkennen Sie den Inhalt der Rede, die Sie gehört haben." + ], + "translation_ec": [ + "Can you translate the speech into Chinese?", + "Please translate the speech you heard into Chinese.", + "Listen to the speech and translate it into Chinese.", + "Give me the Chinese translation of this speech.", + "Could you please provide a Chinese translation for the speech?", + "Would you be willing to translate the speech into Chinese for me?", + "Would you be able to render the speech in Chinese?", + "Could you assist me in translating the speech into Chinese?", + "Can you help me convert the speech into Chinese text?", + "Please convert the speech into Chinese text.", + "请将这段语音的内容翻译成中文。", + "你能把这段语音用中文表达出来吗?", + "请将你听到的语音用中文写出来。" + ], + "translation_ce": [ + "Can you translate the speech into English?", + "Please translate the speech you heard into English.", + "Listen to the speech and translate it into English.", + "Give me the English translation of this speech.", + "Could you please provide an English translation for the speech?", + "Would you be willing to translate the speech into English for me?", + "Would you be able to render the speech in English", + "Could you assist me in translating the speech into English?", + "Can you help me convert the speech into English text?", + "Please convert the speech into English text.", + "请将这段语音的内容翻译成英文。", + "你能把这段语音用英文表达出来吗?", + "请将你听到的语音用英文写出来。" + ], + "translation_ce_pe": [ + "Can you translate the speech into English?", + "Please translate the speech you heard into English.", + "Listen to the speech and translate it into English.", + "Give me the English translation of this speech.", + "Could you please provide an English translation for the speech?", + "Would you be willing to translate the speech into English for me?", + "Would you be able to render the speech in English", + "Could you assist me in translating the speech into English?", + "Can you help me convert the speech into English text?", + "Please convert the speech into English text." + ], + "count_audio": [ + "Can you tell me how many pieces of speeches are there in the audio?", + "Can you tell me how many pieces of speeches is this audio consists of?", + "Please count the number of pieces of speeches in the audio you heard.", + "Please tell me the number of speeches in this audio.", + "Listen to the speech and tell me hwo many fragments are there in it.", + "The speech you heard might consist of several speeches and please count the number of speeches.", + "How many pieces of speeches did you hear in this audio?" + ], + "count_word": [ + "How many words did you hear in this audio?", + "Can you tell me how many words are there in the audio?", + "How many words does this speech you hear consist of?", + "Please count the number of words in this speech.", + "Listen to the speech and count how many words are there.", + "Please help me record the number of words the speech you heard consists of.", + "Give me the number of words in this speech." + ], + "audiocaption": [ + "Listen to this audio clip and provide its caption.", + "Describe the following audio in a caption.", + "Based on the sound you hear, create a caption for this audio.", + "Can you describe the scene or event depicted in this audio?", + "Could you summarise what's happening in this audio?", + "What does this audio describe?", + "Please describe the audio." + ], + "audiocaption_v2": [ + "Please write down what your hear in the audio." + ], + "QA": [ + "{}" + ], + "inference_QA": [ + "{}" + ], + "gender_QA": [ + "{}" + ], + "gender_recognition": [ + "What is the gender of the speaker?", + "Use one word to describe the speaker's gender.", + "Describe the speaker's gender.", + "Can you accurately identify the gender of the speaker?", + "Can you distinguish the gender of the speaker?", + "Describe the gender of the person speaking.", + "What is the speaker's gender based on the audio?", + "Tell me about the gender of the person you hear.", + "Is the speaker male or female?" + ], + "info_retrieval":[ + "Please check the time when each word appears and ends in that speech.", + "Write the start and end time of each word in the speech to form a sequence." + ], + "phone_recognition": [ + "Please transcribe the audio clip into its corresponding phonetic representation.", + "Write the sequence of phonemes corresponding to this speech.", + "Provide the phonetic transcription for the speech.", + "Transcribe the phonemes for the speech please.", + "Can you recognize the phonetic representation in the speech?", + "Listen to the speech and recognize its phonetic representation", + "What is the phoneme transcription of the speech?" + ], + "audio_speech_description":[ + "Describe the speech and the background audio", + "Record the speech and the background audio" + ], + "speech_separation": [ + "There are two people talking in the audio, please write what they say in order.", + "Please write down what you hear each person says.", + "Can you record what each person says?", + "Transcribe the words spoken by each person in the audio." + ], + "speech_query": [ + "Please answer the question in the speech.", + "Please answer the question.", + "Can you answer the previous question?", + "There is a question in the speech. Please answer it." + ], + "emotion_recognition": [ + "Describe the emotion of the speaker in one word.", + "Use one word to describe the speaker's emotion." + ], + "music_description": [ + "Listen to this music clip and describe the music.", + "Please describe the music.", + "Provide a description of the music.", + "Analyze the music in this clip and offer a description.", + "Give me a description of the music in this clip." + ], + "lyrics_recognition": [ + "Listen to the song and write down its content.", + "Please recognize what the singer sings.", + "Please write down the transcription of the speech.", + "Give me the transcription of the song you heard.", + "Recognize the song and give me the transcription.", + "Listen to the song and recognize its content.", + "Please help me to transcribe the speech into a written format.", + "Can you recognize what you heard in the song?" + ], + "speaker_verification": [ + "Are the two people speaking successively the same person? Answer yes or no.", + "Do you only hear the same person talking? Answer yes or no.", + "Is only one person speaking in the audio? Answer yes or no." + ], + "fluent_speech_audio": [ + "Describe the background audio and the speech in a fluent sentence." + ], + "audio_story_telling": [ + "Based on the audio, write a story in detail. Your story should be highly related to the audio." + ], + "speech_audio_query": [ + "Please answer the speaker's question in detail based on the background sound." + ], + "summarization": [ + "Please give me the summarization of the speech before.", + "summarise the speech before into several sentences." + ], + "speaker_asr": [ + "Here are several people talking. Please recognize and write down what they are saying and add a speaker number to each sentence." + ] +} \ No newline at end of file diff --git a/video_salmonn/videosalmonn.yml b/video_salmonn/videosalmonn.yml new file mode 100644 index 0000000..03b0013 --- /dev/null +++ b/video_salmonn/videosalmonn.yml @@ -0,0 +1,309 @@ +name: videosalmonn +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.08.22=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.11=h7f8727e_2 + - pip=23.2.1=py39h06a4308_0 + - python=3.9.18=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.0.0=py39h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py39h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - accelerate==0.23.0 + - addict==2.4.0 + - aiofiles==23.2.1 + - aiohttp==3.8.6 + - aiosignal==1.3.1 + - aliyun-python-sdk-core==2.14.0 + - aliyun-python-sdk-kms==2.16.2 + - altair==5.1.2 + - aniso8601==9.0.1 + - annotated-types==0.6.0 + - antlr4-python3-runtime==4.9.3 + - anyio==3.7.1 + - asgiref==3.7.2 + - async-timeout==4.0.3 + - attrs==23.1.0 + - automata-toolkit==1.0.2 + - av==12.3.0 + - azure-core==1.30.1 + - azure-identity==1.15.0 + - azure-keyvault-secrets==4.8.0 + - backoff==2.2.1 + - bcrypt==4.1.2 + - bidict==0.23.1 + - bitsandbytes==0.41.1 + - blinker==1.6.3 + - blis==0.7.11 + - cachetools==5.3.1 + - catalogue==2.0.10 + - certifi==2023.7.22 + - chardet==5.2.0 + - charset-normalizer==3.3.0 + - chroma-hnswlib==0.7.3 + - chromadb==0.4.21 + - click==8.1.7 + - cloudpathlib==0.16.0 + - cmake==3.25.0 + - colorama==0.4.6 + - coloredlogs==14.0 + - confection==0.1.4 + - contourpy==1.1.1 + - cpm-kernels==1.0.11 + - crcmod==1.7 + - cryptography==41.0.7 + - cycler==0.12.1 + - cymem==2.0.8 + - dataclasses-json==0.6.1 + - datasets==2.16.0 + - deprecated==1.2.14 + - dill==0.3.7 + - distance==0.1.3 + - distro==1.8.0 + - dnspython==2.3.0 + - editdistance==0.6.2 + - einops==0.7.0 + - en-core-web-sm==3.7.1 + - et-xmlfile==1.1.0 + - eventlet==0.36.1 + - exceptiongroup==1.1.3 + - fastapi==0.103.2 + - ffmpeg==1.4 + - ffmpy==0.3.1 + - filelock==3.12.4 + - flask==2.2.5 + - flask-cors==4.0.1 + - flask-restful==0.3.10 + - flask-socketio==5.3.6 + - flask-talisman==1.1.0 + - flatbuffers==23.5.26 + - fonttools==4.43.1 + - frozenlist==1.4.0 + - fsspec==2023.9.2 + - ftfy==6.2.3 + - fvcore==0.1.5.post20221221 + - g2p==2.0.0 + - g2p-en==2.1.0 + - gast==0.5.4 + - gitdb==4.0.10 + - gitpython==3.1.37 + - google-auth==2.25.2 + - googleapis-common-protos==1.62.0 + - gradio==3.47.1 + - gradio-client==0.6.0 + - greenlet==3.0.0 + - grpcio==1.60.0 + - h11==0.14.0 + - httpcore==0.18.0 + - httptools==0.6.1 + - httpx==0.25.0 + - huggingface-hub==0.20.1 + - humanfriendly==10.0 + - idna==3.4 + - imageio==2.31.5 + - importlib-metadata==6.8.0 + - importlib-resources==6.1.0 + - inflect==7.2.1 + - iniconfig==2.0.0 + - iopath==0.1.10 + - isodate==0.6.1 + - itsdangerous==2.2.0 + - jinja2==3.1.2 + - jiwer==3.0.3 + - jmespath==0.10.0 + - joblib==1.4.2 + - jsonpatch==1.33 + - jsonpointer==2.4 + - jsonschema==4.19.1 + - jsonschema-specifications==2023.7.1 + - kiwisolver==1.4.5 + - kubernetes==28.1.0 + - langchain==0.0.344 + - langchain-core==0.0.13 + - langcodes==3.3.0 + - langsmith==0.0.75 + - latex2mathml==3.76.0 + - lazy-loader==0.3 + - lit==15.0.7 + - loguru==0.7.2 + - markdown==3.5 + - markdown-it-py==3.0.0 + - markupsafe==2.1.3 + - marshmallow==3.20.1 + - matplotlib==3.8.0 + - mdtex2html==1.2.0 + - mdurl==0.1.2 + - mmh3==4.0.1 + - modelscope==1.10.0 + - monotonic==1.6 + - more-itertools==10.1.0 + - mpmath==1.3.0 + - msal==1.28.0 + - msal-extensions==1.1.0 + - multidict==6.0.4 + - multiprocess==0.70.15 + - munkres==1.1.4 + - murmurhash==1.0.10 + - mypy-extensions==1.0.0 + - networkx==3.1 + - nltk==3.8.1 + - numpy==1.24.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.2.140 + - nvidia-nvtx-cu12==12.1.105 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - onnxruntime==1.16.3 + - openai==0.28.1 + - openpyxl==3.1.2 + - opentelemetry-api==1.22.0 + - opentelemetry-exporter-otlp-proto-common==1.22.0 + - opentelemetry-exporter-otlp-proto-grpc==1.22.0 + - opentelemetry-instrumentation==0.43b0 + - opentelemetry-instrumentation-asgi==0.43b0 + - opentelemetry-instrumentation-fastapi==0.43b0 + - opentelemetry-proto==1.22.0 + - opentelemetry-sdk==1.22.0 + - opentelemetry-semantic-conventions==0.43b0 + - opentelemetry-util-http==0.43b0 + - orjson==3.9.9 + - oss2==2.18.3 + - overrides==7.4.0 + - packaging==23.2 + - pandas==2.1.1 + - panphon==0.20.0 + - parameterized==0.9.0 + - peft==0.5.0 + - pillow==10.0.1 + - platformdirs==4.1.0 + - pluggy==1.3.0 + - portalocker==2.8.2 + - posthog==3.1.0 + - preshed==3.0.9 + - progress==1.6 + - protobuf==4.24.4 + - psutil==5.9.5 + - pulsar-client==3.3.0 + - pyarrow==13.0.0 + - pyarrow-hotfix==0.6 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pycryptodome==3.19.0 + - pydantic==2.4.2 + - pydantic-core==2.10.1 + - pydeck==0.8.1b0 + - pydub==0.25.1 + - pygments==2.16.1 + - pyjwt==2.8.0 + - pyparsing==3.1.1 + - pypika==0.48.9 + - pypinyin==0.51.0 + - pytest==7.4.3 + - python-dateutil==2.8.2 + - python-dotenv==1.0.0 + - python-engineio==4.9.1 + - python-graphviz==0.16 + - python-multipart==0.0.6 + - python-socketio==5.11.2 + - pytorchvideo==0.1.5 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - rapidfuzz==3.4.0 + - referencing==0.30.2 + - regex==2023.10.3 + - reportlab==4.1.0 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rich==13.6.0 + - rpds-py==0.10.6 + - rsa==4.9 + - safetensors==0.4.2 + - scikit-image==0.22.0 + - scipy==1.11.3 + - semantic-version==2.10.0 + - sentencepiece==0.1.99 + - simple-websocket==1.0.0 + - simplejson==3.19.2 + - six==1.16.0 + - smart-open==6.4.0 + - smmap==5.0.1 + - sniffio==1.3.0 + - sortedcontainers==2.4.0 + - spacy==3.7.4 + - spacy-legacy==3.0.12 + - spacy-loggers==1.0.5 + - sqlalchemy==2.0.22 + - srsly==2.4.8 + - starlette==0.27.0 + - streamlit==1.27.2 + - sympy==1.12 + - tabulate==0.9.0 + - tenacity==8.2.3 + - termcolor==2.4.0 + - text-unidecode==1.3 + - textdistance==4.6.0 + - thinc==8.2.3 + - tifffile==2023.9.26 + - tiktoken==0.5.1 + - timm==1.0.8 + - tokenizers==0.19.1 + - toml==0.10.2 + - tomli==2.0.1 + - toolz==0.12.0 + - torch==2.0.0+cu118 + - torchaudio==2.0.1+cu118 + - torchvision==0.15.1+cu118 + - tornado==6.3.3 + - tqdm==4.66.1 + - transformers==4.40.1 + - transformers-stream-generator==0.0.4 + - triton==2.0.0 + - typeguard==4.3.0 + - typer==0.9.0 + - typing-extensions==4.12.0 + - typing-inspect==0.9.0 + - tzdata==2023.3 + - tzlocal==5.1 + - ujson==4.0.2 + - unicodecsv==0.14.1 + - urllib3==1.26.6 + - uvicorn==0.23.2 + - uvloop==0.19.0 + - validators==0.22.0 + - wasabi==1.1.2 + - watchdog==3.0.0 + - watchfiles==0.21.0 + - wcwidth==0.2.13 + - weasel==0.3.4 + - websocket-client==1.7.0 + - websockets==11.0.3 + - werkzeug==3.0.3 + - wrapt==1.16.0 + - wsproto==1.2.0 + - xxhash==3.4.1 + - yacs==0.1.8 + - yapf==0.40.2 + - yarl==1.9.2 + - zipp==3.17.0