Skip to content

Commit

Permalink
[Squeeze] Add MoveSqueezePastMatMul needed by depth-wise convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
iksnagreb committed Jan 21, 2025
1 parent bfc70f1 commit 712ee5a
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/finn/transformation/streamline/reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,3 +1321,68 @@ def apply(self, model: ModelWrapper): # noqa
# Return the transformed model and indicate whether the graph
# actually has been transformed
return model, graph_modified


# Moves a Squeeze operation past MatMul
# TODO: extend to all operations invariant to or compatible with squeezing
class MoveSqueezePastMatMul(Transformation):
# Applies the transform to a whole model graph
def apply(self, model: ModelWrapper): # noqa
# Get the model graph out of the model wrapper object
graph = model.graph
# Keep track of whether the graph has been modified
graph_modified = False
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# Applies to Squeeze operation types
if node.op_type == "Squeeze":
# Currently does not handle fork- or join-nodes
if model.is_fork_node(node) or model.is_join_node(node):
# Softly skip this node
continue
# As this is not a fork-node, there can be at most one successor
successor = model.find_direct_successors(node)
# If Squeeze is the final operation in the graph, there might
# be no successor
if successor is None:
# Softly skip this node
continue
# Now there is exactly one successor which needs to be extracted
# from the list
successor = successor[0]
# Applies to MatMul
# TODO: Check behavior for multi-dimensional and potentially
# broadcasting MatMuls...
if successor.op_type in {"MatMul"}:
# Get names of all tensors involved in # noqa: Duplicate
# connecting the nodes
inp = node.input[0] # noqa: Duplicate
mid = node.output[0]
out = successor.output[0]
# Rewire the graph to feed original into the MultiThreshold
# node first
successor.input[0] = inp
# Repurpose the middle tensor for the output of the
# MultiThreshold
successor.output[0] = mid
# The Squeeze operator now gets the middle tensor as its
# input
node.input[0] = mid
# Squeeze now produces the original output tensor
node.output[0] = out
# Delete the shape annotation of the connecting tensors
# to be re-done later
model.set_tensor_shape(mid, None)
model.set_tensor_shape(out, None)
# Track whether the graph has been modified, never
# resets to False
graph_modified = True
# Break the loop after deleting shape annotations to
# immediately re-do these before changing the next
# operator
break
# Need to redo the shape inference after potentially deleting them
model = model.transform(InferShapes()) # noqa: Shadows model
# Return the transformed model and indicate whether the graph
# actually has been transformed
return model, graph_modified

0 comments on commit 712ee5a

Please sign in to comment.