Skip to content

Commit

Permalink
Wrap model attributes in smart ptrs to avoid memory leaks
Browse files Browse the repository at this point in the history
Remove memory leaks from #57 and #61 by using shared ptrs on:
	model::graph
	model::session
And unique ptrs on vars from model constructor:
	session_options
	run_options
	meta_graph
  • Loading branch information
serizba committed Nov 3, 2020
1 parent 306a5c0 commit 8db5b32
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions include/cppflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,40 @@ namespace cppflow {
std::vector<tensor> operator()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<std::string> outputs);
tensor operator()(const tensor& input);

~model() = default;
model(const model &model) = default;
model(model &&model) = default;
model &operator=(const model &other) = default;
model &operator=(model &&other) = default;

private:

TF_Graph* graph;
TF_Session* session;
std::shared_ptr<TF_Graph> graph;
std::shared_ptr<TF_Session> session;
};
}


namespace cppflow {

model::model(const std::string &filename) {
this->graph = TF_NewGraph();
this->graph = {TF_NewGraph(), TF_DeleteGraph};

// Create the session.
TF_SessionOptions* session_options = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* meta_graph = TF_NewBuffer();
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};

auto session_deleter = [](TF_Session* sess) {
TF_DeleteSession(sess, context::get_status());
status_check(context::get_status());
};

int tag_len = 1;
const char* tag = "serve";
this->session = TF_LoadSessionFromSavedModel(session_options, run_options, filename.c_str(), &tag, tag_len, graph, meta_graph, context::get_status());
TF_DeleteSessionOptions(session_options);
TF_DeleteBuffer(run_options);
//TF_DeleteBuffer(meta_graph);
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
session_deleter};

status_check(context::get_status());
}
Expand All @@ -58,7 +69,7 @@ namespace cppflow {
TF_Operation* oper;

// Iterate through the operations of a graph
while ((oper = TF_GraphNextOperation(this->graph, &pos)) != nullptr) {
while ((oper = TF_GraphNextOperation(this->graph.get(), &pos)) != nullptr) {
result.emplace_back(TF_OperationName(oper));
}
return result;
Expand All @@ -77,7 +88,7 @@ namespace cppflow {

// Operations
const auto[op_name, op_idx] = parse_name(std::get<0>(inputs[i]));
inp_ops[i].oper = TF_GraphOperationByName(this->graph, op_name.c_str());
inp_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
inp_ops[i].index = op_idx;

if (!inp_ops[i].oper)
Expand All @@ -94,15 +105,15 @@ namespace cppflow {
for (int i=0; i<outputs.size(); i++) {

const auto[op_name, op_idx] = parse_name(outputs[i]);
out_ops[i].oper = TF_GraphOperationByName(this->graph, op_name.c_str());
out_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
out_ops[i].index = op_idx;

if (!out_ops[i].oper)
throw std::runtime_error("No operation named \"" + op_name + "\" exists");

}

TF_SessionRun(this->session, NULL,
TF_SessionRun(this->session.get(), NULL,
inp_ops.data(), inp_val.data(), inputs.size(),
out_ops.data(), out_val.get(), outputs.size(),
NULL, 0,NULL , context::get_status());
Expand Down

0 comments on commit 8db5b32

Please sign in to comment.