This repository provides a PyTorch Lighting implementation for VICReg, as described in the paper VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning. This repo is inspired on the original repository of Meta AI.
This module was written with the style used in Lightning Bolts for other SOTA Self-Supervised models.
PyTorch Lightning is a lightweight PyTorch wrapper for high-performance AI research. It makes your code neatly organized and provides lots of useful features, like ability to run model on CPU, GPU, multi-GPU cluster and TPU.
Lightning Bolts is a community-built deep learning research and production toolbox, featuring a collection of well established and SOTA models and components, pre-trained weights, callbacks, loss functions, data sets, and data modules.
Here are some examples!
Python
model = VICReg(
arch="resnet18",
maxpool1=False,
first_conv=False,
mlp_expander='2048-2048-2048',
invariance_coeff=25.0,
variance_coeff=25.0,
covariance_coeff=1.0,
optimizer="lars",
learning_rate=0.3,
warmup_steps=10
)
dm = CIFAR10DataModule(batch_size=128, num_workers=0)
dm.train_transforms = VICRegTrainDataTransform(
input_height=32,
gaussian_blur=False,
jitter_strength=1.0
)
dm.val_transforms = VICRegEvalDataTransform(
input_height=32,
gaussian_blur=False,
jitter_strength=1.0
)
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
Command line interface [cifar10
]
python vicreg_module.py
--accelerator gpu
--devices 1
--dataset cifar10
--data_dir /path/to/cifar/
--batch_size 128
--arch resnet18
--maxpool1 False
--first_conv False,
--mlp_expander 2048-2048-2048
--invariance_coeff 25.0
--variance_coeff 25.0
--covariance_coeff 1.0
--optimizer adam
--learning_rate 0.3
--warmup_steps 10
Command line interface [imagenet
]
python vicreg_module.py
--accelerator gpu
--devices 1
--dataset imagenet
--data_dir /path/to/imagenet/
--batch_size 512
--arch resnet50
--maxpool1 True
--first_conv True,
--mlp_expander 8192-8192-8192
--invariance_coeff 25.0
--variance_coeff 25.0
--covariance_coeff 1.0
--optimizer lars
--learning_rate 0.6
--warmup_steps 10
I have pre-trained the model for CIFAR10(here WandB eval metrics for CIFAR10)
If you love notebooks and free GPUs, the Colab version of this repository can be found here