diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17e0b7..74ecae3ac8d98 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -524,6 +524,7 @@ static LogicalResult rewriteMonomorphizedFuncClone( auto newOp = OpBuilder(op).create( op.getLoc(), op.getType(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName()); + newOp->setAttr("FXOutputName", op->getAttr("FXOutputName")); op.replaceAllUsesWith(&*newOp); } toErase.push_back(op); diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c9a3a..1d7900ab2b7af 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -375,8 +375,10 @@ class InlineGlobalSlotsPass getBackwardSliceIncludingRoot(initialValue); IRMapping mapping; OpBuilder builder(op); - for (Operation *opInSlice : slice) - builder.clone(*opInSlice, mapping); + for (Operation *opInSlice : slice) { + auto clonedOp = builder.clone(*opInSlice, mapping); + clonedOp->setAttr("FXOutputName", op->getAttr("FXOutputName")); + } auto inlinedInitialValue = mapping.lookup(initialValue); inlinedInitialValue = Torch::adjustStaticInformation( builder, op.getLoc(), inlinedInitialValue, op.getType(), diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 15cffedbe8342..5905662c5faee 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -84,6 +84,21 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); + + auto outs = node->outputs(); + auto output = outs.size() == 1 ? outs[0] : nullptr; + + auto addFxOutputNameAttr = [&](MlirOperation& operation) { + if (output && output->hasDebugName()) { + std::string name = output->debugName(); + size_t len = name.size(); + if (len > 2 && name[len-2] == '.' && name[len-1] == '1') + name = name.substr(0, len-2); + auto strAttr = mlirStringAttrGet(context, toMlirStringRef(name)); + mlirOperationSetAttributeByName(operation, toMlirStringRef("FXOutputName"), strAttr); + } + }; + auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); @@ -91,6 +106,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, appendToBlock, opName, loc, getMlirTypesFromValues(loc, node->outputs(), importOptions), t ? t(mappedInputs) : mappedInputs); + addFxOutputNameAttr(operation); mapResults(node, operation); }; @@ -102,6 +118,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValues(node->inputs()), toMlirNamedAttribute(attrName.c_str(), attr)); + addFxOutputNameAttr(operation); mapResults(node, operation); }; @@ -112,6 +129,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, appendToBlock, loc, node->schema(), getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValues(node->inputs())); + addFxOutputNameAttr(operation); mapResults(node, operation); return; }