Skip to content

Commit

Permalink
[MiniCPMV] fix precision bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chuxiaoyi2023 committed Dec 2, 2024
1 parent e19dab9 commit 4648334
Show file tree
Hide file tree
Showing 14 changed files with 741 additions and 49 deletions.
2 changes: 1 addition & 1 deletion models/MiniCPM-V-2_6/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ python3 export_onnx.py --model_path your_minicpmv_path
此处介绍如何将onnx模型编译成bmodel。也可以省去编译模型这一步,直接下载编译好的模型:

``` shell
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4.bmodel
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel
```

#### 1. 下载docker,启动容器
Expand Down
9 changes: 6 additions & 3 deletions models/MiniCPM-V-2_6/compile/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
## Export onnx

```shell
pip install transformers_stream_generator einops tiktoken accelerate torch==2.0.1+cpu torchvision==0.15.2 transformers==4.40.0
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu
pip install transformers_stream_generator einops tiktoken accelerate transformers==4.40.0
cp files/MiniCPM-V-2_6/modeling_qwen2.py /usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/
cp files/MiniCPM-V-2_6/resampler.py your_torch_model
cp files/MiniCPM-V-2_6/modeling_navit_siglip.py your_torch_model
```
your_torch_model是你模型的位置
```shell
python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device cpu
python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device cpu --image_file ../python_demo/test0.jpg
```
* image_file:image_file为真实图片的路径,导出模型时,输入size会固定为该图片的size。`image_file请输入你实际的图片`
* 目前不支持多图,不支持图片size可变

## Compile bmodel
使用io_alone
Expand All @@ -23,7 +26,7 @@ python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device c
也可以直接下载编译好的模型,不用自己编译
```shell
pip3 install dfss
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpm_int4_seq512_1dev.bmodel
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel
```

### python demo
Expand Down
11 changes: 6 additions & 5 deletions models/MiniCPM-V-2_6/compile/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def test_net_with_mask():
tgt_sizes = inputs["tgt_sizes"][0].to(dtype).to(device)
vit_infer = VisionTransformer(pixel_values, tgt_sizes)
vit_embeds = vit_infer(pixel_values) # [1, 64, 3584]
vit_token_length = vit_embeds.shape[1]

msgs = [{'role': 'user', 'content': '(<image>./</image>)\n请详细描述一下图片内容'}]
prompts_lists = processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
Expand All @@ -295,11 +294,11 @@ def test_net_with_mask():
[[image]],
max_slice_nums=MAX_SLICE_NUMS,
use_image_id=None,
return_tensors="pt",
return_tensors="pt",
max_length=8192
).to(device)
ids = inputs.input_ids[0]
first_offset = int(torch.where(ids==128244)[0][0])
image_offsets = torch.where(ids==128244)[0].tolist()
ids = ids.tolist()

ID_IM_END = tokenizer.convert_tokens_to_ids("<|im_end|>")
Expand All @@ -308,8 +307,10 @@ def test_net_with_mask():
input_ids = torch.tensor(ids).view(SEQ_LENGTH).to(device)
out = embed(input_ids).view(1, SEQ_LENGTH, HIDDEN_SIZE) # [1, 512, 3584]

for i in range(vit_embeds.shape[0]):
out[:, first_offset+i*vit_token_length:first_offset+(i+1)*vit_token_length, :] = vit_embeds[i]
patch_num = pixel_values.shape[0]
patch_size = len(image_offsets) // patch_num
for i in range(patch_num):
out[:, image_offsets[i*patch_size]:image_offsets[i*patch_size]+patch_size, :] = vit_embeds[i]

position_ids = list(range(token_len)) + (SEQ_LENGTH - token_len) * [0]
position_ids = torch.tensor([position_ids]).to(device)
Expand Down
2 changes: 1 addition & 1 deletion models/MiniCPM-V-2_6/compile/run_compile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ sudo cp files/${model_name_upper}/resampler.py ${model_path}
sudo cp files/${model_name_upper}/modeling_navit_siglip.py ${model_path}

echo "export onnx..."
python export_onnx.py --model_path ${model_path} --seq_length ${seq_length}
python export_onnx.py --model_path ${model_path} --seq_length ${seq_length} --image_file ../python_demo/test0.jpg

echo "compile model..."
source ${tpu_mlir_path}/envsetup.sh
Expand Down
6 changes: 3 additions & 3 deletions models/MiniCPM-V-2_6/python_demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pip3 install gradio==3.39.0 mdtex2html==1.2.0 dfss

如果不打算自己编译模型,可以直接用下载好的模型
```
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4.bmodel
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel
```

编译库文件
Expand All @@ -20,5 +20,5 @@ cd build && cmake .. && make && cp *cpython* .. && cd ..

# python demo
```
python3 pipeline.py --model_path minicpmv26_bm1684x_int4.bmodel --tokenizer_path ../support/token_config/ --devid 0
```
python3 pipeline.py --model_path minicpmv26_bm1684x_int4_seq1024.bmodel --processor_path ../support/processor_config/ --devid 0
```
19 changes: 12 additions & 7 deletions models/MiniCPM-V-2_6/python_demo/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MiniCPMV {
void init(int devid, std::string model_path);
void deinit();
int forward_first(std::vector<int> &tokens, std::vector<float> &pixel_values,
int img_offset);
std::vector<int> &img_offsets, int patch_num);
int forward_next();

std::mt19937 sgen;
Expand Down Expand Up @@ -160,7 +160,7 @@ void MiniCPMV::deinit() {
}

int MiniCPMV::forward_first(std::vector<int> &tokens,
std::vector<float> &pixel_values, int img_offset) {
std::vector<float> &pixel_values, std::vector<int> &img_offsets, int patch_num) {
std::vector<int> input_ids(SEQLEN, 0);
std::vector<int> position_id(SEQLEN, 0);
std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK);
Expand All @@ -185,7 +185,7 @@ int MiniCPMV::forward_first(std::vector<int> &tokens,
bm_memcpy_s2d(bm_handle, in_mem, (void *)input_ids.data());
net_launch(net_embed); // prefil embedding

if (pixel_values.size() * sizeof(float) == IMAGE_BYTES && img_offset > 0) {
if (pixel_values.size() * sizeof(float) == IMAGE_BYTES && img_offsets.size() > 0) {
d2d(dev_buffer, out_mem);
out_mem = dev_buffer;
// forward vision transformer
Expand All @@ -195,10 +195,15 @@ int MiniCPMV::forward_first(std::vector<int> &tokens,
net_launch(net_vit);

// concatenante texting embedding and image embedding
int dst_offset = img_offset * HIDDEN_SIZE * 2;
int vit_size = bm_mem_get_device_size(vit_out_mem);
bm_memcpy_d2d_byte(bm_handle, out_mem, dst_offset, vit_out_mem, 0,
vit_size);
int type_byte = sizeof(uint16_t);
int patch_bytes = bm_mem_get_device_size(vit_out_mem) / patch_num;
int patch_size = net_vit->stages[0].output_shapes[0].dims[1];
for (int i = 0; i < patch_num; i++) {
int vit_offset = i * patch_bytes;
int dst_offset = img_offsets[i * patch_size] * HIDDEN_SIZE * type_byte;

bm_memcpy_d2d_byte(bm_handle, out_mem, dst_offset, vit_out_mem, vit_offset, patch_bytes);
}
}

// forward blocks
Expand Down
59 changes: 30 additions & 29 deletions models/MiniCPM-V-2_6/python_demo/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
from PIL import Image
import torchvision.transforms as T
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor
from torchvision.transforms.functional import InterpolationMode
import chat
import os
Expand Down Expand Up @@ -38,18 +38,13 @@ def __init__(self, args):
self.device = args.devid

# load tokenizer
print("Load " + args.tokenizer_path + " ...")
self.tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_path, trust_remote_code=True
)
self.tokenizer.decode([0]) # warm up

print("Load " + args.processor_path + " ...")
self.processor = AutoProcessor.from_pretrained(
args.tokenizer_path, trust_remote_code=True
args.processor_path, trust_remote_code=True
)

# preprocess parameters, such as prompt & tokenizer
self.system_prompt = '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n'
self.tokenizer = self.processor.tokenizer
self.tokenizer.decode([0]) # warm up

# load model
self.model = chat.MiniCPMV()
Expand All @@ -58,27 +53,33 @@ def __init__(self, args):
self.ID_EOS = self.tokenizer.eos_token_id
self.ID_IM_END = self.tokenizer.convert_tokens_to_ids("<|im_end|>")

# parameters
self.MAX_SLICE_NUMS = self.processor.image_processor.max_slice_nums

def encode(self):
if not self.image_str:
inserted_image_str = ""
self.pixel_values = []
else:
inserted_image_str = "(<image>./</image>)\n"
image = Image.open(sample_image_file).convert('RGB')
inputs = processor.image_processor([image], do_pad=True, max_slice_nums=MAX_SLICE_NUMS, return_tensors="pt")
pixel_values = inputs["pixel_values"][0]

msgs = [{'role': 'user', 'content': '{}{}'.format(self.inserted_image_str, self.input_str)}]
prompt = self.system_prompt + self.input_str + "<|im_end|>\n<|im_start|>assistant\n"
self.input_ids = self.tokenizer.encode(prompt)
self.image_offset = 0
self.pixel_values = []
return
self.pixel_values = load_image(self.image_str).flatten().tolist()
msgs = [{'role': 'user', 'content': '(<image>./</image>)\n{}'.format(self.input_str)}]
self.input_ids = processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)[0]
self.image_offset = 0
breakpoint()
image = Image.open(self.image_str).convert('RGB')

msgs = [{'role': 'user', 'content': '{}{}'.format(inserted_image_str, self.input_str)}]
prompts_lists = self.processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

inputs = self.processor(
prompts_lists,
[[image]] if image else None,
max_slice_nums=self.MAX_SLICE_NUMS,
use_image_id=None,
return_tensors="pt",
max_length=8192
)
self.input_ids = inputs.input_ids[0]
self.pixel_values = torch.cat(inputs["pixel_values"][0], dim=0).flatten().tolist()
self.image_offsets = torch.where(self.input_ids==128244)[0].tolist()
self.patch_num = len(inputs["pixel_values"][0])

self.input_ids = self.input_ids.tolist()

def chat(self):
"""
Expand Down Expand Up @@ -107,7 +108,7 @@ def chat(self):
# Chat
first_start = time.time()
token = self.model.forward_first(
self.input_ids, self.pixel_values, self.image_offset)
self.input_ids, self.pixel_values, self.image_offsets, self.patch_num)
first_end = time.time()
tok_num = 1
# Following tokens
Expand Down Expand Up @@ -142,8 +143,8 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str,
required=True, help='path to the bmodel file')
parser.add_argument('-t', '--tokenizer_path', type=str,
default="../support/token_config", help='path to the tokenizer file')
parser.add_argument('-p', '--processor_path', type=str,
default="../support/processor_config", help='path to the processor file')
parser.add_argument('-d', '--devid', type=int,
default=0, help='device ID to use')
args = parser.parse_args()
Expand Down
Loading

0 comments on commit 4648334

Please sign in to comment.