Skip to content

Commit

Permalink
[executorch][flat_tensor] Serialize flat tensor tests
Browse files Browse the repository at this point in the history
Pull Request resolved: #7269

Introduce _convert_to_flat_tensor, which interprets a flat_tensor blob as a flat_tensor schema.

Use this for more comprehensive testing for flat tensor serialization, and later for deserialization.

ghstack-source-id: 261976100
@exported-using-ghexport

Differential Revision: [D67007821](https://our.internmc.facebook.com/intern/diff/D67007821/)

Co-authored-by: lucylq <[email protected]>
  • Loading branch information
2 people authored and YIWENX14 committed Jan 28, 2025
1 parent 03c85bf commit 2683adb
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 10 deletions.
36 changes: 31 additions & 5 deletions extension/flat_tensor/serialize/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import pkg_resources
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass

from executorch.exir._serialize._flatbuffer import _flatc_compile
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
Expand All @@ -33,8 +33,8 @@
)


def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
"""Converts a FlatTensor to a flatbuffer and returns the serialized data."""
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
flat_tensor_json = json.dumps(flat_tensor, cls=_DataclassEncoder)
with tempfile.TemporaryDirectory() as d:
schema_path = os.path.join(d, "flat_tensor.fbs")
Expand All @@ -57,6 +57,32 @@ def _convert_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
return Cord(output_file.read())


def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
"""Deserializes a flatbuffer to a FlatTensor and returns the dataclass."""
with tempfile.TemporaryDirectory() as d:
schema_path = os.path.join(d, "flat_tensor.fbs")
with open(schema_path, "wb") as schema_file:
schema_file.write(
pkg_resources.resource_string(__name__, "flat_tensor.fbs")
)

scalar_type_path = os.path.join(d, "scalar_type.fbs")
with open(scalar_type_path, "wb") as scalar_type_file:
scalar_type_file.write(
pkg_resources.resource_string(__name__, "scalar_type.fbs")
)

bin_path = os.path.join(d, "flat_tensor.bin")
with open(bin_path, "wb") as bin_file:
bin_file.write(flatbuffer)

_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])

json_path = os.path.join(d, "flat_tensor.json")
with open(json_path, "rb") as output_file:
return _json_to_dataclass(json.load(output_file), cls=FlatTensor)


@dataclass
class FlatTensorConfig:
tensor_alignment: int = 16
Expand Down Expand Up @@ -244,7 +270,7 @@ def serialize(
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
)

flatbuffer_payload = _convert_to_flatbuffer(flat_tensor)
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
padded_flatbuffer_length: int = aligned_size(
input_size=len(flatbuffer_payload),
alignment=self.config.tensor_alignment,
Expand Down
110 changes: 105 additions & 5 deletions extension/flat_tensor/test/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import unittest

from typing import List

from executorch.exir._serialize.data_serializer import (
DataPayload,
DataSerializer,
Expand All @@ -18,15 +20,17 @@
from executorch.exir._serialize.padding import aligned_size

from executorch.exir.schema import ScalarType
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorMetadata

from executorch.extension.flat_tensor.serialize.serialize import (
_deserialize_to_flat_tensor,
FlatTensorConfig,
FlatTensorHeader,
FlatTensorSerializer,
)

# Test artifacts.
TEST_TENSOR_BUFFER = [b"tensor"]
TEST_TENSOR_BUFFER: List[bytes] = [b"\x11" * 4, b"\x22" * 32]
TEST_TENSOR_MAP = {
"fqn1": TensorEntry(
buffer_index=0,
Expand All @@ -44,6 +48,14 @@
dim_order=[0, 1, 2],
),
),
"fqn3": TensorEntry(
buffer_index=1,
layout=TensorLayout(
scalar_type=ScalarType.INT,
sizes=[2, 2, 2],
dim_order=[0, 1],
),
),
}
TEST_DATA_PAYLOAD = DataPayload(
buffers=TEST_TENSOR_BUFFER,
Expand All @@ -52,13 +64,24 @@


class TestSerialize(unittest.TestCase):
# TODO(T211851359): improve test coverage.
def check_tensor_metadata(
self, tensor_layout: TensorLayout, tensor_metadata: TensorMetadata
) -> None:
self.assertEqual(tensor_layout.scalar_type, tensor_metadata.scalar_type)
self.assertEqual(tensor_layout.sizes, tensor_metadata.sizes)
self.assertEqual(tensor_layout.dim_order, tensor_metadata.dim_order)

def test_serialize(self) -> None:
config = FlatTensorConfig()
serializer: DataSerializer = FlatTensorSerializer(config)

data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))

header = FlatTensorHeader.from_bytes(data[0 : FlatTensorHeader.EXPECTED_LENGTH])
# Check header.
header = FlatTensorHeader.from_bytes(
serialized_data[0 : FlatTensorHeader.EXPECTED_LENGTH]
)
self.assertTrue(header.is_valid())

# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
Expand All @@ -77,9 +100,86 @@ def test_serialize(self) -> None:
self.assertTrue(header.segment_base_offset, expected_segment_base_offset)

# TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
self.assertEqual(header.segment_data_size, config.segment_alignment)
expected_segment_data_size = aligned_size(
sum(len(buffer) for buffer in TEST_TENSOR_BUFFER), config.segment_alignment
)
self.assertEqual(header.segment_data_size, expected_segment_data_size)

# Confirm the flatbuffer magic is present.
self.assertEqual(
data[header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8], b"FT01"
serialized_data[
header.flatbuffer_offset + 4 : header.flatbuffer_offset + 8
],
b"FT01",
)

# Check flat tensor data.
flat_tensor_bytes = serialized_data[
header.flatbuffer_offset : header.flatbuffer_offset + header.flatbuffer_size
]

flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)

self.assertEqual(flat_tensor.version, 0)
self.assertEqual(flat_tensor.tensor_alignment, config.tensor_alignment)

tensors = flat_tensor.tensors
self.assertEqual(len(tensors), 3)
self.assertEqual(tensors[0].fully_qualified_name, "fqn1")
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn1"].layout, tensors[0])
self.assertEqual(tensors[0].segment_index, 0)
self.assertEqual(tensors[0].offset, 0)

self.assertEqual(tensors[1].fully_qualified_name, "fqn2")
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn2"].layout, tensors[1])
self.assertEqual(tensors[1].segment_index, 0)
self.assertEqual(tensors[1].offset, 0)

self.assertEqual(tensors[2].fully_qualified_name, "fqn3")
self.check_tensor_metadata(TEST_TENSOR_MAP["fqn3"].layout, tensors[2])
self.assertEqual(tensors[2].segment_index, 0)
self.assertEqual(tensors[2].offset, config.tensor_alignment)

segments = flat_tensor.segments
self.assertEqual(len(segments), 1)
self.assertEqual(segments[0].offset, 0)
self.assertEqual(segments[0].size, config.tensor_alignment * 3)

# Length of serialized_data matches segment_base_offset + segment_data_size.
self.assertEqual(
header.segment_base_offset + header.segment_data_size, len(serialized_data)
)
self.assertTrue(segments[0].size <= header.segment_data_size)

# Check the contents of the segment. Expecting two tensors from
# TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
segment_data = serialized_data[
header.segment_base_offset : header.segment_base_offset + segments[0].size
]

# Tensor: b"\x11" * 4
t0_start = 0
t0_len = len(TEST_TENSOR_BUFFER[0])
t0_end = t0_start + aligned_size(t0_len, config.tensor_alignment)
self.assertEqual(
segment_data[t0_start : t0_start + t0_len], TEST_TENSOR_BUFFER[0]
)
padding = b"\x00" * (t0_end - t0_len)
self.assertEqual(segment_data[t0_start + t0_len : t0_end], padding)

# Tensor: b"\x22" * 32
t1_start = t0_end
t1_len = len(TEST_TENSOR_BUFFER[1])
t1_end = t1_start + aligned_size(t1_len, config.tensor_alignment)
self.assertEqual(
segment_data[t1_start : t1_start + t1_len],
TEST_TENSOR_BUFFER[1],
)
padding = b"\x00" * (t1_end - (t1_len + t1_start))
self.assertEqual(segment_data[t1_start + t1_len : t1_start + t1_end], padding)

# Check length of the segment is expected.
self.assertEqual(
segments[0].size, aligned_size(t1_end, config.segment_alignment)
)
self.assertEqual(segments[0].size, header.segment_data_size)

0 comments on commit 2683adb

Please sign in to comment.