Skip to content

Commit

Permalink
Merge pull request fastai#218 from PiotrCzapla/master
Browse files Browse the repository at this point in the history
Make the assertion less restrictive + add some docs
  • Loading branch information
jph00 authored Mar 20, 2018
2 parents bcd7e48 + 31504f1 commit 7694eed
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions fastai/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def dict_source(folder, fnames, csv_labels, suffix='', continuous=False):
return full_names, label_arr, all_labels

class BaseDataset(Dataset):
"""An abstract class representing a fastai dataset, it extends torch.utils.data.Dataset."""
def __init__(self, transform=None):
self.transform = transform
self.n = self.get_n()
Expand All @@ -171,19 +172,39 @@ def get(self, tfm, x, y):
return (x,y) if tfm is None else tfm(x,y)

@abstractmethod
def get_n(self): raise NotImplementedError
def get_n(self):
"""Return number of elements in the dataset == len(self)."""
raise NotImplementedError

@abstractmethod
def get_c(self): raise NotImplementedError
def get_c(self):
"""Return number of classes in a dataset."""
raise NotImplementedError

@abstractmethod
def get_sz(self): raise NotImplementedError
def get_sz(self):
"""Return maximum size of an image in a dataset."""
raise NotImplementedError

@abstractmethod
def get_x(self, i): raise NotImplementedError
def get_x(self, i):
"""Return i-th example (image, wav, etc)."""
raise NotImplementedError

@abstractmethod
def get_y(self, i): raise NotImplementedError
def get_y(self, i):
"""Return i-th label."""
raise NotImplementedError

@property
def is_multi(self): return False
def is_multi(self):
"""Returns true if this data set contains multiple labels per sample."""
return False

@property
def is_reg(self): return False
def is_reg(self):
"""True if the data set is used to train regression models."""
return False

def open_image(fn):
""" Opens an image using OpenCV given the file path.
Expand All @@ -192,7 +213,7 @@ def open_image(fn):
fn: the file path of the image
Returns:
The numpy array representation of the image in the RGB format
The image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0
"""
flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR
if not os.path.exists(fn):
Expand Down Expand Up @@ -394,8 +415,7 @@ def from_paths(cls, path, bs=64, tfms=(None,None), trn_name='train', val_name='v
Returns:
ImageClassifierData
"""
assert isinstance(tfms[0], Transforms) and isinstance(tfms[1], Transforms), \
"please provide transformations for your train and validation sets"
assert not(tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
trn,val = [folder_source(path, o) for o in (trn_name, val_name)]
if test_name:
test = folder_source(path, test_name) if test_with_labels else read_dir(path, test_name)
Expand Down Expand Up @@ -448,6 +468,16 @@ def from_names_and_array(cls, path, fnames,y,classes, val_idxs=None, test_name=N
return cls(path, datasets, bs, num_workers, classes=classes)

def split_by_idx(idxs, *a):
"""
Split each array passed as *a, to a pair of arrays like this (elements selected by idxs, the remaining elements)
This can be used to split multiple arrays containing training data to validation and training set.
:param idxs [int]: list of indexes selected
:param a list: list of np.array, each array should have same amount of elements in the first dimension
:return: list of tuples, each containing a split of corresponding array from *a.
First element of each tuple is an array composed from elements selected by idxs,
second element is an array of remaining elements.
"""
mask = np.zeros(len(a[0]),dtype=bool)
mask[np.array(idxs)] = True
return [(o[mask],o[~mask]) for o in a]
Expand Down

0 comments on commit 7694eed

Please sign in to comment.