-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Image Generation Dataset #2140
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Overview | ||
|
||
This is an RFC regarding how we should support datasets for finetuning text-conditioned image generation models. | ||
|
||
A basic data pipeline for this would be: | ||
1. Load the JSON/CSV/TSV/Parquet/LMDB/etc. file containing the image paths/urls and captions | ||
2. For each pair: | ||
- load/download the image | ||
- resize the image and optionally randomly augment it (horizontal flip, etc.) and normalize it | ||
- optionally randomly augment the caption (rearrange caption parts, etc.) | ||
- tokenize the caption using the model's tokenizer | ||
3. collate into a batch | ||
|
||
At a broad level, this fits well into our current TorchTune data ecosystem (except we wouldn't use the "list of Message objects" abstraction, which would change how we interact with the model's tokenizer). | ||
|
||
In TorchTune, a simple version would look something like this: | ||
|
||
```yaml | ||
dataset: | ||
_component_: torchtune.datasets.img_caption_dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just my naivate, but when I hear image-caption dataset, I assume it's a dataset for taking an image and generating a caption, which is not the case here. Hugging Face has a label for these datasets called "Text-to-Image", which I think is a more accurate description. This also is inline with our addition of task-centered dataset builders like the vqa_dataset. Concretely proposing changing the default dataset for diffusion from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figured that this dataset could be used for any downstream task that uses pairs of images+text. Like finetuning CLIP for example. Maybe |
||
path: ~/my_dataset/data.tsv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it more common to have ahem private data to finetune diffusion models, or data that might be published on the Hugging Face Hub? That should affect what the first-class citizen is here and what goes in all our examples. Regardless, if we're using the |
||
img_transform: | ||
resize: [256, 256] | ||
center_crop: true | ||
horizontal_flip: 0.5 | ||
caption_transform: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above comment, but would opt for |
||
drop: 0.05 | ||
shuffle_parts: 0.1 | ||
tokenizer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you show how this would look from code? I know we prefer flattened params for our configs, but if this was build via code I'd imagine we'd instantiate Clip and T5 and then pass that to our FluxTransform - right? |
||
_component_: torchtune.models.flux.FluxTransform | ||
clip_tokenizer_path: ... | ||
t5_tokenizer_path: ... | ||
t5_max_seq_len: 256 | ||
``` | ||
|
||
```python | ||
def img_caption_dataset( | ||
model_transform: Transform, | ||
*, | ||
path: str, | ||
img_transform: Config, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we ever want our builders to see the notion of configs. Configs are just a way to interface with our recipes, but builders should be able to be dropped into place anywhere. |
||
caption_transform: Config, | ||
): | ||
"""Builder for an image caption dataset.""" | ||
data = _load_img_text_dataset(path) | ||
img_transform = _build_torchvision_transforms(img_transform) | ||
caption_transform = _CaptionTransform(caption_transform) | ||
return ImgTextDataset( | ||
data, | ||
img_transform=img_transform, | ||
text_tranform=caption_transform, | ||
model_transform=model_transform, | ||
) | ||
|
||
|
||
def _load_img_text_dataset(path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use huggingface load_dataset as well as load_image There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding load_image, thanks I'll switch to this. Question though: when the image path is a URL, should we include the option for saving these images to disk so that they don't need to be re-downloaded during the next epoch? Regarding load_dataset, I address this in the first bullet of the user experience section. I personally think it's better if we handle simple cases like loading a image-caption TSV ourselves so the user doesn't have to go read huggingface docs, especially since most img gen finetuning will be done on small local datasets, but I'm also ok with just relying on huggingface's load_dataset since that does make our code simpler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally understand your point on not wanting to overcomplicate things, but using load_dataset under the hood makes our lives way easier lol |
||
if '.' not in path: | ||
return datasets.load_dataset(path, ...) | ||
|
||
path = Path(path).expanduser().resolve() | ||
if path.suffix == ".tsv": | ||
data = [] | ||
with open(path, "r") as f: | ||
for line in f: | ||
img_path_or_url, text = [x.strip() for x in line.split("\t")] | ||
data.append((img_path_or_url, text)) | ||
return data | ||
|
||
elif path.suffix == "...": | ||
... | ||
|
||
|
||
def _build_torchvision_transforms(cfg): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This along with CaptionTransform is all within the abstraction of model transform or data transform as the user needs. Or is this meant to be an example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should separate the data transform logic from the model transform logic, e.g. the data augmentations like horizontal flip would be in a img transform that's entirely separate from model logic, and the model-specific logic like image normalization would be in the model transform |
||
""" | ||
Create a series of torchvision transforms | ||
(resize, crop, flip, etc.) | ||
""" | ||
... | ||
|
||
|
||
class _CaptionTransform: | ||
""" | ||
Callable that randomly augments image captions with comma-separated parts | ||
(shuffle parts, randomly drop entire caption, etc.) | ||
(or does nothing if disabled) | ||
""" | ||
|
||
def __init__(self, cfg): ... | ||
|
||
def __call__(self, caption: str) -> str: ... | ||
|
||
|
||
class ImgTextDataset(torch.utils.data.Dataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should basically look like the text completion dataset but with model_transform as you have here and column_map instead of column. I also don't know if we want to anchor this to vision as diffusion for audio etc would use the same pattern. I think an optional data_transform would work here instead of img/text. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should include an image transform here that way we can separate model-independent image augmentations from model-specific ones in the model transform. Also, I don't think this should be a generic dataset class for diffusion in general. It shouldn't be tied to diffusion at all, and instead be for any downstream task that uses image-text pairs, e.g. non-diffusion image gen models, image captioning models, image-text joint encoders, etc. There were a lot of papers at NeurIPS this year that was finetuning CLIP. I would expect this to utilize the same ImageTextDataset as finetuning Flux would. If you're doing diffusion for audio, you would use an AudioTextDataset There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you WRT a dataset class (Hence why everything so far essentially returns an SFT dataset). However, I do think there's tremendous value in aligning our dataset builders with specific tasks. It makes it easier to utilize from configs and find datasets to use on the Hub. |
||
def __init__(self, data, img_transform, text_transform, model_transform): | ||
self._data = data | ||
self._img_transform = img_transform | ||
self._text_transform = text_transform | ||
self._model_transform = model_transform | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
def __getitem__(self, idx): | ||
img_path_or_url, text = self._data[idx] | ||
img = ( | ||
Image.open(BytesIO(requests.get(img_path_or_url).content)) | ||
if img_path_or_url.startswith(("http://", "https://", "ftp://", "ftps://")) | ||
else Image.open(img_path_or_url) | ||
) | ||
img = self._img_transform(img) | ||
text = self._text_transform(text) | ||
data_dict = self._model_transform(img, text) | ||
return data_dict | ||
|
||
|
||
class FluxTransform(Transform): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generic diffusion model transform will just take a dict instead of a list of messages but otherwise be the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is logic here that is specific to Flux and I think it should exist withing a Flux-specific model transform |
||
def __init__(self, clip_tokenizer_path, t5_tokenizer_path, t5_max_seq_len): | ||
... | ||
|
||
def __call__(self, img, text): | ||
return { | ||
'img': (img / 127.5) - 1.0, | ||
'clip_text_tokens': self._clip_tokenizer(text), | ||
't5_text_tokens': self._t5_tokenizer(text), | ||
} | ||
``` | ||
|
||
# TODO: Collate | ||
|
||
We'll need to generalize our collate functions such that they can handle data outside of the tokens-and-labels format they currently expect. I will update this section after I've looked into this. | ||
|
||
# Caching/Preprocessing | ||
|
||
From what I've seen online, some people finetune image generators on massive datasets, but most people just finetune on very small personal datasets, often 5-100 images. So we should probably add support for various caching/preprocessing options that increase disk/mem usage in order to achieve faster iterations. Some ideas for optional configurations: | ||
|
||
- cache up to N images in each data worker so they don't have to load them fresh from disk each epoch | ||
- in the extreme case of like <10 images, we could even just keep the whole dataset on each GPU so we don't have to transfer them each step | ||
- in the case of a web dataset, save up to N downloaded images to local storage for the next epoch | ||
- before training, preprocess the outputs of frozen parts of the model (text tokens, image autoencoder embeddings) and save them to disk so that we don't have to recompute every epoch | ||
- tokenization would be negligible but I bet preprocessing the Flux image encoding would save a lot of time and GPU memory | ||
- this could also be done on the fly, i.e. caching instead of preprocessing. During the first epoch, save the intermediate values to disk and reuse them in all the next epochs. But this makes the code much more complicated. | ||
|
||
But we should evaluate whether each of these is worth it: | ||
- how much performance gain would you actually get? and under what circumstances? | ||
- how much would it complicate the code and the configs? | ||
|
||
# Dataset Creation | ||
|
||
Should we include scripts/utilities for creating the captions? Users will probably often have just a folder with a bunch of images that they want to finetune on. So we could help them turn that folder into a dataset by using some model to automatically caption them. We could even provide our own models for this by distilling the image captioning capabilities of Llama3.2V-90B into several smaller Llama3.2V models, and let the user pick the one that fits on their device. | ||
|
||
We'll also want to support adding words/phrases to the caption that tell the model to generate in the style of this dataset. For example, if I'm finetuning a model on images of myself, I'll want to include something like "a photo of cpelletier" in the caption so that the model learns to associate "cpelletier" with my face. | ||
|
||
# User Experience | ||
|
||
- Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's `load_dataset` like we currently do in `SFTDataset`? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves. | ||
- In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is handled via our current image/text dataset utilities. |
||
- There's currently some potentially unnecessary fields in the config. For example with Flux models, the model determines the image size and the T5 tokenizer sequence length. Is it better to pass this information to the image transform and model transform, respectively? Which complicates the code but lowers the chance of user error. Or is it better to have the user define these values in the dataset config and tokenizer config, respectively? Which puts the burden on the user to match what the model expects. | ||
- Should we add scripts/utilities for inspecting the dataset? It's nice to see a preview of what a batch looks like, especially when you're messing around with color jitter and other hard-to-configure image augmentations. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely a cool feature, but probably a P2 or upon-request-from-users type of thing. |
||
|
||
# Other | ||
- Naming of the image-text dataset builders/classes? Maybe the more verbose `image_caption_dataset_for_image_generation` is better to make it clear that this is NOT for something like finetuning a VLM to do image captioning (although maybe it could be generalized to the point where it can also do lists of Message objects and therefore can be used for whatever purpose). | ||
- Support multiple captions per image? I can imagine people wanting to generate multiple captions for their images, and randomly selecting one at a time during training to prevent overfitting. It's kinda a caption augmentation but it's unique for each caption so it would have to be supported at the data level. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be possible to do easily with torchtune, but definitely not OOTB. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay one big question: What direction are we trying to go in?
We landed torchdata support which started a refactor of our datasets into dataset-specific utils rather than an entire builder that essentially just spits back an SFT datasets class. IMO this means less code for the user to worry about and makes hacking easier. In addition, this gives us all the benefits from torchdata.
If we believe torchdata is the right way to go (especially for these more data-intensive use cases), then should this be refactored towards that end?
cc @pbontrager @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of this was to follow the pattern of our current SFT dataset solution so it'd be easier to move in parallel to the torchdata solution. By following close to STF then it should be trivial to convert this to the torchdata solution once that's finalized.