Skip to content

Commit

Permalink
Unified flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 17, 2023
1 parent bb8f0ad commit 354ff29
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 17 deletions.
12 changes: 3 additions & 9 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Int32Bias
from brevitas.quant import Uint8ActPerTensorFloat
from brevitas.utils.quant_utils import _is_proxy_in_export_mode

QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]]
QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]]
Expand Down Expand Up @@ -674,14 +673,9 @@ def forward(self, inp, hidden_state, cell_state):
quant_weight_io, quant_weight_ho, quant_bias_output = self.gate_params_fwd(
self.output_gate_params, quant_input)
if self.cifg:
if _is_proxy_in_export_mode(self.input_gate_params):
# Avoid dealing with None and set it the same as the forget one
quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd(
self.input_gate_params, quant_input)
else:
quant_weight_if = quant_weight_ii
quant_weight_hf = quant_weight_hi
quant_bias_forget = quant_bias_input
# Avoid dealing with None and set it the same as the forget one
quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd(
self.input_gate_params, quant_input)
else:
quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd(
self.forget_gate_params, quant_input)
Expand Down
8 changes: 0 additions & 8 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
from brevitas.core.function_wrapper import *
from brevitas.core.quant import RescalingIntQuant
from brevitas.inject.enum import FloatToIntImplType
from brevitas.proxy.quant_proxy import QuantProxyFromInjector


def _is_proxy_in_export_mode(model):
for submodule in model.modules():
if isinstance(submodule, QuantProxyFromInjector) and hasattr(submodule, 'export_mode'):
return submodule.export_mode
return False


def has_learned_weight_bit_width(module):
Expand Down

0 comments on commit 354ff29

Please sign in to comment.