Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT add auto mps support #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 102 additions & 42 deletions ferret/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,98 +15,156 @@
import shutil
import pdb

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
BitsAndBytesConfig,
)
import torch
from ferret.model import *
from ferret.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from ferret.constants import (
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from ferret.model.utils import DEVICE

DEFAULT_REGION_FEA_TOKEN = "<region_fea>"

def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):

def load_pretrained_model(
model_path,
model_base,
model_name,
load_8bit=False,
load_4bit=False,
device_map="auto",
):
kwargs = {"device_map": device_map}

if load_8bit:
kwargs['load_in_8bit'] = True
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
bnb_4bit_quant_type="nf4",
)
else:
kwargs['torch_dtype'] = torch.float16
kwargs["torch_dtype"] = torch.float16

if 'llava' in model_name.lower() or 'ferret' in model_name.lower():
if "llava" in model_name.lower() or "ferret" in model_name.lower():
# Load LLaVA/FERRET model
if 'lora' in model_name.lower() and model_base is not None:
if "lora" in model_name.lower() and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA/FERRET from base model...')
model = FERRETLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
print("Loading LLaVA/FERRET from base model...")
model = FERRETLlamaForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.lm_head.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)
model.model.embed_tokens.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)

print('Loading additional LLaVA/FERRET weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
print("Loading additional LLaVA/FERRET weights...")
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
non_lora_trainables = torch.load(
os.path.join(model_path, "non_lora_trainables.bin"),
map_location="cpu",
)
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download

def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
repo_id=repo_id, filename=filename, subfolder=subfolder
)
return torch.load(cache_file, map_location="cpu")

non_lora_trainables = load_from_hf(
model_path, "non_lora_trainables.bin"
)
non_lora_trainables = {
(k[11:] if k.startswith("base_model.") else k): v
for k, v in non_lora_trainables.items()
}
if any(k.startswith("model.model.") for k in non_lora_trainables):
non_lora_trainables = {
(k[6:] if k.startswith("model.") else k): v
for k, v in non_lora_trainables.items()
}
model.load_state_dict(non_lora_trainables, strict=False)

from peft import PeftModel
print('Loading LoRA weights...')

print("Loading LoRA weights...")
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
print("Merging LoRA weights...")
model = model.merge_and_unload()
print('Model is loaded...')
print("Model is loaded...")
elif model_base is not None:
# this may be mm projector only
print('Loading LLaVA/FERRET from base model...')
print("Loading LLaVA/FERRET from base model...")
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = FERRETLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
model = FERRETLlamaForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
)

mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
mm_projector_weights = torch.load(
os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
)
mm_projector_weights = {
k: v.to(torch.float16) for k, v in mm_projector_weights.items()
}
model.load_state_dict(mm_projector_weights, strict=False)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = FERRETLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
model = FERRETLlamaForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel

tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(
model_base,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print('Convert to FP16...')
print("Convert to FP16...")
model.to(torch.float16)
else:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)

image_processor = None

if 'llava' in model_name.lower() or 'ferret' in model_name.lower():
if "llava" in model_name.lower() or "ferret" in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
mm_im_region_fea_token = getattr(model.config, "im_region_fea_token", None)
Expand All @@ -115,20 +173,22 @@ def load_from_hf(repo_id, filename, subfolder=None):
if mm_im_region_fea_token is not None:
tokenizer.add_tokens([DEFAULT_REGION_FEA_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
vision_tower_path = os.path.join(model_path, 'vision_tower')
vision_tower_path = os.path.join(model_path, "vision_tower")
if not vision_tower.is_loaded or os.path.exists(vision_tower_path):
if os.path.exists(vision_tower_path):
print(f'Start Loading vision tower from {vision_tower_path}')
print(f"Start Loading vision tower from {vision_tower_path}")
vision_tower.load_model(vision_tower_path=vision_tower_path)
print(f'Finish Loading vision tower from {vision_tower_path}')
print(f"Finish Loading vision tower from {vision_tower_path}")
else:
vision_tower.load_model()

vision_tower.to(device='cuda', dtype=torch.float16)
vision_tower.to(device=DEVICE, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
Expand Down
24 changes: 19 additions & 5 deletions ferret/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import torch
from typing import Literal
from transformers import AutoConfig

DEVICE: Literal["cpu", "cuda", "mps"] = None
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"


def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and 'llava' not in cfg.model_type:
assert cfg.model_type == 'llama'
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
if "llava" in config and "llava" not in cfg.model_type:
assert cfg.model_type == "llama"
print(
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
)
print(
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
)
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'FERRETLlamaForCausalLM'
cfg.architectures[0] = "FERRETLlamaForCausalLM"
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
Expand Down
9 changes: 5 additions & 4 deletions ferret/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ferret.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
# from transformers import TextIteratorStreamer
from threading import Thread
from ferret.model.utils import DEVICE


GB = 1 << 30
Expand Down Expand Up @@ -177,7 +178,7 @@ def generate_stream(self, params):

if region_masks is not None:
assert self.add_region_feature
region_masks = [[torch.Tensor(region_mask_i).cuda().half() for region_mask_i in region_masks]]
region_masks = [[torch.Tensor(region_mask_i).to(DEVICE).half() for region_mask_i in region_masks]]
image_args["region_masks"] = region_masks
logger.info("Add region_masks to image_args.")
else:
Expand Down Expand Up @@ -211,15 +212,15 @@ def generate_stream(self, params):
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids]).cuda(),
torch.as_tensor([input_ids]).to(DEVICE),
use_cache=True,
**image_args)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device="cuda")
out = model(input_ids=torch.as_tensor([[token]], device="cuda"),
1, past_key_values[0][0].shape[-2] + 1, device=DEVICE)
out = model(input_ids=torch.as_tensor([[token]], device=DEVICE),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
Expand Down