Skip to content

Commit

Permalink
Formatting and Sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay committed Jun 12, 2024
1 parent 717de64 commit 4cdc2d3
Show file tree
Hide file tree
Showing 26 changed files with 212 additions and 227 deletions.
44 changes: 23 additions & 21 deletions benchmate/benchmate/datagen.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
#!/usr/bin/env python

import argparse
from collections import defaultdict
import json
import multiprocessing
import os
from collections import defaultdict
from pathlib import Path
import json
import warnings

warnings.filterwarnings('ignore')

import torch

from tqdm import tqdm


def write(args):
import torch
import torchvision.transforms as transforms

offset, outdir, size = args
offset, outdir, size = args

img = torch.randn(*size)
target = offset % 1000 # torch.randint(0, 1000, size=(1,), dtype=torch.long)[0]
target = offset % 1000 # torch.randint(0, 1000, size=(1,), dtype=torch.long)[0]
img = transforms.ToPILImage()(img)

class_val = int(target)
Expand All @@ -35,14 +30,16 @@ def write(args):
img.save(image_path)


def generate(image_size, n, outdir, start = 0):
def generate(image_size, n, outdir, start=0):
work_items = []
for i in range(n):
work_items.append([
start + i,
outdir,
image_size,
])
work_items.append(
[
start + i,
outdir,
image_size,
]
)

n_worker = min(multiprocessing.cpu_count(), 8)
with multiprocessing.Pool(n_worker) as pool:
Expand All @@ -53,7 +50,7 @@ def generate(image_size, n, outdir, start = 0):
def count_images(path):
count = defaultdict(int)
for root, _, files in tqdm(os.walk(path)):
split = root.split('/')[-2]
split = root.split("/")[-2]
count[split] += len(files)

return count
Expand All @@ -71,7 +68,12 @@ def generate_sets(root, sets, shape):

if current_count < count:
print(f"Generating {split} (current {current_count}) (target: {count})")
generate(shape, count - current_count, os.path.join(root, split), start=current_count)
generate(
shape,
count - current_count,
os.path.join(root, split),
start=current_count,
)

with open(sentinel, "w") as fp:
json.dump(sets, fp)
Expand All @@ -92,14 +94,14 @@ def generate_fakeimagenet():

total_images = args.batch_size * args.batch_count
size_spec = {
"train": total_images,
"val": int(total_images * args.val),
"test": int(total_images * args.test)
"train": total_images,
"val": int(total_images * args.val),
"test": int(total_images * args.test),
}

generate_sets(dest, size_spec, args.image_size)
print("Done!")


if __name__ == "__main__":
generate_fakeimagenet()
generate_fakeimagenet()
100 changes: 47 additions & 53 deletions benchmate/benchmate/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
import torch.cuda.amp
import torchcompat.core as accelerator
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchcompat.core as accelerator
from torch.utils.data.distributed import DistributedSampler


Expand All @@ -23,14 +23,14 @@ def generate_tensors(batch_size, shapes, device):
tensors = []
if len(shapes[0]) == 2:
tensors = dict()

for kshape in shapes:
if len(kshape) == 2:
key, shape = kshape
tensors[key] = torch.randn((batch_size, *shape), device=device)
else:
tensors.append(torch.randn((batch_size, *kshape), device=device))
tensors.append(torch.randn((batch_size, *kshape), device=device))

return tensors


Expand Down Expand Up @@ -70,7 +70,7 @@ def __iter__(self):
if self.fixed_batch:
for _ in range(self.n):
yield self.tensors

else:
for _ in range(self.n):
yield [torch.rand_like(t) for t in self.tensors]
Expand All @@ -80,25 +80,25 @@ def __len__(self):


def dali(folder, batch_size, num_workers, rank=0, world_size=1):
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.plugin.pytorch import DALIGenericIterator

@pipeline_def(num_threads=num_workers, device_id=0)
def get_dali_pipeline():
images, labels = fn.readers.file(
file_root=folder,
random_shuffle=True,
file_root=folder,
random_shuffle=True,
name="Reader",
shard_id=rank,
num_shards=world_size,
)

# decode data on the GPU
images = fn.decoders.image_random_crop(
images,
device="mixed",
images,
device="mixed",
output_type=types.RGB,
)
# the rest of processing happens on the GPU as well
Expand All @@ -109,14 +109,14 @@ def get_dali_pipeline():
crop_w=224,
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
mirror=fn.random.coin_flip()
mirror=fn.random.coin_flip(),
)
return images, labels

train_data = DALIGenericIterator(
[get_dali_pipeline(batch_size=batch_size)],
['data', 'label'],
reader_name='Reader'
["data", "label"],
reader_name="Reader",
)

class Adapter:
Expand All @@ -130,18 +130,18 @@ def set_epoch(epoch):

def __len__(self):
return len(self.iter)

def __iter__(self):
for data in self.iter:
x, y = data[0]['data'], data[0]['label']
x, y = data[0]["data"], data[0]["label"]
yield x, torch.squeeze(y, dim=1).type(torch.LongTensor)

return Adapter(train_data)


def pytorch_fakedataset(folder, batch_size, num_workers):
train = FakeImageClassification((3, 224, 224), batch_size, 60)

return torch.utils.data.DataLoader(
train,
batch_size=batch_size,
Expand All @@ -152,7 +152,9 @@ def pytorch_fakedataset(folder, batch_size, num_workers):


def image_transforms():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
data_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(224),
Expand All @@ -163,11 +165,9 @@ def image_transforms():
)
return data_transforms

def pytorch(folder, batch_size, num_workers, distributed=False):
train = datasets.ImageFolder(
folder,
image_transforms()
)

def pytorch(folder, batch_size, num_workers, distributed=False, epochs=60):
train = datasets.ImageFolder(folder, image_transforms())

kwargs = {"shuffle": True}
if distributed:
Expand All @@ -179,9 +179,7 @@ def pytorch(folder, batch_size, num_workers, distributed=False):
# we reduce the standard deviation
if False:
kwargs["sampler"] = torch.utils.data.RandomSampler(
train,
replacement=True,
num_samples=len(train) * args.epochs
train, replacement=True, num_samples=len(train) * epochs
)
kwargs["shuffle"] = False

Expand All @@ -197,15 +195,13 @@ def pytorch(folder, batch_size, num_workers, distributed=False):
def synthetic(model, batch_size, fixed_batch):
return SyntheticData(
tensors=generate_tensor_classification(
model,
batch_size,
(3, 244, 244),
device=accelerator.fetch_device(0)
model, batch_size, (3, 244, 244), device=accelerator.fetch_device(0)
),
n=1000,
fixed_batch=fixed_batch,
)


def synthetic_fixed(*args):
return synthetic(*args, fixed_batch=True)

Expand All @@ -216,20 +212,28 @@ def synthetic_random(*args):

def dataloader_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--batch-size", type=int, default=16,
"--batch-size",
type=int,
default=16,
help="input batch size for training (default: 16)",
)
parser.add_argument(
"--loader", type=str, help="Dataloader implementation (dali, pytorch, synthetic_fixed, synthetic_random)",
default="pytorch"
"--loader",
type=str,
help="Dataloader implementation (dali, pytorch, synthetic_fixed, synthetic_random)",
default="pytorch",
)
parser.add_argument(
"--num-workers", type=int, default=8,
"--num-workers",
type=int,
default=8,
help="number of workers for data loading",
)
parser.add_argument(
"--data", type=str, default=os.environ.get("MILABENCH_DIR_DATA", None),
help="data directory"
parser.add_argument(
"--data",
type=str,
default=os.environ.get("MILABENCH_DIR_DATA", None),
help="data directory",
)


Expand All @@ -243,24 +247,14 @@ def data_folder(args):

def imagenet_dataloader(args, model, rank=0, world_size=1):
if args.loader == "synthetic_random":
return synthetic(
model=model,
batch_size=args.batch_size,
fixed_batch=False
)

return synthetic(model=model, batch_size=args.batch_size, fixed_batch=False)

if args.loader == "synthetic_fixed":
return synthetic(
model=model,
batch_size=args.batch_size,
fixed_batch=True
)

return synthetic(model=model, batch_size=args.batch_size, fixed_batch=True)

if args.loader == "pytorch_fakedataset":
return pytorch_fakedataset(
None,
batch_size=args.batch_size,
num_workers=args.num_workers
None, batch_size=args.batch_size, num_workers=args.num_workers
)

folder = os.path.join(data_folder(args), "train")
Expand Down
8 changes: 4 additions & 4 deletions benchmate/benchmate/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
from collections import defaultdict

Expand All @@ -11,13 +10,15 @@ def transform_images(transform_x, transform_y=no_transform):
def _(args):
print(args)
return transform_x(args[0]), transform_y(args[1])

return _


def transform_celebA(transform_x):
def _(args):
print(args)
return transform_x(args["image"])

return _


Expand All @@ -33,7 +34,6 @@ def __getitem__(self, item):
return self.transforms(self.dataset[item])



class ImageNetAsFrames:
def __init__(self, folder) -> None:
self.clip = defaultdict(list)
Expand All @@ -42,9 +42,9 @@ def __init__(self, folder) -> None:
video = self.clip[clip_id]
for frame in files:
video.append(frame)

def __getitem__(self, item):
return self.clip[item]

def __len__(self):
return len(self.clip)
Loading

0 comments on commit 4cdc2d3

Please sign in to comment.