diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 936106884..bf780a26c 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -39,9 +39,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo class quant_inference_mode: - def __init__(self, model, cache_quant_weight=False, enabled=True): + def __init__(self, model, cache_quant_weight=False, delete_injector=False, enabled=True): self.model = model self.enabled = enabled + self.delete_injector = delete_injector + self.injector_reference = dict() self.cache_quant_weight = cache_quant_weight self.export_manager = InferenceManager self.hook_list = [] @@ -74,6 +76,10 @@ def __exit__(self, type, value, traceback): lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) InferenceManager.set_export_mode(self.model, enabled=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + if self.delete_injector: + for m in self.model.modules(): + if m in self.injector_reference: + m.quant_injector = self.injector_reference[m] def hook(self, module, inp, out): # After one forward pass with caching enabled, we can: @@ -85,6 +91,11 @@ def hook(self, module, inp, out): self.model.apply(InferenceManager.set_export_handler) InferenceManager.set_export_mode(self.model, enabled=True) self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + if self.delete_injector: + for m in self.model.modules(): + if hasattr(m, 'quant_injector'): + self.injector_reference[m] = m.quant_injector + del m.quant_injector # Inheritance from BaseManager is not techincally needed