Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2023
1 parent bbeca9f commit 2b6e81d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_resnet18_equalization():
# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)

regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs.keys()]))
regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
equalized_layers = set()
for r in resnet_18_regions:
Expand All @@ -45,9 +45,9 @@ def test_resnet18_equalization():

# Check that we found all the expected regions
for region, expected_region in zip(regions, resnet_18_regions):
srcs = list(region.srcs)
srcs = region.srcs_names
sources_check = set(srcs) == set(expected_region[0])
sinks = list(region.sinks)
sinks = region.sinks_names
sinks_check = set(sinks) == set(expected_region[1])
assert sources_check
assert sinks_check
Expand Down

0 comments on commit 2b6e81d

Please sign in to comment.