Skip to content

Commit

Permalink
better type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcla committed Feb 17, 2023
1 parent 3f81d7c commit b3079c8
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions h5max/h5max.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
'coo': sparse.coo_matrix,
# 'bsr': sparse.bsr_matrix, <- could be implemented, need extra attribute to describe data shape
# 'dia': sparse.dia_matrix, <- could be implemented, need extra attribute to describe data shape
# 'dok': sparse.dok_matrix, <- seems not feasible
# 'lil': sparse.lil_matrix, <- seems not feasible
# 'dok': sparse.dok_matrix, <- does not seem feasible
# 'lil': sparse.lil_matrix, <- does not seem feasible
}

S = TypeVar("S", *list(format_dict.keys()))
type_dict = {
'csr': sparse._csr.csr_matrix,
'csc': sparse._csc.csc_matrix,
'coo': sparse._coo.coo_matrix,
'bsr': sparse._bsr.bsr_matrix,
'dia': sparse._dia.dia_matrix,
'dok': sparse._dok.dok_matrix,
'lil': sparse._lil.lil_matrix,
}

format_attr_dict = {
'csr': ['data', 'indices', 'indptr', 'shape'],
Expand All @@ -24,6 +32,8 @@
'dia': ['data', 'offsets', 'shape'],
}

S = TypeVar("S", *list(format_dict.keys()))

def store_sparse(
f: Union[h5py._hl.group.Group, h5py._hl.files.File],
data: Union[np.ndarray, List[np.ndarray], S, List[S]],
Expand All @@ -43,9 +53,9 @@ def store_sparse(
overwrite (bool, optional): whether to overwrite existing nodes by default or raise an error.
Defaults to False.
"""
if type(data) != list:
if type(data) not in [list, np.ndarray]:
data = [data]
transform = type(data[0]) == np.ndarray
transform = type(data[0]) != type_dict[format]
data_attr = {key: [] for key in format_attr_dict[format]}

for sample in data:
Expand Down

0 comments on commit b3079c8

Please sign in to comment.