diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index b26de8bd79..bfaddfc47b 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -34,7 +34,9 @@ def _get_gamma_node(self, output_node): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - partitions = get_source_partitions(graph, [torch.nn.RMSNorm]) + partitions = get_source_partitions( + graph, [torch.nn.RMSNorm, torch.ops.aten.rms_norm.default] + ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: input_len = len(src_partition.input_nodes)