Skip to content

Commit

Permalink
Add data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adosar committed May 25, 2024
1 parent 869f984 commit 0fa1e53
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
File renamed without changes.
65 changes: 65 additions & 0 deletions src/moxel/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
r"""
Write the docstring of the module.
"""

import json
import torch
from torch.utils.data import Dataset, random_split
from . utils import load_json


def prepare_data(source, split_ratio=(0.8, 0.1, 0.1), seed=1):
r"""
Split a set of materials into train, validation and test sets.
.. warning::
* You should use this function **after** :func:`utils.batch_clean`.
* No directory is created by :func:`prepare_data`. **All ``.json``
files are stored under the directory containing ``source``**.
Before the split::
voxels_data
├──clean_voxels.npy
└──clean_names.json
After the split::
voxels_data
├──clean_voxels.npy
├──clean_names.json
├──train.json
├──validation.json
└──test.json
Each ``.json`` file stores the indices of ``clean_voxels.npy`` that will be
used for training, validation and testing.
Parameters
----------
source: str
Pathname to the file holding the names of the materials
(``clean_names.json``).
split_ratio: sequence, default=(0.8, 0.1, 0.1)
The sizes or fractions of splits to be produced.
* ``split_ratio[0] == train``.
* ``split_ratio[1] == validation``.
* ``split_ratio[2] == test``.
seed : int, default=1
Controls the randomness of the ``rng`` used for splitting.
"""
rng = torch.Generator().manual_seed(seed)
indices = range(len(load_json(source)))

train, val, test = random_split(indices, split_ratio, generator=rng)

for split, mode in zip((train, val, test), ('train', 'validation', 'test')):
mode_indices = list(split)
with open(os.path.join(path, f'{mode}.json'), 'w') as fhand:
json.dump(mode_indices, fhand, indent=4)

print('\033[32mData preparation completed!\033[0m')


class VoxelsDataset(Dataset):
...
Empty file added tests/test_data.py
Empty file.

0 comments on commit 0fa1e53

Please sign in to comment.