From ace577b693e3259dc678192149bca49bd2193798 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 21 Jan 2025 11:26:54 -0600 Subject: [PATCH] [executorch][serialization] Serialize PTD files. Pull Request resolved: https://github.com/pytorch/executorch/pull/7270 Introduce top-level serialization file that calls: - serialize_pte_binary for PTE file - FlatTensor.serialize_tensors for PTD files. ghstack-source-id: 262004271 @exported-using-ghexport Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/) --------- Co-authored-by: lucylq --- exir/_serialize/TARGETS | 1 + exir/_serialize/_serialize.py | 91 ++++++++++++++++++++++++++++++++++ exir/program/TARGETS | 1 + exir/program/_program.py | 62 ++++++++++++++++------- extension/export_util/utils.py | 3 ++ 5 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 exir/_serialize/_serialize.py diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index cd6a4bc5a2..cc6f16d78d 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -33,6 +33,7 @@ runtime.python_library( "_dataclass.py", "_flatbuffer.py", "_program.py", + "_serialize.py", "data_serializer.py", "padding.py", ], diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py new file mode 100644 index 0000000000..c311274922 --- /dev/null +++ b/exir/_serialize/_serialize.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import Dict, Tuple + +from executorch.exir._serialize import _serialize_pte_binary + +from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize.data_serializer import ( + DataPayload, + DataSerializer, + TensorEntry, + TensorLayout, +) + +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.emit import EmitterOutput +from executorch.exir.schema import Tensor, TensorDataLocation + + +def serialize_for_executorch( + emitter_output: EmitterOutput, + config: ExecutorchBackendConfig, + data_serializer: DataSerializer, +) -> Tuple[Cord, Dict[str, Cord]]: + """Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files.""" + + # Serialize PTE file. + pte: Cord = _serialize_pte_binary( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + extract_delegate_segments=config.extract_delegate_segments, + segment_alignment=config.segment_alignment, + constant_tensor_alignment=config.constant_tensor_alignment, + delegate_alignment=config.delegate_alignment, + ) + + # Serialize PTD files. + ptd_files: Dict[str, Cord] = {} + + # Find all external tensors and organize into {fqn: TensorLayout}. + fqn_to_tensor_layout: Dict[str, TensorLayout] = {} + for plan in emitter_output.program.execution_plan: + for evalue in plan.values: + if isinstance(evalue.val, Tensor): + tensor = evalue.val + if ( + tensor.extra_tensor_info is not None + and tensor.extra_tensor_info.fully_qualified_name is not None + and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL + ): + fqn_to_tensor_layout[ + tensor.extra_tensor_info.fully_qualified_name + ] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order) + + if len(fqn_to_tensor_layout) > 0: + # emitter_output.external_constant_map contains the mapping from + # {file: {fqn: index into external_constant_buffer}} + # Contains the locations of the tensor buffers, and must be non-empty + # if there are external tensors to serialize. + assert emitter_output.external_constant_map is not None + for ( + filename, + fqn_to_index, + ) in ( + # pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`. + emitter_output.external_constant_map.items() + ): + # Create a TensorEntry for each external tensor. + fqn_to_tensor_entry: Dict[str, TensorEntry] = {} + for fqn, index in fqn_to_index.items(): + assert fqn in fqn_to_tensor_layout + fqn_to_tensor_entry[fqn] = TensorEntry( + buffer_index=index, + layout=fqn_to_tensor_layout[fqn], + ) + + ptd_files[filename] = data_serializer.serialize( + DataPayload( + buffers=emitter_output.external_constant_buffer, + fqn_to_tensor=fqn_to_tensor_entry, + ) + ) + + return pte, ptd_files diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 674d7baa35..33e417e732 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -44,6 +44,7 @@ python_library( "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", + "//executorch/extension/flat_tensor/serialize:serialize", ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 7dbf97a047..e8cee0b5da 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -9,12 +9,14 @@ import copy import io import logging +import os from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union import torch import torch._export -from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize._serialize import serialize_for_executorch +from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.partitioner import Partitioner @@ -59,6 +61,7 @@ EXIREdgeDialectVerifier, get_aten_verifier, ) +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( @@ -497,6 +500,7 @@ def __init__( ) self.exported_program = exir_exported_program.exported_program self._pte_data: Optional[Cord] = None + self._tensor_data: Optional[Dict[str, Cord]] = None self._buffer: Optional[bytes] = None self._emitter_output: Optional[EmitterOutput] = None self._emit_stacktrace: bool = emit_stacktrace @@ -504,16 +508,23 @@ def __init__( self._segment_alignment: int = segment_alignment self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment self._delegate_alignment: Optional[int] = delegate_alignment + self._data_serializer: DataSerializer = FlatTensorSerializer() + + def _get_emitter_output(self) -> EmitterOutput: + if self._emitter_output is None: + self._emitter_output = emit_program( + self.exported_program, self._emit_stacktrace + ) + return self._emitter_output def _get_pte_data(self) -> Cord: if self._pte_data is None: - self._pte_data = _serialize_pte_binary( - program=self.program, - extract_delegate_segments=self._extract_delegate_segments, - segment_alignment=self._segment_alignment, - constant_tensor_alignment=self._constant_tensor_alignment, - delegate_alignment=self._delegate_alignment, + self._pte_data, self._tensor_data = serialize_for_executorch( + self._get_emitter_output(), + ExecutorchBackendConfig(), + self._data_serializer, ) + assert self._pte_data is not None return self._pte_data @property @@ -532,11 +543,7 @@ def buffer(self) -> bytes: @property def program(self) -> Program: - if self._emitter_output is None: - self._emitter_output = emit_program( - self.exported_program, self._emit_stacktrace - ) - return self._emitter_output.program + return self._get_emitter_output().program @property def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: @@ -571,6 +578,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ self._get_pte_data().write_to_file(open_file) + def write_tensor_data_to_file(self, outdir) -> None: + """ + Writes the serialized ExecuTorch data files to the directory at `outdir`. + """ + assert self._tensor_data is not None + # pyre-ignore[16]: `Optional` has no attribute `items`. + for filename, cord in self._tensor_data.items(): + with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f: + logging.info(f"Writing data file to {filename}.ptd") + cord.write_to_file(f) + def _get_aten_to_edge_passes(config: EdgeCompileConfig): # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable @@ -1453,13 +1471,9 @@ def __init__( ) # Serialize emitter output, ready to be written to a file. - self._pte_data: Cord = _serialize_pte_binary( - program=self._emitter_output.program, - mutable_data=self._emitter_output.mutable_data, - extract_delegate_segments=backend_config.extract_delegate_segments, - segment_alignment=backend_config.segment_alignment, - constant_tensor_alignment=backend_config.constant_tensor_alignment, - delegate_alignment=backend_config.delegate_alignment, + self._data_serializer = FlatTensorSerializer() + self._pte_data, self._tensor_data = serialize_for_executorch( + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer ) self._buffer: Optional[bytes] = None @@ -1542,6 +1556,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ self._pte_data.write_to_file(open_file) + def write_tensor_data_to_file(self, outdir) -> None: + """ + Writes the serialized ExecuTorch data files to the directory at `outdir`. + """ + assert self._tensor_data is not None + for filename, cord in self._tensor_data.items(): + with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f: + logging.info(f"Writing data file to {filename}") + cord.write_to_file(f) + def save(self, path: str) -> None: """ Saves the serialized ExecuTorch binary to the file at `path`. diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index a289355919..2679930178 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -135,9 +135,12 @@ def save_pte_program( filename = os.path.join(output_dir, f"{model_name}.pte") try: + # Write program to file. with open(filename, "wb") as file: prog.write_to_file(file) logging.info(f"Saved exported program to {filename}") + # Write data to file/s. + prog.write_tensor_data_to_file(outdir=output_dir) except Exception as e: logging.error(f"Error while saving to {filename}: {e}")