Skip to content

Commit

Permalink
Multiturn mm single image (#1270)
Browse files Browse the repository at this point in the history
* initial test

* Pad casual mask with zeroes and set decoder max_seq_len to the max sequence length so their shapes are both set to the max_seq_len

* Fix control bug for image inputs

* Clear image input after submitting a chat

* Include empty assistant message for chat

* Pipe image input from CLI

---------

Co-authored-by: Jack-Khuu <[email protected]>
Co-authored-by: vmpuri <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2024
1 parent 766bee9 commit d0993b3
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 107 deletions.
141 changes: 96 additions & 45 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import base64
import itertools
import logging
import os
Expand All @@ -12,6 +13,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -600,9 +602,8 @@ def generate(

if len(prompt.shape) > 1:
prompt = prompt.squeeze(0)
T = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
T_new = T + max_new_tokens
prompt_length = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
# set up caches only if first inference
if start_pos == 0:
model = model.to(device=device)
Expand All @@ -616,7 +617,7 @@ def generate(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=6404,
decoder_max_seq_len=T_new,
decoder_max_seq_len=max_seq_length,
)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
Expand All @@ -629,7 +630,7 @@ def generate(
model.reset_caches()

input_pos = torch.arange(
start_pos, T + start_pos, device=device, dtype=torch.int
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
)

prefill_t0 = time.perf_counter()
Expand All @@ -655,7 +656,9 @@ def generate(
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)

input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
input_pos = torch.tensor(
[start_pos + prompt_length], device=device, dtype=torch.int
)
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
Expand All @@ -678,7 +681,7 @@ def generate(
)

accept_counts[len(next_tokens) - 1] += 1
num_added = min(T_new - input_pos - 1, len(next_tokens))
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
for token in next_tokens[:num_added,]:
callback(token)
yield token, None
Expand Down Expand Up @@ -741,6 +744,7 @@ def _gen_model_input(
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
max_seq_len: Optional[int] = 2048,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.
Expand All @@ -757,7 +761,7 @@ def _gen_model_input(
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
# Text-Only model
if self.model.config.model_type != ModelType.Flamingo:
# Single String prompt
if isinstance(prompt, str):
Expand All @@ -778,32 +782,69 @@ def _gen_model_input(
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"

if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts
images = None

assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
image_found = False
messages = []
for message in prompt:
if isinstance(message["content"], str):
if not image_found and image_prompts:
messages.append(
Message(
role=message["role"],
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": message["content"]},
],
)
)
image_found = True
else:
messages.append(Message(**message))

elif isinstance(message["content"], list):
images = None
for content_dict in message["content"]:
if content_dict["type"] == "text":
prompt_arg = content_dict["text"]
elif content_dict["type"] == "image_url":
assert (
images is None
), "At most one image is supported at the moment"

base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
images = [Image.open(BytesIO(base64_decoded))]
image_found = True

is_multimodal = images is not None
content = [{"type": "text", "content": prompt_arg}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content
messages.append(
Message(
role=message["role"],
content=content,
)
)

messages = [
messages.append(
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
role="assistant",
content="",
)
)

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

Expand All @@ -812,7 +853,7 @@ def _gen_model_input(
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
if image_found:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
Expand All @@ -822,17 +863,27 @@ def _gen_model_input(
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)

else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
batch["causal_mask"] = torch.nn.functional.pad(
torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
),
(
0,
max_seq_len - total_response_length,
0,
max_seq_len - total_response_length,
),
value=0,
)

logging.debug(encoded)
Expand All @@ -845,12 +896,6 @@ def chat(
if generator_args.chat_mode:
print("Starting Interactive Chat")

encoded, batch = self._gen_model_input(
generator_args.prompt,
generator_args.image_prompts,
generator_args.max_new_tokens,
)

model_size = sum(
[
p.numel() * p.dtype.itemsize
Expand Down Expand Up @@ -896,6 +941,12 @@ def chat(
max_seq_length = (
text_transformer_args.max_seq_length if text_transformer_args else 2048
)
encoded, batch = self._gen_model_input(
[{"role": "user", "content": generator_args.prompt}],
generator_args.image_prompts,
generator_args.max_new_tokens,
max_seq_length,
)

if generator_args.chat_mode:
print(
Expand All @@ -907,16 +958,16 @@ def chat(
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")

elif not generator_args.is_torchtune_model:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
(
text_transformer_args.block_size
if text_transformer_args is not None
else 2048
),
max_seq_length,
)
# elif not generator_args.is_torchtune_model:
# max_seq_length = min(
# encoded.size(0) + generator_args.max_new_tokens,
# (
# text_transformer_args.block_size
# if text_transformer_args is not None
# else 2048
# ),
# max_seq_length,
# )

max_seq_length = (
max_seq_length + self.speculative_builder_args.speculate_k + 1
Expand Down
Loading

0 comments on commit d0993b3

Please sign in to comment.