From c2905f7cdff4dc4f620a9863042df0b8aeff8d75 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Tue, 21 Jan 2025 15:12:12 +0100 Subject: [PATCH] Make activation handler guess the layout based on tensor rank if missing --- .../qonnx/qonnx_activation_handlers.py | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 36181e7a4..8085e5a8e 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -402,12 +402,24 @@ def _calculate_thresholds(self): else: thresholds[c][t] = step / selu_scale - # First try to consider the tensor layout of the input for determining - # the number of output channels + # Get the shape of the input (should also be the output) tensor + # Note: Querying the input is more safe as we do not want to + # propagate shapes backwards by accident. + shape = self._model.get_tensor_shape(self._q_node.input[0]) # noqa + # First try to consider the tensor layout of the input for + # determining the number of output channels layout = self._model.get_tensor_layout(self._q_node.input[0]) - # If there is a layout annotation, use this to determine the index of - # the channel dimension - if layout is not None and "C" in layout: + # If there is no layout annotation, guess based on rank of the + # tensor + # TODO: No support for Rank >= 5 + if layout is None and len(shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + layout = rank_to_layout[len(shape)] + # If there is a layout annotation, use this to determine the index + # of the channel dimension + if layout is not None and "C" in layout: # noqa: Duplicate # Lookup the index in list cdim = layout.index("C") # If no layout has been annotated or there is no channel dimension, fall @@ -570,12 +582,24 @@ def _calculate_thresholds(self): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t + # Get the shape of the input (should also be the output) tensor + # Note: Querying the input is more safe as we do not want to + # propagate shapes backwards by accident. + shape = self._model.get_tensor_shape(self._q_node.input[0]) # First try to consider the tensor layout of the input for # determining the number of output channels - layout = self._model.get_tensor_layout(self._q_node.input[0]) + layout = self._model.get_tensor_layout(self._q_node.input[0]) # noqa + # If there is no layout annotation, guess based on rank of the + # tensor + # TODO: No support for Rank >= 5 + if layout is None and len(shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + layout = rank_to_layout[len(shape)] # If there is a layout annotation, use this to determine the index # of the channel dimension - if layout is not None and "C" in layout: + if layout is not None and "C" in layout: # noqa: Duplicate # Lookup the index in list cdim = layout.index("C") # If no layout has been annotated or there is no channel dimension,