Skip to content

Commit

Permalink
Fromat refactor with black
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeBailey181 committed Mar 22, 2023
1 parent c62517b commit d1c20e9
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 127 deletions.
60 changes: 44 additions & 16 deletions k_dropout/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,44 +32,72 @@ def get_mnist(train_batch_size=64, test_batch_size=1000, num_workers=2):
test_batch_size -- size of batches in returned test set dataloader
"""

transform = transforms.Compose([
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # to mean 0, std 1
transforms.Lambda(lambda x: torch.flatten(x)),
])
]
)

train_set = datasets.MNIST(
root=DATASET_ROOT, train=True, download=True, transform=transform)
root=DATASET_ROOT, train=True, download=True, transform=transform
)
test_set = datasets.MNIST(
root=DATASET_ROOT, train=False, download=True, transform=transform)
root=DATASET_ROOT, train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=train_batch_size, shuffle=True, drop_last=True,
pin_memory=True, num_workers=num_workers)
train_set,
batch_size=train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_workers,
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=test_batch_size, shuffle=False, drop_last=True,
pin_memory=True, num_workers=num_workers)
test_set,
batch_size=test_batch_size,
shuffle=False,
drop_last=True,
pin_memory=True,
num_workers=num_workers,
)

return train_loader, test_loader


def get_cifar10(train_batch_size=4, test_batch_size=4, num_workers=2):
transform = transforms.Compose([
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # to [-1, 1]
transforms.Lambda(lambda x: torch.flatten(x)),
])
]
)

train_set = datasets.CIFAR10(
root=DATASET_ROOT, train=True, download=True, transform=transform)
root=DATASET_ROOT, train=True, download=True, transform=transform
)
test_set = datasets.CIFAR10(
root=DATASET_ROOT, train=False, download=True, transform=transform)
root=DATASET_ROOT, train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=train_batch_size, shuffle=True, drop_last=True,
pin_memory=True, num_workers=num_workers)
train_set,
batch_size=train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_workers,
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=test_batch_size, shuffle=False, drop_last=True,
pin_memory=True, num_workers=num_workers)
test_set,
batch_size=test_batch_size,
shuffle=False,
drop_last=True,
pin_memory=True,
num_workers=num_workers,
)

return train_loader, test_loader
9 changes: 5 additions & 4 deletions k_dropout/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SequentialKDropout(nn.Module):
m: number of masks to use per batch, defaults to 0 which will set m=batch_size for
any input.
"""

def __init__(self, k: int, p: float = 0.5, m=-1):
super(SequentialKDropout, self).__init__()
self.k = k
Expand Down Expand Up @@ -90,12 +91,12 @@ def forward(self, x: Tensor) -> Tensor:

if self.training:
g = torch.Generator(device=x.device)
seed_idxs = torch.randint(high=self.n_masks, size=(self.m,))
seed_idxs = torch.randint(high=self.n_masks, size=(self.m,))
gen_seeds = [self.mask_seeds[i] for i in seed_idxs]
masks = []

masks = []
for seed in gen_seeds:
g.manual_seed(seed)
g.manual_seed(seed)
masks.append(torch.rand(d, device=x.device, generator=g) >= self.p)

mask_block = torch.stack(masks)
Expand Down
2 changes: 1 addition & 1 deletion k_dropout/training_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def train_net(
wandb.log({"test_loss": test_loss, "test_acc": acc}, step=example_ct)

# final test of model
if test_set is not None and (epochs-1) not in test_losses:
if test_set is not None and (epochs - 1) not in test_losses:
if return_results:
test_loss, acc = test_net(net, test_set)
test_losses[epochs - 1] = test_loss
Expand Down
Loading

0 comments on commit d1c20e9

Please sign in to comment.