Skip to content

Commit

Permalink
Fix Test
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 10, 2023
1 parent edde30d commit ec08971
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions tests/brevitas_examples/test_jit_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch

from brevitas.utils.jit_utils import jit_patches_generator
from brevitas_examples.bnn_pynq.models import model_with_cfg

FC_INPUT_SIZE = (1, 1, 28, 28)
Expand All @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits):
fc, _ = model_with_cfg(nname.lower(), pretrained=False)
fc.train(False)
input_tensor = torch.randn(FC_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(fc, input_tensor)
out_traced = traced_model(input_tensor)
out = fc(input_tensor)
Expand All @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits):
cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
cnv.train(False)
input_tensor = torch.randn(CNV_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(cnv, input_tensor)
out_traced = traced_model(input_tensor)
out = cnv(input_tensor)
Expand Down

0 comments on commit ec08971

Please sign in to comment.