Skip to content

Commit

Permalink
Merge pull request #166 from fastmachinelearning/feature/insert_identity
Browse files Browse the repository at this point in the history
Insert Identity nodes on given tensor and top-level graph inputs
  • Loading branch information
maltanar authored Jan 13, 2025
2 parents 51965ab + 11a1373 commit 530f3e2
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
python-version: ["3.10"]

steps:
- name: Checkout
Expand Down
127 changes: 127 additions & 0 deletions src/qonnx/transformation/insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from onnx import helper as oh

from qonnx.transformation.base import Transformation
from qonnx.transformation.general import SortGraph


class InsertIdentityOnAllTopLevelIO(Transformation):
"""
Transformation that inserts an Identity node on all top-level inputs and outputs
of the ONNX graph. This can be useful before calling transformations that do not
gracefully handle edge cases where transformed tensors are top-level inputs or outputs.
"""

def apply(self, model):
graph = model.graph
for inp in graph.input:
model = model.transform(InsertIdentity(inp.name, "consumer"))
for out in graph.output:
model = model.transform(InsertIdentity(out.name, "producer"))
return model, False


class InsertIdentity(Transformation):
"""
Transformation that inserts an Identity node in the ONNX graph. For edge cases
where tensor_name is a graph input and producer_or_consumer is 'producer', the
graph input will be replaced with a new tensor name <old_name>_identity. For the
edge case where tensor_name is a graph output and producer_or_consumer is 'consumer',
the graph output will be replaced with a new tensor name <old_name>_identity
Parameters:
tensor_name (str): The name of the tensor where the Identity node will be inserted.
producer_or_consumer (str): Indicates whether the Identity node will be inserted before ('producer')
or after ('consumer') the tensor_name.
"""

def __init__(self, tensor_name, producer_or_consumer):
super().__init__()
self.tensor_name = tensor_name
self.producer_or_consumer = producer_or_consumer

def insert_node_before(self, model, tensor):
graph = model.graph
new_tensor_name = tensor + "_identity"
# rewire the tensor's producer to the new tensor
prod = model.find_producer(tensor)
if prod is not None:
prod_outlist = list(prod.output)
prod.output[prod_outlist.index(tensor)] = new_tensor_name
else:
# if the tensor is an input tensor (top-level)
# update the graph's input
top_inp_names = [inp.name for inp in graph.input]
graph.input[top_inp_names.index(tensor)].name = new_tensor_name
# Create a new node
identity_node = oh.make_node("Identity", [new_tensor_name], [tensor])
# Insert the new node
# we do this late in the process to avoid affecting find_producer
graph.node.append(identity_node)

def insert_node_after(self, model, tensor):
graph = model.graph
new_tensor_name = tensor + "_identity"
# rewire the tensor's consumers to the new node
consumers = model.find_consumers(tensor)
if consumers == []:
# if the tensor is an output tensor (top-level)
# find the graph's output and replace it with the new name
top_out_name = [out.name for out in graph.output]
graph.output[top_out_name.index(tensor)].name = new_tensor_name
# TODO what if feeding multiple graph outputs? seems unlikely...
else:
for consumer in consumers:
consumer_inplist = list(consumer.input)
consumer.input[consumer_inplist.index(tensor)] = new_tensor_name
# Create a new node
# we do this late in the process to avoid affecting find_consumers
identity_node = oh.make_node("Identity", [tensor], [new_tensor_name])
# Insert the new node
graph.node.append(identity_node)

def apply(self, model):
# Find the tensor in the graph
tshape = model.get_tensor_shape(self.tensor_name)
if tshape is None:
raise ValueError(f"Tensor '{self.tensor_name}' not found in the graph.")
tensor = self.tensor_name
# Insert the Identity node before or after the specified tensor
if self.producer_or_consumer == "producer":
self.insert_node_before(model, tensor)
elif self.producer_or_consumer == "consumer":
self.insert_node_after(model, tensor)
else:
raise ValueError("producer_or_consumer must be either 'producer' or 'consumer'.")

model = model.transform(SortGraph())
# important to return run_again=False to avoid infinite loop
return (model, False)
128 changes: 128 additions & 0 deletions tests/transformation/test_insert_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

from onnx import TensorProto
from onnx import helper as oh

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.insert import InsertIdentity, InsertIdentityOnAllTopLevelIO


@pytest.fixture
def simple_model():
# Create a simple ONNX model for testing
input_tensor = oh.make_tensor_value_info("input", TensorProto.FLOAT, [1, 2])
output_tensor = oh.make_tensor_value_info("output", TensorProto.FLOAT, [1, 2])
node1 = oh.make_node("Relu", ["input"], ["intermediate"])
node2 = oh.make_node("Relu", ["intermediate"], ["output"])
graph = oh.make_graph([node1, node2], "test_graph", [input_tensor], [output_tensor])
model = ModelWrapper(oh.make_model(graph))
model = model.transform(InferShapes())
return model


def test_insert_identity_on_all_top_level_io(simple_model):
orig_top_inp_names = [inp.name for inp in simple_model.graph.input]
orig_top_out_names = [out.name for out in simple_model.graph.output]
model = simple_model.transform(InsertIdentityOnAllTopLevelIO())
for inp in orig_top_inp_names:
assert model.find_consumer(inp).op_type == "Identity"
for out in orig_top_out_names:
assert model.find_producer(out).op_type == "Identity"
assert orig_top_inp_names == [inp.name for inp in model.graph.input]
assert orig_top_out_names == [out.name for out in model.graph.output]


def test_insert_identity_before_input(simple_model):
# Apply the transformation
transformation = InsertIdentity("input", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("input")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_input(simple_model):
# Apply the transformation
transformation = InsertIdentity("input", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("input")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_before_intermediate(simple_model):
# Apply the transformation
transformation = InsertIdentity("intermediate", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("intermediate")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_intermediate(simple_model):
# Apply the transformation
transformation = InsertIdentity("intermediate", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("intermediate")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_before_output(simple_model):
# Apply the transformation
transformation = InsertIdentity("output", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("output")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_output(simple_model):
# Apply the transformation
transformation = InsertIdentity("output", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("output")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_tensor_not_found(simple_model):
# Apply the transformation with a non-existent tensor
transformation = InsertIdentity("non_existent_tensor", "producer")
with pytest.raises(ValueError):
simple_model.transform(transformation)

0 comments on commit 530f3e2

Please sign in to comment.