Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Negative loss on CLIP tranining #321

Open
u1ug opened this issue Aug 11, 2024 · 0 comments
Open

Negative loss on CLIP tranining #321

u1ug opened this issue Aug 11, 2024 · 0 comments

Comments

@u1ug
Copy link

u1ug commented Aug 11, 2024

Got a negative loss value while was training a CLIP model: am I doing something wrong or it is a loss function bug?

Here is my code

import torch
from torch.utils.data import DataLoader
from dalle2_pytorch import CLIP
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dataset import TextImgDataset
from tqdm import tqdm
from torch.optim.adamw import AdamW

clip = CLIP(
    dim_text=512,
    dim_image=32,
    dim_latent=512,
    num_text_tokens=49408,
    text_enc_depth=1,
    text_seq_len=256,
    text_heads=8,
    visual_enc_depth=1,
    visual_image_size=256,
    visual_patch_size=32,
    visual_heads=8,
    use_all_token_embeds=True,
    decoupled_contrastive_learning=True,
    extra_latent_projection=True,
    use_visual_ssl=True,
    visual_ssl_type='simclr',
    use_mlm=False,
    text_ssl_loss_weight=0.05,
    image_ssl_loss_weight=0.05
).cuda()

optim = AdamW(clip.parameters(), lr=3e-4)

dataloader = DataLoader(dataset=TextImgDataset(
    'hf://datasets/pranked03/flowers-blip-captions/data/train-00000-of-00001-f41d4839cc8f6449.parquet'), batch_size=4,
                        shuffle=False)
t = SimpleTokenizer()

# Early stopping parameters
patience = 10
best_loss = float('inf')
trigger_times = 0

for epoch in range(1, 500):
    losses = []
    for image, text in tqdm(dataloader, desc=f'epoch {epoch}'):
        optim.zero_grad()
        loss = clip(
            t.tokenize(text).cuda(),
            image.cuda(),
            return_loss=True
        )
        loss.backward()
        losses.append(loss.item())
        optim.step()

    epoch_loss = sum(losses) / len(losses)
    print(f"epoch {epoch}, loss: {epoch_loss}")

    # Check for early stopping
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        trigger_times = 0
        torch.save(clip.state_dict(), 'clip.pt')  # Save the model when it improves
    else:
        trigger_times += 1
        print(f"Trigger times: {trigger_times}")

        if trigger_times >= patience:
            print("Early stopping!")
            break

Custom dataset class

import torch
from torch.utils.data import Dataset
from torchvision import transforms as T
import pandas as pd
from PIL import Image
import io


# Dataset class, returns images and corresponding textual captions
class TextImgDataset(Dataset):
    def __init__(self, fp: str):
        self.df = pd.read_parquet(fp)
        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((32, 32)),
            T.ToTensor(),
        ])

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, idx) -> (torch.Tensor, str):
        row = self.df.iloc[idx]
        img_bytes = io.BytesIO(row['image']['bytes'])
        image = Image.open(img_bytes)
        image_tensor = self.transform(image)
        caption = row['text']

        return image_tensor, caption
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant