Skip to content

Commit

Permalink
Enable weight sharing in LiteRt core.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726125027
  • Loading branch information
YunanAZ authored and tensorflower-gardener committed Feb 12, 2025
1 parent 6ee6437 commit b6a99af
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/experimental/litert/core/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ cc_library(
"//tensorflow/lite/experimental/litert/cc:litert_macros",
"//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:string_view",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,9 @@ TEST_P(MultiSubgraphDupeConstTest, CheckGraph) {
Tensor t(&cst);
EXPECT_THAT(*t.WeightsData<float>(), ElementsAreArray(kWeights));
}
auto buf_id_0 = model.Subgraph(0).Op(0).Input(1).Weights().GetBufferId();
auto buf_id_1 = model.Subgraph(1).Op(0).Input(1).Weights().GetBufferId();
ASSERT_EQ(buf_id_0, buf_id_1);
}

INSTANTIATE_TEST_SUITE_P(ModelLoadTests, MultiSubgraphDupeConstTest,
Expand All @@ -804,7 +807,7 @@ INSTANTIATE_TEST_SUITE_P(ModelLoadTests, MultiSubgraphDupeConstTest,
INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, MultiSubgraphDupeConstTest,
Values(MakeRoundTripFactory(kCstMultiSubgraph)));

// Tests that programatically check litert against tflite models.
// Tests that programmatically check litert against tflite models.
//===---------------------------------------------------------------------------

using ModelLoadOpCheckTest = TestWithModelPath;
Expand Down
17 changes: 16 additions & 1 deletion tensorflow/lite/experimental/litert/core/model/model_load.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/experimental/litert/c/litert_common.h"
#include "tensorflow/lite/experimental/litert/c/litert_logging.h"
Expand All @@ -42,6 +43,10 @@ namespace {
// Provides a view of model-level resources when constructing litert graph.
class FlatbufferContext {
public:
using LiteRtBufferId = uint32_t;
using TflBufferInd = uint32_t;
using BufferIdMap = absl::flat_hash_map<TflBufferInd, LiteRtBufferId>;

FlatbufferContext(const FlatbufferWrapper& tfl_flatbuffer,
BufferManager* buffer_manager)
: tfl_flatbuffer_(tfl_flatbuffer), buffer_manager_(buffer_manager) {}
Expand Down Expand Up @@ -71,9 +76,12 @@ class FlatbufferContext {
return tfl_flatbuffer_.PackedModel();
}

BufferIdMap& RegisteredTflBufferIds() { return registered_tfl_buffer_ids_; }

private:
const FlatbufferWrapper& tfl_flatbuffer_;
BufferManager* buffer_manager_;
BufferIdMap registered_tfl_buffer_ids_;
};

LiteRtStatus UnpackOp(FlatbufferContext& context, LiteRtSubgraphT& parent,
Expand Down Expand Up @@ -172,7 +180,14 @@ LiteRtStatus UnpackTensor(FlatbufferContext& context,
return buffer.Error().Status();
}

SetWeightsFromUnownedBuffer(litert_tensor.Weights(), *buffer);
auto it = context.RegisteredTflBufferIds().find(buffer_ind);
if (it != context.RegisteredTflBufferIds().end()) {
litert_tensor.Weights().SetBufferId(it->second);
} else {
SetWeightsFromUnownedBuffer(litert_tensor.Weights(), *buffer);
context.RegisteredTflBufferIds()[buffer_ind] =
litert_tensor.Weights().GetBufferId();
}
}

// TENSOR TYPE
Expand Down
30 changes: 19 additions & 11 deletions tensorflow/lite/experimental/litert/core/model/model_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ namespace {

using TensorMap = absl::flat_hash_map<LiteRtTensor, int32_t>;


// This is expected to be used to serialize the dispatch op custom code.
TflOpCodePtr MakeCustomOpCode(std::string custom_code_name) {
auto custom_code = std::make_unique<TflOpCode>();
Expand All @@ -67,6 +66,8 @@ class SerializationContext {
using TflBufferInd = uint32_t;
using TflOffsetTensorMap =
absl::flat_hash_map<TflBufferInd, LiteRtModelT::BufferId>;
using TflBufferIdMap =
absl::flat_hash_map<LiteRtModelT::BufferId, TflBufferInd>;

explicit SerializationContext(uint32_t dispatch_op_code_ind,
LiteRtModelT& litert_model)
Expand Down Expand Up @@ -100,17 +101,23 @@ class SerializationContext {
return litert_buf.Error().Status();
}

auto& tfl_buffer =
tfl_model_->buffers.emplace_back(std::make_unique<TflBuffer>());
const auto tfl_buffer_ind = tfl_model_->buffers.size() - 1;

if (litert_buf_ctx->get().should_append) {
tfl_buffer->offset = 1;
tfl_buffer->size = 1;
offset_tensor_map_.emplace(tfl_buffer_ind, litert_buf_id);
TflBufferInd tfl_buffer_ind;
if (buffer_id_map_.contains(litert_buf_id)) {
tfl_buffer_ind = buffer_id_map_.at(litert_buf_id);
} else {
tfl_buffer->data.assign(litert_buf->Data(),
litert_buf->Data() + litert_buf->Size());
auto& tfl_buffer =
tfl_model_->buffers.emplace_back(std::make_unique<TflBuffer>());
tfl_buffer_ind = tfl_model_->buffers.size() - 1;

if (litert_buf_ctx->get().should_append) {
tfl_buffer->offset = 1;
tfl_buffer->size = 1;
offset_tensor_map_.emplace(tfl_buffer_ind, litert_buf_id);
} else {
tfl_buffer->data.assign(litert_buf->Data(),
litert_buf->Data() + litert_buf->Size());
}
buffer_id_map_[litert_buf_id] = tfl_buffer_ind;
}

tfl_tensor.buffer = tfl_buffer_ind;
Expand Down Expand Up @@ -158,6 +165,7 @@ class SerializationContext {

TflOpAssetMap op_asset_map_;
TflOffsetTensorMap offset_tensor_map_;
TflBufferIdMap buffer_id_map_;
};

void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) {
Expand Down

0 comments on commit b6a99af

Please sign in to comment.