Skip to content

Commit

Permalink
continue using llama3_2_vision_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur-singh committed Feb 14, 2025
1 parent d2eb1bc commit d495f59
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/source/basics/model_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ These are intended to be drop-in replacements for tokenizers in multimodal datas
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(560, 560))},
{"type": "image", "content": Image.new(mode="RGB", size=(560, 560))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Expand All @@ -52,7 +52,7 @@ These are intended to be drop-in replacements for tokenizers in multimodal datas
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([1, 3, 224, 224])
# torch.Size([4, 3, 224, 224])
Using model transforms
Expand Down
7 changes: 3 additions & 4 deletions docs/source/basics/multimodal_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,14 @@ in the text, ``"<image>"`` for where to place the image tokens. This will get re
.. code-block:: python
from torchtune.models.llama3_2_vision import Llama3VisionTransform
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset
model_transform = Llama3VisionTransform(
model_transform = llama3_2_vision_transform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
prompt_template="torchtune.data.QuestionAnswerTemplate",
max_seq_len=8192,
tile_size=244,
patch_size=14,
image_size=560,
)
ds = multimodal_chat_dataset(
model_transform=model_transform,
Expand Down
7 changes: 3 additions & 4 deletions torchtune/datasets/multimodal/_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ def multimodal_chat_dataset(
::
>>> from torchtune.models.llama3_2_vision import Llama3VisionTransform
>>> from torchtune.models.llama3_2_vision import llama3_2_vision_transform
>>> from torchtune.datasets.multimodal import multimodal_chat_dataset
>>> model_transform = Llama3VisionTransform(
>>> model_transform = llama3_2_vision_transform(
>>> path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
>>> prompt_template="torchtune.data.QuestionAnswerTemplate",
>>> max_seq_len=8192,
>>> tile_size=224,
>>> patch_size=14,
>>> image_size=560,
>>> )
>>> ds = multimodal_chat_dataset(
>>> model_transform=model_transform,
Expand Down

0 comments on commit d495f59

Please sign in to comment.