Skip to content

Commit

Permalink
fix transforms in validation splits notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jun 21, 2019
1 parent f304ba9 commit 1dae767
Showing 1 changed file with 84 additions and 11 deletions.
95 changes: 84 additions & 11 deletions pytorch_ipynb/mechanics/validation-splits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,68 @@
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"9920512it [00:02, 4390618.70it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"32768it [00:00, 293812.98it/s] \n",
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
"Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1654784it [00:00, 2762205.03it/s] \n",
"8192it [00:00, 124866.40it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
"Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
"Processing...\n",
"Done!\n",
"Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
"Image label dimensions: torch.Size([64])\n"
]
Expand All @@ -85,11 +143,13 @@
"\n",
"\n",
"train_loader = DataLoader(dataset=train_dataset, \n",
" batch_size=BATCH_SIZE, \n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" shuffle=True)\n",
"\n",
"test_loader = DataLoader(dataset=test_dataset, \n",
" batch_size=BATCH_SIZE, \n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" shuffle=False)\n",
"\n",
"# Checking the dataset\n",
Expand Down Expand Up @@ -167,11 +227,13 @@
"outputs": [],
"source": [
"train_loader = DataLoader(dataset=train_dataset, \n",
" batch_size=BATCH_SIZE, \n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" shuffle=True)\n",
"\n",
"valid_loader = DataLoader(dataset=valid_dataset, \n",
" batch_size=BATCH_SIZE, \n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" shuffle=False)"
]
},
Expand Down Expand Up @@ -206,8 +268,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1, 0, 6, 5, 9, 5, 2, 6, 0, 2])\n",
"tensor([7, 7, 8, 6, 6, 2, 9, 3, 3, 2])\n"
"tensor([1, 7, 2, 4, 7, 7, 8, 4, 0, 5])\n",
"tensor([5, 5, 6, 4, 2, 3, 8, 0, 7, 5])\n"
]
}
],
Expand Down Expand Up @@ -306,24 +368,35 @@
"\n",
"train_dataset = datasets.MNIST(root='data', \n",
" train=True, \n",
" transform=transforms.ToTensor(),\n",
" transform=training_transform,\n",
" download=True)\n",
"\n",
"# note that this is the same dataset as \"train_dataset\" above\n",
"# however, we can now choose a different transform method\n",
"valid_dataset = datasets.MNIST(root='data', \n",
" train=False, \n",
" transform=transforms.ToTensor(),\n",
" train=True, \n",
" transform=valid_transform,\n",
" download=False)\n",
"\n",
"test_dataset = datasets.MNIST(root='data', \n",
" train=False, \n",
" transform=valid_transform,\n",
" download=False)\n",
"\n",
"train_loader = DataLoader(train_dataset,\n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" sampler=train_sampler)\n",
"\n",
"valid_loader = DataLoader(valid_dataset,\n",
" batch_size=BATCH_SIZE,\n",
" sampler=valid_sampler)"
" num_workers=4,\n",
" sampler=valid_sampler)\n",
"\n",
"test_loader = DataLoader(dataset=test_dataset, \n",
" batch_size=BATCH_SIZE,\n",
" num_workers=4,\n",
" shuffle=False)"
]
},
{
Expand Down Expand Up @@ -425,7 +498,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.6.8"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 1dae767

Please sign in to comment.