Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] refactor grounding #2979

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/train/multimodal/grounding.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# 20GiB
# You can refer to `https://github.com/QwenLM/Qwen2-VL` for the meaning of the `MAX_PIXELS` parameter.
CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift sft \
--model Qwen/Qwen2-VL-7B-Instruct \
--dataset 'swift/refcoco:grounding#1000' \
--dataset 'AI-ModelScope/coco#20000' \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
Expand All @@ -22,4 +23,5 @@ swift sft \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4
--dataloader_num_workers 4 \
--dataset_num_proc 4
29 changes: 29 additions & 0 deletions swift/llm/dataset/dataset/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,35 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
tags=['multi-modal', 'en', 'vqa', 'quality']))


class CocoPreprocessor(ResponsePreprocessor):
category = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
row['query'] = 'Task: Object Detection'
objects = row['objects']
objects['ref'] = [self.category[c] for c in objects['category']]
row['response'] = ','.join(['<ref-object><bbox>'] * len(objects['ref']))
return super().preprocess(row)


register_dataset(
DatasetMeta(
ms_dataset_id='AI-ModelScope/coco',
hf_dataset_id='detection-datasets/coco',
preprocess_func=CocoPreprocessor(),
huge_dataset=True,
tags=['multi-modal', 'en', 'vqa', 'quality']))


class LLaVAMixSFTPreprocessor(RowPreprocessor):

def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
28 changes: 27 additions & 1 deletion swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datasets import Dataset as HfDataset
from datasets import Image
from datasets import IterableDataset as HfIterableDataset
from datasets import Value
from datasets import Sequence, Value

from swift.llm import history_to_messages
from swift.utils import get_logger
Expand Down Expand Up @@ -143,6 +143,28 @@ def _fix_streaming_keys(row):
new_k = k[len('__@'):]
row[new_k] = row.pop(k)

@staticmethod
def _check_objects(row):
if 'objects' not in row:
return
objects = row['objects']
for k in list(objects.keys()):
if k not in {'bbox', 'ref', 'image_id'}:
objects.pop(k)
bbox = objects['bbox']
assert len(bbox) == len(
objects['ref']), (f"len(objects['bbox']): {len(bbox)}, len(objects['ref']): {len(objects['ref'])}")

# check bbox
for box in bbox:
assert len(box) % 2 == 0, f'len(box): {len(box)}'
if len(box) != 4:
continue
if box[0] > box[2]:
box[0], box[2] = box[2], box[0]
if box[1] > box[3]:
box[1], box[3] = box[3], box[1]

def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
ignore_max_length_error: bool) -> Dict[str, Any]:
from ...template import MaxLengthError
Expand All @@ -161,6 +183,7 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
if isinstance(row, dict):
row = [row]
for r in row:
self._check_objects(r)
self._check_messages(r)
self._check_rejected_response(r)
self._cast_images(r)
Expand Down Expand Up @@ -228,6 +251,9 @@ def _new_init(self, schema=None, features=None, *args, **kwargs):
'content': Value(dtype='string', id=None)
}]
features['images'] = [{'bytes': Value(dtype='binary', id=None), 'path': Value(dtype='string', id=None)}]
features['bbox'] = Sequence(feature=Sequence(feature=Value(dtype='float64'), length=4))
features['ref'] = Sequence(feature=Value(dtype='string'))

ArrowWriter.__origin_init__(self, schema, features, *args, **kwargs)

ArrowWriter.__origin_init__ = ArrowWriter.__init__
Expand Down
Empty file added swift/llm/template/grounding.py
Empty file.
35 changes: 11 additions & 24 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
assert isinstance(image, str)
return [f'Picture {index + 1}: <img>{image}</img>\n']

def replace_object(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
return [f'<ref>{object_["caption"]}</ref>']
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
return [f'<ref>{ref}</ref>']

def replace_box(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
def replace_bbox(self, bbox: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
if isinstance(object_['bbox'][0], list):
all_objects = ''
for sub_object in object_['bbox']:
Expand Down Expand Up @@ -208,27 +208,14 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
return ['<|vision_start|><|video_pad|><|vision_end|>']

def replace_object(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
if object_:
return ['<|object_ref_start|>', object_['caption'], '<|object_ref_end|>']
else:
return ['<ref-object>']

def replace_box(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
if object_:
if isinstance(object_['bbox'][0], list):
all_objects = ''
for sub_object in object_['bbox']:
all_objects += (f'<|box_start|>({sub_object[0]},{sub_object[1]}),'
f'({sub_object[2]},{sub_object[3]})<|box_end|>')
return [all_objects]
else:
return [
f'<|box_start|>({object_["bbox"][0]},{object_["bbox"][1]}),'
f'({object_["bbox"][2]},{object_["bbox"][3]})<|box_end|>'
]
else:
return ['<bbox>']
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
return [f'<|object_ref_start|>{ref}<|object_ref_end|>']

def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
point = []
for x, y in zip(bbox[::2], bbox[1::2]):
point.append(f'({x},{y})')
return [f'<|box_start|>{",".join(point)}<|box_end|>']

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
Expand Down
Loading