diff --git a/test/dataset_utils.py b/test/dataset_utils.py index 430d125..4b99a19 100644 --- a/test/dataset_utils.py +++ b/test/dataset_utils.py @@ -2,14 +2,18 @@ import unittest from typing import Any, Dict, Union from unittest.mock import patch - +import pytest import numpy as np +import os + +# Location of the files to be saved and extracted by the datasets during testing +TEST_LOCATION_ON_SYSTEM = "~/../../tmp" +TEST_LOCATION_ON_SYSTEM = os.path.expanduser(TEST_LOCATION_ON_SYSTEM) class DatasetTestCase(unittest.TestCase): DATASET_CLASS = None FEATURE_TYPES = None - _CHECK_FUNCTIONS = {"check_md5", "check_integrity", "check_exists"} _DOWNLOAD_EXTRACT_FUNCTIONS = { "download_url", @@ -41,8 +45,7 @@ def inject_fake_data( ) def create_dataset(self, inject_fake_data: bool = True, **kwargs: Any): - tmpdir = "/tmp/" - info = self._inject_fake_data(tmpdir) + info = self._inject_fake_data(TEST_LOCATION_ON_SYSTEM) if inject_fake_data: with patch.object(self.DATASET_CLASS, "_check_exists", return_value=True): @@ -81,7 +84,7 @@ def test_feature_types(self): assert len(data) == len(self.FEATURE_TYPES) assert len(target) == len(self.TARGET_TYPES) - for (data_piece, feature_type) in zip(data, self.FEATURE_TYPES): + for data_piece, feature_type in zip(data, self.FEATURE_TYPES): if type(data_piece) == np.ndarray: assert data_piece.dtype == feature_type else: @@ -93,6 +96,9 @@ def test_num_examples(self): @classmethod def setUpClass(cls): - cls.KWARGS.update({"save_to": "/tmp"}) - shutil.rmtree("/tmp/" + cls.DATASET_CLASS.__name__, ignore_errors=True) + cls.KWARGS.update({"save_to": TEST_LOCATION_ON_SYSTEM}) + shutil.rmtree( + f"{TEST_LOCATION_ON_SYSTEM}/" + cls.DATASET_CLASS.__name__, + ignore_errors=True, + ) super().setUpClass() diff --git a/test/torch_requirements.txt b/test/torch_requirements.txt index fd14302..00faaa7 100644 --- a/test/torch_requirements.txt +++ b/test/torch_requirements.txt @@ -1,5 +1,5 @@ --index-url https://download.pytorch.org/whl/cpu -torch==2.1.0 -torchaudio==2.1.0 -torchvision==0.16.0 -torchdata +torch==2.3.0 +torchaudio==2.3.0 +torchvision==0.18.0 +torchdata<=0.8.0 diff --git a/tonic/dataset.py b/tonic/dataset.py index db44a3b..4ec16d1 100644 --- a/tonic/dataset.py +++ b/tonic/dataset.py @@ -18,7 +18,7 @@ def __init__( target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ): - self.location_on_system = os.path.join(save_to, self.__class__.__name__) + self.location_on_system = os.path.join(os.path.expanduser(save_to), self.__class__.__name__) self.transform = transform self.target_transform = target_transform self.transforms = transforms diff --git a/tonic/download_utils.py b/tonic/download_utils.py index e613557..91386ed 100644 --- a/tonic/download_utils.py +++ b/tonic/download_utils.py @@ -159,9 +159,9 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No with zipfile.ZipFile( from_path, "r", - compression=_ZIP_COMPRESSION_MAP[compression] - if compression - else zipfile.ZIP_STORED, + compression=( + _ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ), ) as zip: zip.extractall(to_path) @@ -313,12 +313,10 @@ def download_and_extract_archive( md5: Optional[str] = None, remove_finished: bool = False, ) -> None: - download_root = os.path.expanduser(download_root) if extract_root is None: extract_root = download_root if not filename: filename = os.path.basename(url) - download_url(url, download_root, filename, md5) archive = os.path.join(download_root, filename)