Skip to content

Commit

Permalink
Merge pull request #10 from kuanweih/simsiam_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshua Yao-Yu Lin authored Oct 29, 2023
2 parents 0751c5a + be3686e commit 72f5d78
Show file tree
Hide file tree
Showing 8 changed files with 920 additions and 11 deletions.
2 changes: 1 addition & 1 deletion calc_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def main(device, args):

# Load dataset
dataset = get_dataset(args.dataset.name, args.dataset.data_dir, args.dataset.subset_size)
dataset = get_dataset(args.dataset.name, args.dataset.data_dir, args.dataset.subset_size, args.aug_method)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.train.batch_size)

# Load the trained model
Expand Down
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ seed: null # None type for yaml file
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)


aug_method: simsiam


4 changes: 3 additions & 1 deletion configs/umap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ testsets:
RealHST:
root: D:\Datasets\2023_hst_lensed_quasars\npy_files
suffix: cutout
subset_size: null
subset_size: null

aug_method: lensiam
4 changes: 2 additions & 2 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from .umap_testsets import get_stl10, get_2022_lens_geoff, get_real_hst


def get_dataset(dataset_name, data_dir, subset_size=None):
def get_dataset(dataset_name, data_dir, subset_size=None, aug_method=None):
if dataset_name == 'paired-lensing':
dataset = PairedLensingImageDataset(data_dir)
dataset = PairedLensingImageDataset(data_dir, aug_method=aug_method)
else:
raise NotImplementedError
if subset_size is not None:
Expand Down
47 changes: 42 additions & 5 deletions datasets/paired_lensing_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import glob
import os
import numpy as np
import random
from PIL import Image

from astropy.io import fits
from torch.utils.data import Dataset
Expand Down Expand Up @@ -42,11 +44,34 @@ def __call__(self, x):
return self.transform(x)


class SimSiamTransform():
def __init__(self):
# image_size = 224 if image_size is None else image_size # by default simsiam use image size 224
p_blur = 0.5 if image_size > 32 else 0 # exclude cifar
# the paper didn't specify this, feel free to change this value
# I use the setting from simclr which is 50% chance applying the gaussian blur
# the 32 is prepared for cifar training where they disabled gaussian blur
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # grayscale -> RGB
transforms.Normalize(*imagenet_mean_std)
])
def __call__(self, x):
x = Image.fromarray((x * 255).astype(np.uint8)) # Convert NumPy array to PIL Image for the input of random augmentation transformation.
return self.transform(x)


class PairedLensingImageDataset(Dataset):
""" Pytorch Dataset Object for the paired lensing image dataset in fits file.
"""
def __init__(self, root=None):
def __init__(self, root=None, aug_method=None):
self.root = root
self.aug_method = aug_method
self.file_names = glob.glob(os.path.join(self.root, "*.fits"))
self.size = len(self.file_names)

Expand All @@ -55,10 +80,22 @@ def __getitem__(self, idx):
raise Exception
file_path = self.file_names[idx]
img_pair, label = load_fits_file(file_path)
transform = LensingImageTransform()
img1 = transform(img_pair[:, :, 0])
img2 = transform(img_pair[:, :, 1])
return img1, img2, label, file_path

if self.aug_method == 'lensiam':
transform = LensingImageTransform()
img1 = transform(img_pair[:, :, 0])
img2 = transform(img_pair[:, :, 1])
return img1, img2, label, file_path
elif self.aug_method == 'simsiam':
transform_aug = SimSiamTransform()
# Randomly choose one of the images from img_pair
i = random.choice([0, 1])
img1 = transform_aug(img_pair[:, :, i])
img2 = transform_aug(img_pair[:, :, i])
return img1, img2, label, file_path
else:
raise NotImplementedError


def __len__(self):
return self.size
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@


def main(device, args):


args.dataset_kwargs['aug_method'] = args.aug_method
# Load dataset
train_loader = torch.utils.data.DataLoader(
dataset=get_dataset(**args.dataset_kwargs),
Expand Down
316 changes: 316 additions & 0 deletions notebooks/ICML_umap_plots.ipynb

Large diffs are not rendered by default.

553 changes: 553 additions & 0 deletions notebooks/NeurIPs_umap_plots.ipynb

Large diffs are not rendered by default.

0 comments on commit 72f5d78

Please sign in to comment.