-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar.py
65 lines (50 loc) · 2.58 KB
/
cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Normalization values are from:
https://github.com/zhenxun-zhuang/AdamW-Scale-free/blob/main/src/data_loader.py
"""
import torchvision
from torchvision import transforms
def get_cifar10(split, path):
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])
if split == 'train':
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
])
ds = torchvision.datasets.CIFAR10(root=path, train=True,
download=True,
transform=transform_train
)
else:
transform_val = transforms.Compose([transforms.ToTensor(),
normalize,
])
ds = torchvision.datasets.CIFAR10(root=path, train=False,
download=True,
transform=transform_val
)
return ds
def get_cifar100(split, path):
normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409],
std=[0.2673, 0.2564, 0.2762])
if split == 'train':
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
])
ds = torchvision.datasets.CIFAR100(root=path, train=True,
download=True,
transform=transform_train
)
else:
transform_val = transforms.Compose([transforms.ToTensor(),
normalize,
])
ds = torchvision.datasets.CIFAR100(root=path, train=False,
download=True,
transform=transform_val
)
return ds