-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #61 from bytedance/videosalmonn
feat: add video salmonn
- Loading branch information
Showing
50 changed files
with
14,016 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
## Inference | ||
|
||
### Preparation | ||
Install the environment with the following specified config: | ||
``` | ||
conda env create -f videosalmonn.yml | ||
``` | ||
Create directory to store checkpoints (If modify the structure/rename directories, need to change config files and model files accordingly) | ||
``` | ||
mkdir -p ckpt/MultiResQFormer | ||
mkdir -p ckpt/pretrained_ckpt | ||
``` | ||
Then download the following model checkpoints: | ||
|
||
1. Main video-SALMONN model [checkpoint](https://huggingface.co/tsinghua-ee/Video-SALMONN/tree/main), then put it under `ckpt/MultiResQFormer` | ||
2. InstructBLIP [checkpoint](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth) for Vicuna-13B model, then put it under `ckpt/pretrained_ckpt` | ||
3. EVA_VIT model [checkpoint](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth) for InstructBLIP, then put it under `ckpt/pretrained_ckpt` | ||
4. BEATs encoder [checkpoint](https://huggingface.co/spaces/fffiloni/SALMONN-7B-gradio/blob/677c0125de736ab92751385e1e8664cd03c2ce0d/beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt), then put it under `ckpt/pretrained_ckpt` | ||
|
||
|
||
### Run inference | ||
``` | ||
python inference.py --cfg-path config/test.yaml | ||
``` | ||
|
||
### Check the result | ||
The result is saved in the following path: | ||
``` | ||
./ckpt/MultiResQFormer/<DateTime>/eval_result.json | ||
``` | ||
|
||
Expecting the following result: | ||
``` | ||
[ | ||
{ | ||
"id": "./dummy/4405327307.mp4_Describe the video and audio in detail", | ||
"conversation": [ | ||
{ | ||
"from": "human", | ||
"value": "Describe the video and audio in detail" | ||
}, | ||
{ | ||
"from": "gpt", | ||
"value": "None" | ||
} | ||
], | ||
"task": "audiovisual_video_input", | ||
"ref_answer": "None", | ||
"gen_answer": "The video shows a group of musicians performing on stage, with a man singing into a microphone and playing the piano. There is also a drum set and a saxophone on stage. The audience is not visible in the video. The music is upbeat and energetic, and the performers seem to be enjoying themselves.</s>" | ||
} | ||
] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import yaml | ||
|
||
def load_model_config(model, mode): | ||
# load special config for each model | ||
config_path = f'config/{model}.yaml' | ||
print(f'[!] load configuration from {config_path}') | ||
with open(config_path) as f: | ||
configuration = yaml.load(f, Loader=yaml.FullLoader) | ||
new_config = {} | ||
for key, value in configuration.items(): | ||
if key in ['train', 'test', 'validation']: | ||
if mode == key: | ||
new_config.update(value) | ||
else: | ||
new_config[key] = value | ||
configuration = new_config | ||
return configuration | ||
|
||
def load_config(args): | ||
'''the configuration of each model can rewrite the base configuration''' | ||
# base config | ||
base_configuration = load_base_config() | ||
|
||
# load one model config | ||
configuration = load_model_config(args['model'], args['mode']) | ||
|
||
# update and append the special config for base config | ||
base_configuration.update(configuration) | ||
configuration = base_configuration | ||
return configuration | ||
|
||
def load_base_config(): | ||
config_path = f'config/base.yaml' | ||
with open(config_path) as f: | ||
configuration = yaml.load(f, Loader=yaml.FullLoader) | ||
print(f'[!] load base configuration: {config_path}') | ||
return configuration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
models: | ||
openllama: | ||
model_name: OpenLLAMAModel | ||
agent_name: DeepSpeedAgent | ||
stage1_train_dataset: SupervisedDataset | ||
test_dataset: SelfInstructTestDataset | ||
openllama_peft: | ||
model_name: OpenLLAMAPEFTModel | ||
agent_name: DeepSpeedAgent | ||
stage1_train_dataset: SupervisedDataset | ||
test_dataset: SelfInstructTestDataset | ||
openllama_peft_small: | ||
model_name: OpenLLAMAPEFTModel | ||
agent_name: DeepSpeedAgent | ||
stage1_train_dataset: SupervisedDataset | ||
test_dataset: SelfInstructTestDataset | ||
|
||
# ========= Global configuration ========== # | ||
logging_step: 5 | ||
# ========= Global configuration ========== # |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from omegaconf import OmegaConf | ||
|
||
class Config: | ||
def __init__(self, args): | ||
self.config = {} | ||
|
||
self.args = args | ||
user_config = self._build_opt_list(self.args.options) | ||
config = OmegaConf.load(self.args.cfg_path) | ||
config = OmegaConf.merge(config, user_config) | ||
self.config = config | ||
|
||
def _convert_to_dot_list(self, opts): | ||
if opts is None: | ||
opts = [] | ||
|
||
if len(opts) == 0: | ||
return opts | ||
|
||
has_equal = opts[0].find("=") != -1 | ||
|
||
if has_equal: | ||
return opts | ||
|
||
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] | ||
|
||
def _build_opt_list(self, opts): | ||
opts_dot_list = self._convert_to_dot_list(opts) | ||
return OmegaConf.from_dotlist(opts_dot_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# generation hyper-parameters | ||
max_len: 512 | ||
penalty_alpha: 0.6 | ||
top_k: 10 | ||
top_p: 0.7 | ||
random_prefix_len: 5 | ||
sample_num: 2 | ||
decoding_method: sampling | ||
generate_len: 512 | ||
|
||
# lora hyper-parameters | ||
lora_r: 8 | ||
lora_alpha: 32 | ||
lora_dropout: 0.1 | ||
|
||
# some train configuration, more can be found under dsconfig folder | ||
train: | ||
seed: 1337 # 0 | ||
warmup_rate: 0.2 | ||
epochs: 10 | ||
max_length: 2000 | ||
max_shard_size: 80GB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# generation hyper-parameters | ||
max_len: 512 | ||
penalty_alpha: 0.6 | ||
top_k: 10 | ||
top_p: 0.7 | ||
random_prefix_len: 5 | ||
sample_num: 2 | ||
decoding_method: sampling | ||
generate_len: 512 | ||
|
||
# lora hyper-parameters | ||
lora_r: 32 | ||
lora_alpha: 32 | ||
lora_dropout: 0.1 | ||
|
||
# some train configuration, more can be found under dsconfig folder | ||
train: | ||
seed: 0 | ||
warmup_rate: 0.3 | ||
epochs: 10 | ||
max_length: 1024 | ||
max_shard_size: 80GB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
model: openllama_peft | ||
imagebind_ckpt_path: "" | ||
vicuna_ckpt_path: /scratch/LLM/LLM.ckpts/vicuna-13b-v1.5 # Should be modified to your own place | ||
orig_delta_path: "" | ||
delta_ckpt_path: ./ckpt/MultiResQFormer/pytorch_model_4_5001.pt | ||
|
||
all_decode_info: [ | ||
["audiovideoimage", "audiovisual_video_input", "example.json"] | ||
] | ||
|
||
stage: 2y | ||
max_tgt_len: 512 # 32000 | ||
yu_lora_r: 32 # 8 | ||
yu_lora_alpha: 32 | ||
yu_lora_dropout: 0.1 | ||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # ['q_proj', 'v_proj'] | ||
use_lora: "true" | ||
qformer: "true" | ||
use_whisper: "true" | ||
use_blip: "true" | ||
instructblip: "true" | ||
proj_checkpoint: "" | ||
num_video_query: 30 | ||
instructblip_video: "false" | ||
video_window_size: 240 | ||
skip_vqformer: "false" | ||
speech_qformer: "false" | ||
early_align: "true" | ||
cascaded: "" | ||
causal: "false" | ||
diversity_loss: "false" | ||
causal_attention: "true" # "false" | ||
groupsize: 10 | ||
alignmode: 2 | ||
pure_aud: False | ||
num_speech_query: 1 | ||
second_per_frame: 0.333333 | ||
second_stride: 0.333333 | ||
sin_pos: False | ||
use_beats: True # True | ||
return_raw: True # True | ||
n_pos: 120 | ||
flash_attn: False | ||
batch_size: 1 | ||
infer_mode: 2 | ||
bilinear_pooling: False | ||
# ext_groupsize: [1, 30] | ||
low_groupsize: 1 | ||
# # high_groupsize: 20 | ||
ext_same_qformer: True | ||
cache_dir: ./ckpt/pretrained_ckpt |
Oops, something went wrong.