We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Got a negative loss value while was training a CLIP model: am I doing something wrong or it is a loss function bug?
Here is my code
import torch from torch.utils.data import DataLoader from dalle2_pytorch import CLIP from dalle2_pytorch.tokenizer import SimpleTokenizer from dataset import TextImgDataset from tqdm import tqdm from torch.optim.adamw import AdamW clip = CLIP( dim_text=512, dim_image=32, dim_latent=512, num_text_tokens=49408, text_enc_depth=1, text_seq_len=256, text_heads=8, visual_enc_depth=1, visual_image_size=256, visual_patch_size=32, visual_heads=8, use_all_token_embeds=True, decoupled_contrastive_learning=True, extra_latent_projection=True, use_visual_ssl=True, visual_ssl_type='simclr', use_mlm=False, text_ssl_loss_weight=0.05, image_ssl_loss_weight=0.05 ).cuda() optim = AdamW(clip.parameters(), lr=3e-4) dataloader = DataLoader(dataset=TextImgDataset( 'hf://datasets/pranked03/flowers-blip-captions/data/train-00000-of-00001-f41d4839cc8f6449.parquet'), batch_size=4, shuffle=False) t = SimpleTokenizer() # Early stopping parameters patience = 10 best_loss = float('inf') trigger_times = 0 for epoch in range(1, 500): losses = [] for image, text in tqdm(dataloader, desc=f'epoch {epoch}'): optim.zero_grad() loss = clip( t.tokenize(text).cuda(), image.cuda(), return_loss=True ) loss.backward() losses.append(loss.item()) optim.step() epoch_loss = sum(losses) / len(losses) print(f"epoch {epoch}, loss: {epoch_loss}") # Check for early stopping if epoch_loss < best_loss: best_loss = epoch_loss trigger_times = 0 torch.save(clip.state_dict(), 'clip.pt') # Save the model when it improves else: trigger_times += 1 print(f"Trigger times: {trigger_times}") if trigger_times >= patience: print("Early stopping!") break
Custom dataset class
import torch from torch.utils.data import Dataset from torchvision import transforms as T import pandas as pd from PIL import Image import io # Dataset class, returns images and corresponding textual captions class TextImgDataset(Dataset): def __init__(self, fp: str): self.df = pd.read_parquet(fp) self.transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((32, 32)), T.ToTensor(), ]) def __len__(self) -> int: return self.df.shape[0] def __getitem__(self, idx) -> (torch.Tensor, str): row = self.df.iloc[idx] img_bytes = io.BytesIO(row['image']['bytes']) image = Image.open(img_bytes) image_tensor = self.transform(image) caption = row['text'] return image_tensor, caption
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Got a negative loss value while was training a CLIP model: am I doing something wrong or it is a loss function bug?
Here is my code
Custom dataset class
The text was updated successfully, but these errors were encountered: