Skip to content
/ ALIA Public
forked from lisadunlap/ALIA

Augmenting with Language-guided Image Augmentation (ALIA)

Notifications You must be signed in to change notification settings

yiyunwacc/ALIA

 
 

Repository files navigation

Automatic Language-guided Image Augmentation (ALIA)

Teaser

Welcome to the official repository for the paper "Diversify Your Vision Datasets with Automatic Diffusion-based Augmentation". If you prefer a condensed version, visit our TL;DR website. If you find our work useful, we welcome citations:

@article{dunlap2023alia,
  author    = {Dunlap, Lisa and Umino, Alyssa and Zhang, Han and Yang, Jiezhi and Gonzalez, Joseph and Darrell, Trevor},
  title     = {Diversify Your Vision Datasets with Automatic Diffusion-based Augmentation},
  journal   = {arXiv},
  year      = {2023},
}

UPDATE: We are currently rerunning experiments due to a bug in our checkpointing (shoutout to EyalMuchaeli for pointing it out), so the new numbers will be updated in the paper once all the experiments are done. If you want to track our newest results, here are the wandb projects to CUB, iWildCam, and Planes. Note that the traditional augmentation baselines for CUB now outperform ALIA and when running on a ResNet50, Txt2Img beats ALIA on Planes. Due to various issues with the Planes dataset, we have replaced it with the Waterbirds dataset.

NEW We have added the Waterbirds dataset(subsampled to exaggerate the real data gains but full dataset coming soon). We use the 95% bias split and use the 5% unbiased data as the extra set. Note that unlike the standard dataset, we make our val set biased as well (but our test set is unbiased). Full dataset can be either generated through their repo or downloaded from this Gdrive link

Table of Contents

  1. Getting Started
  2. Prompt Generation
  3. Generating Images
  4. Filtering
  5. Training
  6. WandB Projects
  7. Checkpoints
  8. Add Custom Datasets

Getting Started

To begin, install our code dependencies using Conda. You may need to adjust the environment.yaml file based on your setup:

conda env create -f environment.yaml
conda activate ALIA
pip install -e .

All experiment parameters are in yaml configs, with configs/base.yaml containing all default parameters and their description. If this is your first time downloading a dataset for this project, please change the base_root in configs/base.yaml to point to the root directory of any downloaded dataset. If you don’t have any precomputed clip embeddings for this project, please also change the embedding_root in configs/base.yaml to null. The defaults for each individual dataset are in their configs/DATASET/base.yaml folder.

The overall pipeline is split up over several files: caption.py captions the dataset, prompt_generation.py extracts the domains from the captions, main.py does all the training/eval, filter.py saves the indexes to be filtered for a given dataset, and editing methods create the training data. To train a model with ALIA, the pipeline would be caption -> prompt_generation -> editing -> main (base model w/ original training data) -> filter (generated training data) -> main (model w/ filtered original + generated data). We outline the exact commands below.

Prompt Generation

  • Captioning: We use the BLIP captioning model to caption the entire dataset:

    python caption.py --config configs/Cub2011/base.yaml

    This will save your captions here.

  • LLM Summarization: In our paper, we used GPT-4 to summarize the domains from the captions. Alternatively, we provide Vicuna support for those who prefer not to give money to OpenAI. Download the Vicuna weights here (we used the 13b parameter model).

    pip3 install fastchat
    python huggingface_api.py message="Hi! How are you doing today?" #test to make sure it works
    python prompt_generation.py --config configs/Cub2011/base.yaml #return prompts

We randomly sample 20 captions to fit within the context length but highly encourage others to develop better methods :)

Generating Images

Our editing methods are housed in editing_methods and utilize the Huggingface Diffusers library and the tyro CLI.

  • Per Example: To generate multiple images given a prompt or edit a single image, use txt2img_example.py or img2img_example.py.

    python editing_methods/txt2img_example.py --prompt "Arachnophobia" --n 20
  • Per Dataset: To generate images for an entire dataset, use the class_names attribute of the dataset to create per-class prompts.

    python editing_methods/img2img.py --dataset Cub2011 --prompt "a photo of a {} bird on rocks." --n 2

Filtering

Once you have generated your data, determine which indices to filter out by running the following command:

python filtering/filter.py --config configs/Cub2011/alia.yaml filter.load=false

NOTE: since this filter requires a pretrained model for the confidence-based filtering, you will need to train a base model first (see below).

Training

To train the base models or models with augmented data, simply run the appropriate YAML file from the configs folder.

python main.py --config configs/Cub2011/base.yaml

To apply a traditional data augmentation technique, set data.augmentation=cutmix. See all available data augmentations in the load_dataset file.

WandB Projects

Our datasets of generated data can be found here under the 'Artifacts' tab. Each artifact includes the hyperparameters and prompts used to create it.

Download the images with the following command:

import wandb
run = wandb.init()
artifact = run.use_artifact('clipinvariance/ALIA/cub_generic:v0', type='dataset')
artifact_dir = artifact.download()

View generated data examples for Txt2Img, Img2Img, and InstructPix2Pix.

Checkpoints

All of our runs, checkpoints, and captions are on WandB. We reran all experiments with the cleaned repo so results may be slightly different than those in the paper.

Seriously, Weights and Biases, send me a care package; I'm giving you some serious promo here.

Add Custom Datasets

To add your own dataset, you need to add a file to the datasets folder and then add it as an option in helpers/load_dataset.py. The repository expects a dataset object of a specific format, where __getitem__ should return three things: image, target, and group (group is the domain the image is in, set to 0 if it's not a bias/DA dataset).

Additionally, the dataset class needs to have the following parameters: classes, groups, class_names, group_names, targets, class_weights. Here's an example:

class BasicDataset(torchvision.datasets.ImageFolder):
    """
    Wrapper class for torchvision.datasets.ImageFolder.
    """
    def __init__(self, root, transform=None, group=0, cfg=None):


        self.group = group # used for domain adaptation/bias datasets, where the group is the domain or bias type.
        super().__init__(root, transform=transform)
        self.groups = [self.group] * len(self.samples) # all images are from the same domain, set the group label to 0 for all of them
        self.group_names = ["all"] # only one group name (used for logging)
        self.class_names = self.classes # used for logging
        self.targets = [s[1] for s in self.samples] 
        self.class_weights = get_counts(self.targets) # class weights for XE loss

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, self.group

After adding your dataset to the get_dataset function, create a default config and set data.base_dataset to the name of your dataset. Then you should be able to generate the prompts and images, mimicking the data.extra_dataset parameters for CUB but replacing data.extra_root with the location of your generated data.

For example, suppose you want to add a typical PyTorch ImageFolder dataset like ImageNet. You can manually determine how much data to add through either the extraset (real data baseline from the paper) or through the data.num_extra parameter. If you want to use ALIA or other methods to improve performance, don't worry about the real data baseline and set data.num_extra to the number of augmented samples you want to add. For this example, say you want to add 1000 augmented samples to your training set.

Since we already have a wrapper for the ImageFolder class in datasets/base.py, you can use that to add your dataset (like ImageNet) into the get_dataset function.

def get_dataset(dataset_name, transform, val_transform, root='./data', embedding_root=None):
    .....

    elif dataset_name == 'ImageNet':
        trainset = BasicDataset(root='/path/to/imagenet/train', transform=transform)
        valset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
        extraset = None # set to none since we are specifying the amount of generated data to add with data.num_extra
        testset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
    ......

    return trainset, valset, testset, extraset

Now all you need to do is create your config:

base_config: configs/base.yaml # this sets default parameters
proj: ALIA-ImageNet # wandb project
name: ImageNet # name of dataset used for logging (can set this to anything)

data: 
  base_dataset: ImageNet # name of dataset used in the new_get_dataset method

From here, you should be able to follow the README as normal.

About

Augmenting with Language-guided Image Augmentation (ALIA)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%