diff --git "a/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" index c0806b928..e4c1e8f78 100644 --- "a/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" @@ -115,7 +115,7 @@ RLHF的数据格式可以参考纯文本大模型的格式。 ``` 使用这种类型的数据需要注意: - 不同模型grounding任务的特殊字符和数据集格式不同 - - 需要对bbox的坐标进行归一化。例如:使用千分位坐标进行归一化 + - 不同模型对bbox是否归一化的处理不同。例如:qwen2.5-vl使用绝对坐标,而qwen2-vl、internvl2.5需要对bbox的坐标进行千分位坐标归一化 2. 使用SWIFT的grounding数据格式: diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 1de9cfff5..8de07c6a5 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -50,7 +50,7 @@ - truncation_strategy: 如果超长如何处理,支持`delete`, `left`和`right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete' - 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。 - tools_prompt: 智能体训练时的工具列表转为system的格式,请参考[智能体训练](./智能体的支持.md),默认为'react_en' -- norm_bbox: 控制如何对bbox进行缩放。可选项为'norm1000', 'none'。默认为'norm1000',即进行千分位坐标缩放 +- norm_bbox: 控制如何对bbox进行缩放。可选项为'norm1000', 'none'。其中'norm1000'代表对bbox进行千分位坐标缩放,'none'则代表不缩放。默认为None,根据模型进行自动选择 - padding_side: 当训练`batch_size>=2`时的padding_side,可选值为'left', 'right',默认为'right'。(`generate`的batch_size>=2时,只进行左padding) - loss_scale: 如何针对训练添加token的loss权重。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。可选值为'default', 'last_round', 'all', 以及agent需要的loss_scale: 'react', 'agentflan', 'alpha_umi', 'qwen'。具体可以查看[插件化](../Customization/插件化.md)和[智能体训练](./智能体的支持.md) - sequence_parallel_size: 序列并行数量。参考[example](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/train.sh) diff --git a/docs/source_en/Customization/Custom-dataset.md b/docs/source_en/Customization/Custom-dataset.md index b87119cdb..1be9824ba 100644 --- a/docs/source_en/Customization/Custom-dataset.md +++ b/docs/source_en/Customization/Custom-dataset.md @@ -119,7 +119,7 @@ For grounding (object detection) tasks, SWIFT supports two methods: When using this type of data, please note: - Different models have different special characters and data format for the grounding task. -- It is necessary to normalize the coordinates of the bounding boxes, for example by using thousandth-scale coordinates for normalization. +- The handling of bounding box normalization varies across different models: for example, qwen2.5-vl uses absolute coordinates, while qwen2-vl and internvl2.5 require bounding box coordinates to be normalized to the thousandth scale. 1. Use SWIFT's grounding data format: diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 628b99a62..dd4ef0e42 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -50,7 +50,7 @@ The introduction to command line parameters will cover base arguments, atomic ar - truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'. - 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling. - tools_prompt: The list of tools for agent training converted to system format, refer to [Agent Training](./Agent-support.md), default is 'react_en'. -- norm_bbox: Controls how the bounding box (bbox) is scaled. Options are 'norm1000' and 'none'. The default is 'norm1000', which applies scaling to thousandths. +- norm_bbox: Controls how to scale the bounding box (bbox). Options are 'norm1000' and 'none'. 'norm1000' stands for scaling the bbox coordinates by a factor of a thousand, while 'none' means no scaling is applied. The default is None, which allows automatic selection based on the model. - padding_side: The padding_side used when training with `batch_size >= 2`, with optional values of 'left' and 'right', defaulting to 'right'. (When the batch_size in `generate` is >= 2, only left padding is applied.) - loss_scale: How to add token loss weight during training. Default is `'default'`, meaning all responses (including history) are treated as 1 for cross-entropy loss. The optional values are 'default', 'last_round', 'all', and the loss scale required by the agent: 'react', 'agentflan', 'alpha_umi', 'qwen'. For specifics, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md). - sequence_parallel_size: Number of sequence parallelism. Refer to [example](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/train.sh). diff --git a/examples/infer/demo_grounding.py b/examples/infer/demo_grounding.py new file mode 100644 index 000000000..5861e3178 --- /dev/null +++ b/examples/infer/demo_grounding.py @@ -0,0 +1,41 @@ +import os +import re +from typing import Literal + + +def draw_bbox_qwen2_vl(image, response, norm_bbox: Literal['norm1000', 'none']): + matches = re.findall( + r'<\|object_ref_start\|>(.*?)<\|object_ref_end\|><\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>', + response) + ref = [] + bbox = [] + for match_ in matches: + ref.append(match_[0]) + bbox.append(list(match_[1:])) + draw_bbox(image, ref, bbox, norm_bbox=norm_bbox) + + +def infer_grounding(): + from swift.llm import PtEngine, RequestConfig, BaseArguments, InferRequest, safe_snapshot_download + output_path = 'bbox.png' + image = load_image('http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png') + infer_request = InferRequest(messages=[{'role': 'user', 'content': 'Task: Object Detection'}], images=[image]) + + request_config = RequestConfig(max_tokens=512, temperature=0) + adapter_path = safe_snapshot_download('swift/test_grounding') + args = BaseArguments.from_pretrained(adapter_path) + + engine = PtEngine(args.model, adapters=[adapter_path]) + resp_list = engine.infer([infer_request], request_config) + response = resp_list[0].choices[0].message.content + print(f'lora-response: {response}') + + draw_bbox_qwen2_vl(image, response, norm_bbox=args.norm_bbox) + print(f'output_path: {output_path}') + image.save(output_path) + + +if __name__ == '__main__': + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + from swift.llm import draw_bbox, load_image + infer_grounding() diff --git a/examples/notebook/qwen2.5-self-cognition/infer.ipynb b/examples/notebook/qwen2_5-self-cognition/infer.ipynb similarity index 100% rename from examples/notebook/qwen2.5-self-cognition/infer.ipynb rename to examples/notebook/qwen2_5-self-cognition/infer.ipynb diff --git a/examples/notebook/qwen2.5-self-cognition/infer.sh b/examples/notebook/qwen2_5-self-cognition/infer.sh similarity index 100% rename from examples/notebook/qwen2.5-self-cognition/infer.sh rename to examples/notebook/qwen2_5-self-cognition/infer.sh diff --git a/examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb b/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb similarity index 100% rename from examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb rename to examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb diff --git a/examples/notebook/qwen2.5-self-cognition/sft.sh b/examples/notebook/qwen2_5-self-cognition/sft.sh similarity index 100% rename from examples/notebook/qwen2.5-self-cognition/sft.sh rename to examples/notebook/qwen2_5-self-cognition/sft.sh diff --git a/examples/notebook/qwen2_5-vl-grounding/zh.ipynb b/examples/notebook/qwen2_5-vl-grounding/zh.ipynb new file mode 100644 index 000000000..63900c0a9 --- /dev/null +++ b/examples/notebook/qwen2_5-vl-grounding/zh.ipynb @@ -0,0 +1,259 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Qwen2.5-VL Grounding任务\n", + "\n", + "这里介绍使用qwen2.5-vl进行grounding任务的全流程介绍。当然,你也可以使用internvl2.5或者qwen2-vl等多模态模型。\n", + "\n", + "我们使用[AI-ModelScope/coco](https://modelscope.cn/datasets/AI-ModelScope/coco)数据集来展示整个流程。\n", + "\n", + "如果需要使用自定义数据集,需要符合以下格式:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"描述图像\"}, {\"role\": \"assistant\", \"content\": \"正在沙滩上玩耍\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [\"一只狗\", \"一个女人\"], \"bbox\": [[331.5, 761.4, 853.5, 1594.8], [676.5, 685.8, 1099.5, 1427.4]]}}\n", + "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"找到图像中的\"}, {\"role\": \"assistant\", \"content\": \"\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [\"羊\"], \"bbox\": [[90.9, 160.8, 135, 212.8], [360.9, 480.8, 495, 532.8]]}}\n", + "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"帮我打开谷歌浏览器\"}, {\"role\": \"assistant\", \"content\": \"Action: click(start_box='')\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [], \"bbox\": [[615, 226]]}}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ms-swift在预处理数据集时,会使用模型特有的grounding任务格式,将objects中的ref填充``,bbox会根据模型类型选择是否进行0-1000的归一化,并填充``。例如:qwen2-vl为`f'<|object_ref_start|>羊<|object_ref_end|>'`和`f'<|box_start|>(101,201),(150,266)<|box_end|>'`(qwen2.5-vl不进行归一化,只将float型转成int型),internvl2.5则为`f''`和`f'[[101, 201, 150, 266]]'`等。\n", + "\n", + "\n", + "训练之前,你需要从main分支安装ms-swift:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "# pip install git+https://github.com/modelscope/ms-swift.git\n", + "\n", + "git clone https://github.com/modelscope/ms-swift.git\n", + "cd ms-swift\n", + "pip install -e .\n", + "\n", + "# 如果'transformers>=4.49'已经发版,则无需从main分支安装\n", + "pip install git+https://github.com/huggingface/transformers.git" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,使用以下shell进行训练。MAX_PIXELS的参数含义可以查看[这里](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html#specific-model-arguments)\n", + "\n", + "### 训练\n", + "\n", + "单卡训练:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "# 显存资源:24GiB\n", + "CUDA_VISIBLE_DEVICES=0 \\\n", + "MAX_PIXELS=1003520 \\\n", + "swift sft \\\n", + " --model Qwen/Qwen2.5-VL-7B-Instruct \\\n", + " --dataset 'AI-ModelScope/coco#2000' \\\n", + " --train_type lora \\\n", + " --torch_dtype bfloat16 \\\n", + " --num_train_epochs 1 \\\n", + " --per_device_train_batch_size 1 \\\n", + " --per_device_eval_batch_size 1 \\\n", + " --learning_rate 1e-4 \\\n", + " --lora_rank 8 \\\n", + " --lora_alpha 32 \\\n", + " --target_modules all-linear \\\n", + " --freeze_vit true \\\n", + " --gradient_accumulation_steps 16 \\\n", + " --eval_steps 100 \\\n", + " --save_steps 100 \\\n", + " --save_total_limit 5 \\\n", + " --logging_steps 5 \\\n", + " --max_length 2048 \\\n", + " --output_dir output \\\n", + " --warmup_ratio 0.05 \\\n", + " --dataloader_num_workers 4 \\\n", + " --dataset_num_proc 4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后我们将训练的模型推送到ModelScope:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "CUDA_VISIBLE_DEVICES=0 swift export \\\n", + " --adapters output/vx-xxx/checkpoint-xxx \\\n", + " --push_to_hub true \\\n", + " --hub_model_id '' \\\n", + " --hub_token '' \\\n", + " --use_hf false" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们将训练的checkpoint推送到[swift/test_grounding](https://modelscope.cn/models/swift/test_grounding)。\n", + "\n", + "### 推理\n", + "\n", + "训练完成后,我们使用以下命令对训练时的验证集进行推理。这里`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件,因此不需要额外指定`--model`。\n", + "\n", + "若模型采用的是绝对坐标的方式进行输出,推理时请提前对图像进行缩放而不使用`MAX_PIXELS`或者`--max_pixels`。若是千分位坐标,则没有此约束。\n", + "\n", + "由于我们已经将训练后的checkpoint推送到了ModelScope上,以下推理脚本可以直接运行:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "CUDA_VISIBLE_DEVICES=0 \\\n", + "swift infer \\\n", + " --adapters swift/test_grounding \\\n", + " --stream true \\\n", + " --load_data_args true \\\n", + " --max_new_tokens 512 \\\n", + " --dataset_num_proc 4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们也可以使用代码的方式进行推理:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", + "\n", + "import re\n", + "from typing import Literal\n", + "from swift.llm import (\n", + " PtEngine, RequestConfig, BaseArguments, InferRequest, safe_snapshot_download, draw_bbox, load_image, load_dataset, InferEngine\n", + ")\n", + "from IPython.display import display\n", + "\n", + "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n", + " request_config = RequestConfig(max_tokens=512, temperature=0, stream=True)\n", + " gen = engine.infer([infer_request], request_config)\n", + " query = infer_request.messages[0]['content']\n", + " print(f'query: {query}\\nresponse: ', end='')\n", + " response = ''\n", + " for resp_list in gen:\n", + " if resp_list[0] is None:\n", + " continue\n", + " delta = resp_list[0].choices[0].delta.content\n", + " response += delta\n", + " print(delta, end='', flush=True)\n", + " print()\n", + " return response\n", + "\n", + "def draw_bbox_qwen2_vl(image, response, norm_bbox: Literal['norm1000', 'none']):\n", + " matches = re.findall(\n", + " r'<\\|object_ref_start\\|>(.*?)<\\|object_ref_end\\|><\\|box_start\\|>\\((\\d+),(\\d+)\\),\\((\\d+),(\\d+)\\)<\\|box_end\\|>',\n", + " response)\n", + " ref = []\n", + " bbox = []\n", + " for match_ in matches:\n", + " ref.append(match_[0])\n", + " bbox.append(list(match_[1:]))\n", + " draw_bbox(image, ref, bbox, norm_bbox=norm_bbox)\n", + "\n", + "# 下载权重,并加载模型\n", + "output_dir = 'images_bbox'\n", + "model_id_or_path = 'swift/test_grounding'\n", + "output_dir = os.path.abspath(os.path.expanduser(output_dir))\n", + "adapter_path = safe_snapshot_download(model_id_or_path)\n", + "args = BaseArguments.from_pretrained(adapter_path)\n", + "engine = PtEngine(args.model, adapters=[adapter_path])\n", + "\n", + "# 获取验证集并推理\n", + "_, val_dataset = load_dataset(args.dataset, split_dataset_ratio=args.split_dataset_ratio, num_proc=4, seed=args.seed)\n", + "print(f'output_dir: {output_dir}')\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "for i, data in enumerate(val_dataset):\n", + " image = data['images'][0]\n", + " image = load_image(image['bytes'] or image['path'])\n", + " display(image)\n", + " response = infer_stream(engine, InferRequest(**data))\n", + " draw_bbox_qwen2_vl(image, response, norm_bbox=args.norm_bbox)\n", + " print('-' * 50)\n", + " image.save(os.path.join(output_dir, f'{i}.png'))\n", + " display(image)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "test_py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/train/multimodal/grounding.sh b/examples/train/multimodal/grounding.sh index c2213660d..077542d1e 100644 --- a/examples/train/multimodal/grounding.sh +++ b/examples/train/multimodal/grounding.sh @@ -18,7 +18,7 @@ swift sft \ --gradient_accumulation_steps 16 \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ + --save_total_limit 5 \ --logging_steps 5 \ --max_length 2048 \ --output_dir output \ diff --git a/examples/train/seq_cls/bert/sft.sh b/examples/train/seq_cls/bert/sft.sh index 538e74337..17ec39024 100644 --- a/examples/train/seq_cls/bert/sft.sh +++ b/examples/train/seq_cls/bert/sft.sh @@ -17,7 +17,7 @@ swift sft \ --gradient_accumulation_steps 16 \ --eval_steps 50 \ --save_steps 50 \ - --save_total_limit 2 \ + --save_total_limit 5 \ --logging_steps 5 \ --max_length 512 \ --truncation_strategy right \ diff --git a/examples/train/seq_cls/qwen2_5/sft.sh b/examples/train/seq_cls/qwen2_5/sft.sh index fe33ee372..efb0ac9b2 100644 --- a/examples/train/seq_cls/qwen2_5/sft.sh +++ b/examples/train/seq_cls/qwen2_5/sft.sh @@ -17,7 +17,7 @@ swift sft \ --gradient_accumulation_steps 16 \ --eval_steps 50 \ --save_steps 50 \ - --save_total_limit 2 \ + --save_total_limit 5 \ --logging_steps 5 \ --max_length 2048 \ --output_dir output \ diff --git a/examples/train/seq_cls/qwen2_vl/sft.sh b/examples/train/seq_cls/qwen2_vl/sft.sh index ed134cf6b..a277f4aa6 100644 --- a/examples/train/seq_cls/qwen2_vl/sft.sh +++ b/examples/train/seq_cls/qwen2_vl/sft.sh @@ -17,7 +17,7 @@ swift sft \ --gradient_accumulation_steps 16 \ --eval_steps 50 \ --save_steps 50 \ - --save_total_limit 2 \ + --save_total_limit 5 \ --logging_steps 5 \ --max_length 2048 \ --output_dir output \ diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index bf41f1959..d29e2176f 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -15,7 +15,7 @@ RLHFArguments, WebUIArguments, BaseArguments, AppArguments) from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template, TemplateInputs, TemplateMeta, get_template_meta, InferRequest, load_image, MaxLengthError, - load_file) + load_file, draw_bbox) from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys, ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup, @@ -47,7 +47,7 @@ 'template': [ 'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template', 'TemplateInputs', 'TemplateMeta', 'get_template_meta', 'InferRequest', 'load_image', 'MaxLengthError', - 'load_file' + 'load_file', 'draw_bbox' ], 'model': [ 'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory', diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index e7a2a3aad..45722b520 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -164,6 +164,7 @@ def adapters_can_be_merged(self): @classmethod def from_pretrained(cls, checkpoint_dir: str): self = super().__new__(cls) + self.load_data_args = True self.ckpt_dir = checkpoint_dir self.load_args_from_ckpt() return self @@ -203,7 +204,7 @@ def load_args_from_ckpt(self) -> None: 'split_dataset_ratio', # template_args 'tools_prompt', - 'use_chat_template' + 'use_chat_template', ] skip_keys = list(f.name for f in fields(GenerationArguments) + fields(CompatArguments)) + ['adapters'] if not isinstance(self, TrainArguments): diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 091f861d6..1153ba597 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -36,7 +36,7 @@ class TemplateArguments: truncation_strategy: Literal['delete', 'left', 'right'] = 'delete' max_pixels: Optional[int] = None tools_prompt: str = 'react_en' # Override the default_tools_prompt in the template. - norm_bbox: Literal['norm1000', 'none'] = 'norm1000' + norm_bbox: Literal['norm1000', 'none', None] = None # train padding_side: Literal['left', 'right'] = 'right' loss_scale: str = 'default' diff --git a/swift/llm/template/__init__.py b/swift/llm/template/__init__.py index 249e565dd..1892f5b70 100644 --- a/swift/llm/template/__init__.py +++ b/swift/llm/template/__init__.py @@ -2,6 +2,7 @@ from . import template from .base import MaxLengthError, Template from .constant import TemplateType +from .grounding import draw_bbox from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template from .template_inputs import InferRequest, TemplateInputs from .template_meta import TemplateMeta diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index a9d7167bf..cb88c6c88 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -21,7 +21,6 @@ from swift.utils import get_dist_setting, get_logger, use_torchacc from ..utils import Processor, ProcessorMixin -from .grounding import normalize_bbox from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by from .vision_utils import load_image, rescale_image @@ -44,6 +43,7 @@ class Template(ProcessorMixin): load_images = True skip_prompt = True use_model = False + norm_bbox = 'norm1000' is_encoder_decoder = False @@ -58,7 +58,7 @@ def __init__( truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', max_pixels: Optional[int] = None, tools_prompt: Optional[str] = None, - norm_bbox: Literal['norm1000', 'none'] = 'norm1000', + norm_bbox: Literal['norm1000', 'none', None] = None, # only for train padding_side: Literal['left', 'right'] = 'right', loss_scale: str = 'default', @@ -106,7 +106,7 @@ def __init__( self.padding_side = padding_side self.sequence_parallel_size = sequence_parallel_size self.tools_prompt = tools_prompt or template_meta.default_tools_prompt - self.norm_bbox = norm_bbox + self.norm_bbox = norm_bbox or self.norm_bbox if self.is_encoder_decoder: self.skip_prompt = False @@ -135,6 +135,38 @@ def _load_image(image, load_images: bool): image = load_image(image) return image + @staticmethod + def _get_height_width(inputs: StdTemplateInputs) -> None: + width = [] + height = [] + for image in inputs.images: + width.append(image.width) + height.append(image.height) + inputs.objects['width'] = width + inputs.objects['height'] = height + + def normalize_bbox(self, inputs: StdTemplateInputs) -> None: + objects = inputs.objects + bbox_list = objects['bbox'] + width_list = objects['width'] + height_list = objects['height'] + bbox_type = objects.pop('bbox_type', None) or 'real' + image_id_list = objects.pop('image_id', None) or [] + image_id_list += [0] * (len(bbox_list) - len(image_id_list)) + for bbox, image_id in zip(bbox_list, image_id_list): + if bbox_type == 'norm1': + width, height = 1, 1 + else: + width, height = width_list[image_id], height_list[image_id] + for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])): + if self.norm_bbox == 'norm1000': + norm_width, norm_height = 1000, 1000 + elif self.norm_bbox == 'none': + image = inputs.images[image_id] + norm_width, norm_height = image.width, image.height + bbox[2 * i] = int(round(x / width * norm_width)) + bbox[2 * i + 1] = int(round(y / height * norm_height)) + def _preprocess_inputs( self, inputs: StdTemplateInputs, @@ -148,7 +180,7 @@ def _preprocess_inputs( for i, image in enumerate(images): images[i] = self._load_image(images[i], load_images) if inputs.objects: - normalize_bbox(inputs.images, inputs.objects, norm_bbox=self.norm_bbox) + self._get_height_width(inputs) if self.max_pixels is not None: # Scale the image proportionally without affecting the scaled objects. images = [rescale_image(img, self.max_pixels) for img in images] @@ -501,6 +533,8 @@ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float Returns: The context_list and loss_scale_list after replacement. """ + if inputs.images and inputs.objects: + self.normalize_bbox(inputs) # replace tag/object/box res: List[Context] = [] # result of context_list res_loss_scale: List[float] = [] # result of loss_scale_list diff --git a/swift/llm/template/grounding.py b/swift/llm/template/grounding.py index 6cff2c09b..4ce88e22e 100644 --- a/swift/llm/template/grounding.py +++ b/swift/llm/template/grounding.py @@ -1,24 +1,89 @@ -from typing import Any, Dict, List, Literal +import colorsys +import itertools +import os +from typing import Any, List, Literal -from PIL import Image +import requests +from modelscope.hub.utils.utils import get_cache_dir +from PIL import Image, ImageDraw, ImageFont -def normalize_bbox(images: List[Image.Image], - objects: Dict[str, List[Any]], - norm_bbox: Literal['norm1000', 'none'] = 'norm1000') -> None: - if not objects or not images or norm_bbox == 'none': - return - bbox_list = objects['bbox'] - bbox_type = objects.pop('bbox_type', None) or 'real' - image_id_list = objects.pop('image_id', None) or [] - image_id_list += [0] * (len(bbox_list) - len(image_id_list)) - for bbox, image_id in zip(bbox_list, image_id_list): +def _shuffle_colors(nums: List[Any]) -> List[Any]: + if len(nums) == 1: + return nums + + mid = len(nums) // 2 + + left = nums[:mid] + right = nums[mid:] + left = _shuffle_colors(left) + right = _shuffle_colors(right) + new_nums = [] + for x, y in zip(left, right): + new_nums += [x, y] + new_nums += left[len(right):] or right[len(left):] + return new_nums + + +def generate_colors(): + vs_combinations = [(v, s) for v, s in itertools.product([0.7, 0.3, 1], [0.7, 0.3, 1])] + colors = [colorsys.hsv_to_rgb(i / 16, s, v) for v, s in vs_combinations for i in _shuffle_colors(list(range(16)))] + colors = [(int(r * 255), int(g * 255), int(b * 255)) for r, g, b in colors] + return _shuffle_colors(colors) + + +def download_file(url: str) -> str: + url = url.rstrip('/') + file_name = url.rsplit('/', 1)[-1] + cache_dir = os.path.join(get_cache_dir(), 'files') + os.makedirs(cache_dir, exist_ok=True) + req = requests.get(url) + file_path = os.path.join(cache_dir, file_name) + with open(file_path, 'wb') as f: + f.write(req.content) + return file_path + + +colors = generate_colors() +color_mapping = {} + + +def _calculate_brightness(image, region: List[int]): + cropped_image = image.crop(region) + grayscale_image = cropped_image.convert('L') + pixels = list(grayscale_image.getdata()) + average_brightness = sum(pixels) / len(pixels) + return average_brightness + + +def draw_bbox(image: Image.Image, + ref: List[str], + bbox: List[List[int]], + norm_bbox: Literal['norm1000', 'none'] = 'norm1000'): + font_path = 'https://modelscope.cn/models/Qwen/Qwen-VL-Chat/resolve/master/SimSun.ttf' + # norm bbox + for i, box in enumerate(bbox): + for i in range(len(box)): + box[i] = int(box[i]) if norm_bbox == 'norm1000': - if bbox_type == 'norm1': - width, height = 1, 1 - else: - image = images[image_id] - width, height = image.width, image.height - for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])): - bbox[2 * i] = int(x / width * 1000) - bbox[2 * i + 1] = int(y / height * 1000) + box[0] = box[0] / 1000 * image.width + box[2] = box[2] / 1000 * image.width + box[1] = box[1] / 1000 * image.height + box[3] = box[3] / 1000 * image.height + + draw = ImageDraw.Draw(image) + # draw bbox + assert len(ref) == len(bbox), f'len(refs): {len(ref)}, len(bboxes): {len(bbox)}' + for (left, top, right, bottom), box_ref in zip(bbox, ref): + if box_ref not in color_mapping: + color_mapping[box_ref] = colors[len(color_mapping) % len(colors)] + color = color_mapping[box_ref] + draw.rectangle([(left, top), (right, bottom)], outline=color, width=3) + # draw text + file_path = download_file(font_path) + font = ImageFont.truetype(file_path, 20) + for (left, top, _, _), box_ref in zip(bbox, ref): + brightness = _calculate_brightness( + image, [left, top, min(left + 100, image.width), + min(top + 20, image.height)]) + draw.text((left, top), box_ref, fill='white' if brightness < 128 else 'black', font=font) diff --git a/swift/llm/template/register.py b/swift/llm/template/register.py index 53636b3e0..c99804306 100644 --- a/swift/llm/template/register.py +++ b/swift/llm/template/register.py @@ -27,7 +27,7 @@ def get_template( truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', max_pixels: Optional[int] = None, # h * w tools_prompt: str = 'react_en', - norm_bbox: Literal['norm1000', 'none'] = 'norm1000', + norm_bbox: Literal['norm1000', 'none', None] = None, # train padding_side: Literal['left', 'right'] = 'right', loss_scale: str = 'default', diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index d969364f3..73fb0ce42 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -325,6 +325,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in class Qwen2_5VLTemplate(Qwen2VLTemplate): version = 'v2_5' + norm_bbox = 'none' register_template( diff --git a/swift/llm/template/template_inputs.py b/swift/llm/template/template_inputs.py index 953a20fbd..0c29b5d78 100644 --- a/swift/llm/template/template_inputs.py +++ b/swift/llm/template/template_inputs.py @@ -39,6 +39,7 @@ class InferRequest: videos: List[str] = field(default_factory=list) tools: Optional[List[Tool]] = None + objects: Dict[str, List[Any]] = field(default_factory=dict) def __post_init__(self): for key in ['images', 'audios', 'videos']: @@ -81,7 +82,6 @@ class TemplateInputs(InferRequest): """ rejected_response: Optional[str] = None label: Optional[bool] = None - objects: Dict[str, List[Any]] = field(default_factory=dict) @dataclass