-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update docs and docstrings related to Llama3VisionTransform #2382
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2382
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager Can you take a look? Seems right to me, but need an extra set of eyes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching all these inconsistencies 🫡
@@ -52,7 +52,7 @@ These are intended to be drop-in replacements for tokenizers in multimodal datas | |||
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False)) | |||
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>' | |||
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width) | |||
# torch.Size([4, 3, 224, 224]) | |||
# torch.Size([1, 3, 224, 224]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's wrong with 4 here? The default for max_num_tiles is 4 and this example doesn't specify that the image is small. It might be worth adding max_num_tiles=4
to LLama3VisionTransform on line 46.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example makes use of 224x224 images in Messages (lines 35 & 36). So when I tried running it, I got torch.Size([1, 3, 224, 224])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I missed that. Yes 1 would be correct here, though less instructive than if we used a larger image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the image size to be 560x560
@@ -42,14 +42,15 @@ in the text, ``"<image>"`` for where to place the image tokens. This will get re | |||
|
|||
.. code-block:: python | |||
|
|||
from torchtune.models.llama3_2_vision import llama3_2_vision_transform | |||
from torchtune.models.llama3_2_vision import Llama3VisionTransform | |||
from torchtune.datasets.multimodal import multimodal_chat_dataset | |||
|
|||
model_transform = Llama3VisionTransform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the error was actually the other way. Llama3VisionTransform
is supposed to be llama3_2_vision_transform
and then the other changes aren't necessary.
@@ -120,14 +120,14 @@ def multimodal_chat_dataset( | |||
|
|||
:: | |||
|
|||
>>> from torchtune.datasets.multimodal import multimodal_chat_dataset | |||
>>> from torchtune.models.llama3_2_vision import llama3_2_vision_transform | |||
>>> from torchtune.models.llama3_2_vision import Llama3VisionTransform | |||
>>> from torchtune.datasets.multimodal import multimodal_chat_dataset | |||
>>> model_transform = Llama3VisionTransform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment from above.
@@ -120,14 +120,14 @@ def multimodal_chat_dataset( | |||
|
|||
:: | |||
|
|||
>>> from torchtune.datasets.multimodal import multimodal_chat_dataset | |||
>>> from torchtune.models.llama3_2_vision import llama3_2_vision_transform | |||
>>> from torchtune.models.llama3_2_vision import Llama3VisionTransform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't multimodal_chat_dataset
needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was imported twice
8fd5598
to
d495f59
Compare
@pbontrager made the requested changes. Please let me know if there are any more changes. |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses. N/A
Changelog
What are the changes made in this PR?
Llama3VisionTransform
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example