Skip to content

Commit

Permalink
force output variable name in check_data_type
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Jan 29, 2024
1 parent 9bf91f3 commit db9a887
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions pygmtools/neural_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def pca_gm(feat1, feat2, A1, A2, n1=None, n2=None,
backend = pygmtools.BACKEND
non_batched_input = False
if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object
for _ in (feat1, feat2, A1, A2):
_check_data_type(_, backend)
for var, name in ((feat1, 'feat1'), (feat2, 'feat2'), (A1, 'A1'), (A2, 'A2')):
_check_data_type(var, name, backend)

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
Expand Down Expand Up @@ -587,10 +587,8 @@ def ipca_gm(feat1, feat2, A1, A2, n1=None, n2=None,
backend = pygmtools.BACKEND
non_batched_input = False
if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object
_check_data_type(feat1, 'feat1', backend)
_check_data_type(feat2, 'feat2', backend)
_check_data_type(A1, 'A1', backend)
_check_data_type(A2, 'A2', backend)
for var, name in ((feat1, 'feat1'), (feat2, 'feat2'), (A1, 'A1'), (A2, 'A2')):
_check_data_type(var, name, backend)

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
Expand Down Expand Up @@ -907,12 +905,9 @@ def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1=None, n2=None
backend = pygmtools.BACKEND
non_batched_input = False
if feat_node1 is not None: # if feat_node1 is None, this function skips the forward pass and only returns a network object
_check_data_type(feat_node1, 'feat_node1', backend)
_check_data_type(feat_node2, 'feat_node2', backend)
_check_data_type(A1, 'A1', backend)
_check_data_type(A2, 'A2', backend)
_check_data_type(feat_edge1, 'feat_edge1', backend)
_check_data_type(feat_edge2, 'feat_edge2', backend)
for var, name in ((feat_node1, 'feat_node1'), (feat_node2, 'feat_node2'), (A1, 'A1'), (A2, 'A2'),
(feat_edge1, 'feat_edge1'), (feat_edge2, 'feat_edge2')):
_check_data_type(var, name, backend)

if all([_check_shape(_, 2, backend) for _ in (feat_node1, feat_node2, A1, A2)]) \
and all([_check_shape(_, 3, backend) for _ in (feat_edge1, feat_edge2)]):
Expand Down Expand Up @@ -1444,8 +1439,8 @@ def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=6
backend = pygmtools.BACKEND
non_batched_input = False
if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object
for _ in (feat1, feat2, A1, A2):
_check_data_type(_, backend)
for var, name in ((feat1, 'feat1'), (feat2, 'feat2'), (A1, 'A1'), (A2, 'A2')):
_check_data_type(var, name, backend)

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
Expand Down

0 comments on commit db9a887

Please sign in to comment.