From e40c2e7b9ae286aa2b8cfb07e5eebc1991e7896e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 4 Jan 2023 09:54:47 +0000 Subject: [PATCH] used download, checking and extraction logic consistent with other datasets --- tonic/datasets/s_mnist.py | 57 +++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/tonic/datasets/s_mnist.py b/tonic/datasets/s_mnist.py index 03705878..b34a8951 100644 --- a/tonic/datasets/s_mnist.py +++ b/tonic/datasets/s_mnist.py @@ -4,7 +4,7 @@ import numpy as np from tonic.dataset import Dataset -from tonic.download_utils import download_and_extract_archive +from tonic.download_utils import check_integrity, download_and_extract_archive from tonic.io import make_structured_array @@ -25,9 +25,6 @@ class SMNIST(Dataset): duplicate (bool): If True, emits two spikes per threshold crossing num_neurons (integer): How many neurons to use to encode thresholds(must be odd) dt (float): Duration(in microseconds) of each timestep - download (bool): Choose to download data or verify existing files. If True - and a file with the same name and correct hash is already - in the directory, download is automatically skipped. transform (callable, optional): A callable of transforms to apply to the data. target_transform (callable, optional): A callable of transforms to apply to the targets/labels. transforms (callable, optional): A callable of transforms that is applied to both data and @@ -43,6 +40,12 @@ class SMNIST(Dataset): train_labels_file = "train-labels-idx1-ubyte" test_images_file = "t10k-images-idx3-ubyte" test_labels_file = "t10k-labels-idx1-ubyte" + + train_images_md5 = "f68b3c2dcbeaaa9fbdd348bbdeb94873" + train_labels_md5 = "d53e105ee54ea40749a09fcbcd1e9432" + test_images_md5 = "9fb629c4189551a2d022fa330f9573f3" + test_labels_md5 = "ec29112dd5afa0611ce80d1b7f02629c" + dtype = np.dtype([("t", int), ("x", int), ("p", int)]) ordering = dtype.names @@ -66,7 +69,6 @@ def __init__( duplicate=True, num_neurons=99, dt=1000.0, - download=True, transform=None, target_transform=None, ): @@ -82,10 +84,18 @@ def __init__( if (num_neurons % 2) == 0: raise Exception("Number of neurons must be odd") - self.images_file = self.train_images_file if train else self.test_images_file - self.labels_file = self.train_labels_file if train else self.test_labels_file - - if download: + if train: + self.images_file = self.train_images_file + self.labels_file = self.train_labels_file + self.images_md5 = self.train_images_md5 + self.labels_md5 = self.train_labels_md5 + else: + self.images_file = self.test_images_file + self.labels_file = self.test_labels_file + self.images_md5 = self.test_images_md5 + self.labels_md5 = self.test_labels_md5 + + if not self._check_exists(): self.download() # Open images file @@ -179,7 +189,32 @@ def __len__(self): return self.image_data.shape[0] def download(self): - for f in [self.images_file, self.labels_file]: + for (f, m) in [(self.images_file, self.images_md5), + (self.labels_file, self.labels_md5)]: download_and_extract_archive( - self.base_url + f + ".gz", self.location_on_system, filename=f + ".gz" + self.base_url + f + ".gz", self.location_on_system, + filename=f + ".gz", md5=m ) + + def _are_labels_present(self) -> bool: + """Check if the label file is present on disk. + + No hashing. + """ + return check_integrity(os.path.join(self.location_on_system, + self.labels_file)) + + def _are_images_present(self) -> bool: + """Check if the images file is present on disk. + + No hashing. + """ + return check_integrity(os.path.join(self.location_on_system, + self.images_file)) + + + def _check_exists(self): + return ( + self._are_labels_present() + and self._are_images_present() + )