Skip to content

Commit

Permalink
backward compatibility in ms tensor check
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Jan 26, 2024
1 parent e311400 commit 7997f13
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pygmtools/mindspore_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,15 @@ def _check_data_type(input: mindspore.Tensor, var_name, raise_err):
"""
mindspore implementation of _check_data_type
"""
if raise_err:
if type(input) is not mindspore.Tensor or type(input) is not mindspore.common._stub_tensor.StubTensor:
raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, '
f'but got {type(input)}.')
return type(input) is mindspore.Tensor
ms_types = [mindspore.Tensor]
if hasattr(mindspore.common, '_stub_tensor'): # MS tensor may be automatically transformed to StubTensor
ms_types += [mindspore.common._stub_tensor.StubTensor]
is_tensor = any([type(input) is t for t in ms_types])

if raise_err and not is_tensor:
raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, '
f'but got {type(input)}.')
return is_tensor


def _check_shape(input, dim_num):
Expand Down

0 comments on commit 7997f13

Please sign in to comment.