Skip to content

Commit

Permalink
Update on "[executorch][flat_tensor] DataMap implementation"
Browse files Browse the repository at this point in the history
DataMap implementation that
* Loads a flat_tensor file
* Populates a map with {fqn: tensor} and {fqn: TensorLayout}.
* Makes tensor information available via the named_data_map.h interface.

For now, DataMap doesn't store the DataLoader.
- If/when tensors are in their own segments, DataMap should also store a DataLoader.

Differential Revision: [D67064580](https://our.internmc.facebook.com/intern/diff/D67064580/)

[ghstack-poisoned]
  • Loading branch information
lucylq committed Feb 5, 2025
2 parents 4cdd830 + 6193e1c commit 9c37b57
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/flat_tensor/data_map.h>
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>

#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>
Expand Down Expand Up @@ -76,7 +76,7 @@ Result<const TensorLayout> create_tensor_layout(

} // namespace

ET_NODISCARD Result<const TensorLayout> DataMap::get_metadata(
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
const char* key) const {
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
get_flat_tensor_metadata(key, flat_tensor_->tensors());
Expand All @@ -86,7 +86,8 @@ ET_NODISCARD Result<const TensorLayout> DataMap::get_metadata(
return create_tensor_layout(metadata_res.get());
}

ET_NODISCARD Result<FreeableBuffer> DataMap::get_data(const char* key) const {
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
const char* key) const {
auto tensor_metadata = flat_tensor_->tensors();

Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
Expand All @@ -95,41 +96,46 @@ ET_NODISCARD Result<FreeableBuffer> DataMap::get_data(const char* key) const {
return metadata_res.error();
}
const auto metadata = metadata_res.get();
if (metadata->segment_index() == -1 || metadata->offset() == -1) {
// Key doesn't exist.
return Error::NotFound;
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
// Invalid segment_index/offset; malformed PTD file.
return Error::InvalidExternalData;
}

Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
if (!tensor_layout_res.ok()) {
return tensor_layout_res.error();
}

// This FreeableBuffer doesn't own the underlying data, and will not free it,
// which is why the free function is a nullptr.
// TODO(T214294528)
return FreeableBuffer(
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
tensor_layout_res.get().nbytes(),
nullptr);
}

ET_NODISCARD Result<size_t> DataMap::load_data_into(
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const {
return Error::NotImplemented;
}

ET_NODISCARD Result<size_t> DataMap::get_num_keys() const {
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
return flat_tensor_->tensors()->size();
}

ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
size_t index) const {
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
return Error::InvalidArgument;
}
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
}

/* static */ Result<DataMap> DataMap::load(DataLoader* loader) {
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
DataLoader* loader) {
// Load data map.
size_t flatbuffer_offset = 0;
size_t flatbuffer_size = 0;
Expand Down Expand Up @@ -196,6 +202,16 @@ ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());

// Validate flatbuffer data.
flatbuffers::Verifier verifier(
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
flat_tensor_data->size());
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
ET_CHECK_OR_RETURN_ERROR(
ok,
InvalidExternalData,
"Verification failed; data may be truncated or corrupt");

// Get pointer to tensor metadata.
const auto* s_tensor_metadata = flat_tensor->tensors();
if (s_tensor_metadata == nullptr) {
Expand All @@ -213,13 +229,9 @@ ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
"FlatTensor has %u segments, only 1 supported.",
s_data_segment->size());
}
// First segment offset should be 0.
int segment_offset = s_data_segment->Get(0)->offset();
if (segment_offset != 0) {
ET_LOG(Error, "FlatTensor segment offset %d != 0", segment_offset);
}
// First segment size should be <= the total segment data size.
int segment_size = s_data_segment->Get(0)->size();
int segment_offset = s_data_segment->Get(0)->offset();
if (segment_size > segment_data_size) {
ET_LOG(
Error,
Expand All @@ -236,14 +248,9 @@ ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
return data_ro.error();
}

return DataMap(
return FlatTensorDataMap(
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
}

DataMap::~DataMap() {
flat_tensor_data_.Free();
data_ro_.Free();
}

} // namespace extension
} // namespace executorch
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ namespace extension {
/**
* A NamedDataMap implementation for FlatTensor-serialized data.
*/
class DataMap final : public executorch::runtime::NamedDataMap {
class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
public:
/**
* Creates a new DataMap that wraps FlatTensor data.
*
* @param[in] loader The DataLoader that wraps the FlatTensor file.
* Note: the loader must outlive the FlatTensorDataMap instance.
*/
static executorch::runtime::Result<DataMap> load(
static executorch::runtime::Result<FlatTensorDataMap> load(
executorch::runtime::DataLoader* loader);

ET_NODISCARD
Expand All @@ -54,22 +55,21 @@ class DataMap final : public executorch::runtime::NamedDataMap {
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
size_t index) const override;

DataMap(DataMap&&) noexcept = default;
~DataMap() override;
FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;

private:
DataMap(
FlatTensorDataMap(
executorch::runtime::FreeableBuffer&& flat_tensor_data,
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
executorch::runtime::FreeableBuffer&& data_ro)
: flat_tensor_data_(std::move(flat_tensor_data)),
flat_tensor_(flat_tensor),
data_ro_(std::move(data_ro)){};
data_ro_(std::move(data_ro)) {}

// Not copyable or assignable.
DataMap(const DataMap& rhs) = delete;
DataMap& operator=(DataMap&& rhs) noexcept = delete;
DataMap& operator=(const DataMap& rhs) = delete;
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;

// Serialized flat_tensor flatbuffer data.
executorch::runtime::FreeableBuffer flat_tensor_data_;
Expand Down
6 changes: 3 additions & 3 deletions extension/flat_tensor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
runtime.cxx_library(
name = "data_map",
name = "flat_tensor_data_map",
srcs = [
"data_map.cpp",
"flat_tensor_data_map.cpp",
],
exported_headers = ["data_map.h"],
exported_headers = ["flat_tensor_data_map.h"],
deps = [
"//executorch/extension/flat_tensor/serialize:schema",
"//executorch/extension/flat_tensor/serialize:serialize",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/flat_tensor/data_map.h>
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -17,7 +17,7 @@
#include <gtest/gtest.h>

using namespace ::testing;
using executorch::extension::DataMap;
using executorch::extension::FlatTensorDataMap;
using executorch::extension::FlatTensorHeader;
using executorch::runtime::DataLoader;
using executorch::runtime::Error;
Expand All @@ -26,7 +26,7 @@ using executorch::runtime::Result;
using executorch::runtime::TensorLayout;
using torch::executor::util::FileDataLoader;

class DataMapTest : public ::testing::Test {
class FlatTensorDataMapTest : public ::testing::Test {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
Expand All @@ -45,13 +45,15 @@ class DataMapTest : public ::testing::Test {
std::unique_ptr<FileDataLoader> data_map_loader_;
};

TEST_F(DataMapTest, LoadDataMap) {
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);
}

TEST_F(DataMapTest, DataMap_GetMetadata) {
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

// Check tensor layouts are correct.
Expand Down Expand Up @@ -91,8 +93,9 @@ TEST_F(DataMapTest, DataMap_GetMetadata) {
EXPECT_EQ(const_c_res.error(), Error::NotFound);
}

TEST_F(DataMapTest, DataMap_GetData) {
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

// Check tensor data sizes are correct.
Expand All @@ -111,8 +114,9 @@ TEST_F(DataMapTest, DataMap_GetData) {
EXPECT_EQ(data_c_res.error(), Error::NotFound);
}

TEST_F(DataMapTest, DataMap_Keys) {
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

// Check num tensors is 2.
Expand Down
6 changes: 3 additions & 3 deletions extension/flat_tensor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def define_common_targets(is_fbcode=False):
}

runtime.cxx_test(
name = "data_map",
name = "flat_tensor_data_map",
srcs = [
"data_map_test.cpp",
"flat_tensor_data_map_test.cpp",
],
deps = [
"//executorch/extension/data_loader:file_data_loader",
"//executorch/extension/flat_tensor:data_map",
"//executorch/extension/flat_tensor:flat_tensor_data_map",
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
"//executorch/extension/flat_tensor/serialize:generated_headers",
"//executorch/extension/flat_tensor/serialize:schema",
Expand Down

0 comments on commit 9c37b57

Please sign in to comment.