Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2-VL for classification #709

Open
Garibelhj opened this issue Feb 1, 2025 · 2 comments
Open

Qwen2-VL for classification #709

Garibelhj opened this issue Feb 1, 2025 · 2 comments

Comments

@Garibelhj
Copy link

Garibelhj commented Feb 1, 2025

I am using the Qwen2-VL-2B model for a classification task, and I want to modify the Qwen2VLForConditionalGeneration by adding a linear classification head to adapt it to the task. I am unsure whether my modification is correct, as the test results show that the model has almost no classification ability after adding the classification head. Below is my training code.

class Qwen2VLWithClassificationHead(Qwen2VLForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        # 定义分类头
        self.mha_layer = torch.nn.MultiheadAttention(embed_dim=512, kdim=512, vdim=512, num_heads=1, batch_first=True)#num_heads=1
        self.sigmoid = nn.Sigmoid()
        self.model = Qwen2VLModel(config)

        self.classification_head = nn.Linear(1536,  7,bias =False) 
        torch.nn.init.xavier_uniform_(self.classification_head.weight) 
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value


    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        position_ids = None,
        past_key_values = None,
        inputs_embeds = None,
        labels = None,
        labels_int = None,
        prob_id = None,
        use_cache = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        pixel_values = None,
        pixel_values_videos = None,
        image_grid_thw= None,
        video_grid_thw = None,
        rope_deltas = None):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.get_dtype())
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                image_mask = (
                    (input_ids == self.config.image_token_id)
                    .unsqueeze(-1)
                    .expand_as(inputs_embeds)
                    .to(inputs_embeds.device)
                )
                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                video_mask = (
                    (input_ids == self.config.video_token_id)
                    .unsqueeze(-1)
                    .expand_as(inputs_embeds)
                    .to(inputs_embeds.device)
                )
                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        transformer_outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        # 获取 hidden_states
        batch_size = input_ids.shape[0]
        classification_logits = self.classification_head(hidden_states)
        pool_classification_logits = classification_logits[torch.arange(batch_size,device=classification_logits.device),-1]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(pool_classification_logits.view(-1, 7), labels_int.view(-1))


        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pool_classification_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

def process_func(example):
    """
    将数据集进行预处理
    """
    number_map = {
    0:'(A) Pornographic Websites',
    1:'(B) Gambling Websites',
    2:'(C) Prize Scam Websites',
    3:'(D) Phishing Websites',
    4:'(E) Malicious Distribution  Websites',
    5:'(F) Fraudulent E-Commerce Website',
    6:'(G) Fraudulent Financial Services Website',
}
    MAX_LENGTH = 8192
    input_ids, attention_mask, labels = [], [], []
    conversation = example["conversations"]
    prompt = conversation['prompt']
    image = conversation['image']
    _id = example['id']
    solution = conversation['solution']
    prob_id = torch.tensor(int(_id))

    answer = conversation['answer']
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{image}",
                    "resized_height": 280,
                    "resized_width": 280,
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]
    labels =number_map[answer]
    labels = f"The answer is {labels}."+"Beacuse:"+solution
    response = tokenizer(f"The answer is {labels}",add_special_tokens=False)
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )  # 获取文本
    image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)

    inputs = {key: value.tolist() for key, value in inputs.items()} #tensor -> list,为了方便拼接
    instruction = inputs

    input_ids = (
            instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
    )

    attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
    labels = (
            [-100] * len(instruction["input_ids"][0])
            + response["input_ids"]
            + [tokenizer.pad_token_id]
    )

    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels_seq = torch.tensor(labels)
    labels_int = torch.tensor(int(answer))
    inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
    inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  #由(1,h,w)变换为(h,w)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels_seq,'labels_int':labels_int,'prob_id':prob_id,
            "pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}





# 在modelscope上下载Qwen2-VL模型到本地目录下
# model_dir = snapshot_download("Qwen/Qwen2-VL-7B-Instruct", cache_dir="./", revision="master")
from transformers import AutoModel

# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
base_model = AutoModel.from_pretrained(
    "/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct",
    device_map=device
)

# 初始化自定义模型(关键步骤)
model = Qwen2VLWithClassificationHead.from_pretrained(
    "/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct",
    config=base_model.config,
    ignore_mismatched_sizes=True
)
model = model.to(device)

# Step 2: 手动设置自定义参数

model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_dataset = train_ds.map(process_func)
eval_dataset = eval_ds.map(process_func)



# 配置训练参数
args = TrainingArguments(
    output_dir="/home/hongjiegu/projects/GMMR/checkpoint/linear_classifier",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    logging_steps=10,
    logging_first_step=5,
    num_train_epochs=10,
    save_steps=1000,
    learning_rate=5e-5,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)
        
# 设置SwanLab回调
swanlab_callback = SwanLabCallback(
    project="Qwen2-VL-finetune_2classes",
    experiment_name="GMMR_FWC_stage1",
)

# 配置Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),

)

# 开启模型训练
trainer.train()
@Adilmunawar
Copy link

Here are the relevant code sections and documentation for the Qwen2VLForConditionalGeneration model and its training process:

  1. Qwen2.5-VL Model Implementation and Usage
  2. Dynamic Resolution and Frame Rate Training for Video Understanding
  3. Web Demo Script Implementation

Based on these findings, I have identified a few areas for potential improvement in your custom model code:

  1. Classification Head Dimensions: Ensure the dimensions of the classification head match the output dimensions of the model.
  2. Gradient Flow: Ensure gradients flow correctly through the added layers.
  3. Data Preprocessing: Verify that data preprocessing aligns with the model's expected input format.

@thohemp
Copy link

thohemp commented Feb 6, 2025

@Adilmunawar What's the point of this AI generated answer.

@Garibelhj I am working on the same issue. In my case, the model does not learn from images at all. Please let me know if you could resolve this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants