Skip to content

Commit

Permalink
fix(sgx): fix memory leak of json_engine
Browse files Browse the repository at this point in the history
  • Loading branch information
zeuson0 committed Apr 3, 2024
1 parent 5940cbc commit f17dedb
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 225 deletions.
6 changes: 0 additions & 6 deletions sgx/grpc/v1.38.1/examples/dynamic_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
"verify_isv_prod_id" : "on",
"verify_isv_svn" : "on",
"sgx_mrs": [
{
"mr_enclave" : "",
"mr_signer" : "",
"isv_prod_id" : "0",
"isv_svn" : "0"
}
],
"other" : []
}
205 changes: 114 additions & 91 deletions sgx/grpc/v1.38.1/src/cpp/sgx/sgx_ra_tls_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ namespace sgx {

struct ra_tls_context _ctx_;

void check_free(void* ptr) {
if (ptr) {
free(ptr);
ptr = nullptr;
};
}

bool hex_to_byte(const char *src, char *dst, size_t dst_size) {
if (strlen(src) < dst_size*2) {
return false;
Expand Down Expand Up @@ -100,107 +93,140 @@ void* library_engine::get_handle() {
return handle;
}

json_engine::json_engine() : handle(nullptr) {};
json_engine::json_engine() : handle(nullptr){};

json_engine::json_engine(const char* file) : handle(nullptr){
this->open(file);
json_engine::json_engine(const char *file) : handle(nullptr)
{
this->open(file);
}

json_engine::~json_engine() {
this->close();
json_engine::~json_engine()
{
this->close();
}

bool json_engine::open(const char* file) {
if (!file) {
grpc_printf("wrong json file path\n");
return false;
}
bool json_engine::open(const char *file)
{
if (!file)
{
grpc_printf("wrong json file path\n");
return false;
}

this->close();
this->close();

auto file_ptr = fopen(file, "r");
fseek(file_ptr, 0, SEEK_END);
auto length = ftell(file_ptr);
fseek(file_ptr, 0, SEEK_SET);
auto buffer = malloc(length);
fread(buffer, 1, length, file_ptr);
fclose(file_ptr);
auto file_ptr = fopen(file, "r");
fseek(file_ptr, 0, SEEK_END);
auto length = ftell(file_ptr);
fseek(file_ptr, 0, SEEK_SET);
auto buffer = malloc(length + 1);
memset(buffer, 0, length);
fread(buffer, 1, length, file_ptr);
fclose(file_ptr);

this->handle = cJSON_Parse((const char *)buffer);
this->handle = cJSON_Parse((const char*)buffer);

check_free(buffer);
if (buffer)
{
free(buffer);
buffer = nullptr;
}

if (this->handle) {
return true;
} else {
grpc_printf("cjson open %s error: %s", file, cJSON_GetErrorPtr());
return false;
}
if (this->handle)
{
return true;
}
else
{
grpc_printf("cjson open %s error: %s", file, cJSON_GetErrorPtr());
return false;
}
}

void json_engine::close() {
if (this->handle) {
cJSON_Delete(this->handle);
this->handle = nullptr;
}
void json_engine::close()
{
if (this->handle)
{
cJSON_Delete(this->handle);
this->handle = nullptr;
}
}

cJSON * json_engine::get_handle() {
return this->handle;
cJSON *json_engine::get_handle()
{
return this->handle;
}

cJSON * json_engine::get_item(cJSON *obj, const char *item) {
return cJSON_GetObjectItem(obj, item);
};

char * json_engine::print_item(cJSON *obj) {
return cJSON_Print(obj);
cJSON *json_engine::get_item(cJSON *obj, const char *item)
{
return cJSON_GetObjectItem(obj, item);
};

bool json_engine::compare_item(cJSON *obj, const char *item) {
auto obj_item = this->print_item(obj);
return strncmp(obj_item+1, item, std::min(strlen(item), strlen(obj_item)-2)) == 0;
bool json_engine::compare_item(cJSON *obj, const char *item)
{
if (!obj || !cJSON_IsString(obj)){
return false;
}
auto obj_item = obj->valuestring;
return strncmp(obj_item, item, std::min(strlen(item), strlen(obj_item))) == 0;
};

sgx_config parse_sgx_config_json(const char* file) {
class json_engine sgx_json(file);
struct sgx_config sgx_cfg;

sgx_cfg.verify_in_enclave = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_in_enclave"), "on");
sgx_cfg.verify_mr_enclave = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_mr_enclave"), "on");
sgx_cfg.verify_mr_signer = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_mr_signer"), "on");
sgx_cfg.verify_isv_prod_id = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_isv_prod_id"), "on");
sgx_cfg.verify_isv_svn = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_isv_svn"), "on");
// grpc_printf("%d, %d, %d, %d, %d\n", sgx_cfg.verify_in_enclave,
// sgx_cfg.verify_mr_enclave,
// sgx_cfg.verify_mr_signer,
// sgx_cfg.verify_isv_prod_id,
// sgx_cfg.verify_isv_svn);

auto objs = sgx_json.get_item(sgx_json.get_handle(), "sgx_mrs");
auto obj_num = cJSON_GetArraySize(objs);

sgx_cfg.sgx_mrs = std::vector<sgx_measurement>(obj_num, sgx_measurement());
for (auto i = 0; i < obj_num; i++) {
auto obj = cJSON_GetArrayItem(objs, i);

auto mr_enclave = sgx_json.print_item(sgx_json.get_item(obj, "mr_enclave"));
memset(sgx_cfg.sgx_mrs[i].mr_enclave, 0, sizeof(sgx_cfg.sgx_mrs[i].mr_enclave));
hex_to_byte(mr_enclave+1, sgx_cfg.sgx_mrs[i].mr_enclave, sizeof(sgx_cfg.sgx_mrs[i].mr_enclave));

auto mr_signer = sgx_json.print_item(sgx_json.get_item(obj, "mr_signer"));
memset(sgx_cfg.sgx_mrs[i].mr_signer, 0, sizeof(sgx_cfg.sgx_mrs[i].mr_signer));
hex_to_byte(mr_signer+1, sgx_cfg.sgx_mrs[i].mr_signer, sizeof(sgx_cfg.sgx_mrs[i].mr_signer));

auto isv_prod_id = sgx_json.print_item(sgx_json.get_item(obj, "isv_prod_id"));
sgx_cfg.sgx_mrs[i].isv_prod_id = strtoul(isv_prod_id, nullptr, 10);
const char* json_engine::get_item_string(cJSON *obj, const char* item){
auto item_json = get_item(obj, item);
return cJSON_IsString(item_json) ? item_json->valuestring : "";
}

auto isv_svn = sgx_json.print_item(sgx_json.get_item(obj, "isv_svn"));
sgx_cfg.sgx_mrs[i].isv_svn = strtoul(isv_svn, nullptr, 10);

// grpc_printf("%s, %s, %s, %s\n", mr_enclave, mr_signer, isv_prod_id, isv_svn);
};
return sgx_cfg;
sgx_config parse_sgx_config_json(const char *file)
{
class json_engine sgx_json(file);
struct sgx_config sgx_cfg;

sgx_cfg.verify_in_enclave = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_in_enclave"), "on");
sgx_cfg.verify_mr_enclave = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_mr_enclave"), "on");
sgx_cfg.verify_mr_signer = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_mr_signer"), "on");
sgx_cfg.verify_isv_prod_id = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_isv_prod_id"), "on");
sgx_cfg.verify_isv_svn = sgx_json.compare_item(sgx_json.get_item(sgx_json.get_handle(), "verify_isv_svn"), "on");

grpc_printf("|- verify_in_enclave: %s\n", sgx_cfg.verify_in_enclave ? "on" : "off");
grpc_printf("|- verify_mr_enclave: %s\n", sgx_cfg.verify_mr_enclave ? "on" : "off");
grpc_printf("|- verify_mr_signer: %s\n", sgx_cfg.verify_mr_signer ? "on" : "off");
grpc_printf("|- verify_isv_prod_id: %s\n", sgx_cfg.verify_isv_prod_id ? "on" : "off");
grpc_printf("|- verify_isv_svn: %s\n", sgx_cfg.verify_isv_svn ? "on" : "off");

auto objs = sgx_json.get_item(sgx_json.get_handle(), "sgx_mrs");
auto obj_num = cJSON_GetArraySize(objs);

sgx_cfg.sgx_mrs = std::vector<sgx_measurement>(obj_num, sgx_measurement());
for (auto i = 0; i < obj_num; i++)
{
auto obj = cJSON_GetArrayItem(objs, i);
grpc_printf(" |- expect measurement [%d]:\n", i + 1);
auto mr_enclave = sgx_json.get_item_string(obj, "mr_enclave");
grpc_printf(" |- mr_enclave: %s\n", mr_enclave);
memset(sgx_cfg.sgx_mrs[i].mr_enclave, 0, sizeof(sgx_cfg.sgx_mrs[i].mr_enclave));
auto res = hex_to_byte(mr_enclave, sgx_cfg.sgx_mrs[i].mr_enclave, sizeof(sgx_cfg.sgx_mrs[i].mr_enclave));
if (!res){
grpc_printf("mr_enclave invalid, %s\n", mr_enclave);
}

auto mr_signer = sgx_json.get_item_string(obj, "mr_signer");
grpc_printf(" |- mr_signer: %s\n", mr_signer);
memset(sgx_cfg.sgx_mrs[i].mr_signer, 0, sizeof(sgx_cfg.sgx_mrs[i].mr_signer));
res = hex_to_byte(mr_signer, sgx_cfg.sgx_mrs[i].mr_signer, sizeof(sgx_cfg.sgx_mrs[i].mr_signer));
if (!res){
grpc_printf("mr_signer invalid, %s\n", mr_signer);
}

auto isv_prod_id = sgx_json.get_item_string(obj, "isv_prod_id");
grpc_printf(" |- isv_prod_id: %s\n", isv_prod_id);
sgx_cfg.sgx_mrs[i].isv_prod_id = strtoul(isv_prod_id ,nullptr, 10);

auto isv_svn = sgx_json.get_item_string(obj, "isv_svn");
grpc_printf(" |- isv_svn: %s\n", isv_svn);
sgx_cfg.sgx_mrs[i].isv_svn = strtoul(isv_svn ,nullptr, 10);;
};
return sgx_cfg;
}

int TlsAuthorizationCheck::Schedule(grpc::experimental::TlsServerAuthorizationCheckArg* arg) {
Expand All @@ -210,13 +236,10 @@ int TlsAuthorizationCheck::Schedule(grpc::experimental::TlsServerAuthorizationCh
auto peer_cert_buf = arg->peer_cert();
peer_cert_buf.copy(der_crt, peer_cert_buf.length(), 0);

// char der_crt[16000] = TEST_CRT_PEM;
// grpc_printf("%s\n", der_crt);

int ret = (*_ctx_.verify_callback_f)(reinterpret_cast<uint8_t *>(der_crt), 16000);

if (ret != 0) {
grpc_printf("something went wrong while verifying quote\n");
grpc_printf("something went wrong while verifying quote, error: %s\n", mbedtls_high_level_strerr(ret));
arg->set_success(0);
arg->set_status(GRPC_STATUS_UNAUTHENTICATED);
} else {
Expand Down Expand Up @@ -405,7 +428,7 @@ int ra_tls_auth_check_schedule(void* /* confiuser_data */,
int ret = (*_ctx_.verify_callback_f)(reinterpret_cast<uint8_t *>(der_crt), 16000);

if (ret != 0) {
grpc_printf("something went wrong while verifying quote\n");
grpc_printf("something went wrong while verifying quote, error: %s\n", mbedtls_high_level_strerr(ret));
arg->success = 0;
arg->status = GRPC_STATUS_UNAUTHENTICATED;
} else {
Expand Down
23 changes: 11 additions & 12 deletions sgx/grpc/v1.38.1/src/cpp/sgx/sgx_ra_tls_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,29 @@ class library_engine {
char* error;
};

class json_engine {
public:
class json_engine
{
public:
json_engine();

json_engine(const char*);
json_engine(const char *);

~json_engine();

bool open(const char*);
bool open(const char *);

void close();

cJSON * get_handle();

cJSON * get_item(cJSON *obj, const char *item);
cJSON *get_handle();

char * print_item(cJSON *obj);
cJSON *get_item(cJSON *obj, const char *item);

bool compare_item(cJSON *obj, const char *item);

private:
cJSON* handle;
const char* get_item_string(cJSON *obj, const char* item);

private:
cJSON *handle;
};

class TlsAuthorizationCheck
Expand Down Expand Up @@ -155,8 +156,6 @@ struct ra_tls_context {
int (*verify_callback_f)(uint8_t* der_crt, size_t der_crt_size) = nullptr;
};

void check_free(void* ptr);

sgx_config parse_sgx_config_json(const char* file);

bool ra_tls_verify_measurement(const char* mr_enclave, const char* mr_signer,
Expand Down
5 changes: 2 additions & 3 deletions sgx/tf/sgx_tls_sample.diff
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index ce1a20a5ae9..65266309cba 100644
index ce1a20a5ae9..23f329096ca 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -41,9 +41,10 @@ filegroup(
@@ -41,9 +41,9 @@ filegroup(

cc_library(
name = "grpc_util",
Expand All @@ -11,7 +11,6 @@ index ce1a20a5ae9..65266309cba 100644
- linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]),
+ srcs = ["grpc_util.cc", "grpc_sgx_ra_tls_utils.cc", "grpc_sgx_ra_tls_server.cc", "grpc_sgx_ra_tls_client.cc", "grpc_sgx_credentials_provider.cc"],
+ hdrs = ["grpc_util.h", "grpc_sgx_ra_tls.h", "grpc_sgx_ra_tls_utils.h", "grpc_sgx_credentials_provider.h"],
+ include_dirs = ["/usr/local/include"],
+ linkopts = ["-L/usr/local/lib", "-l:libmbedx509_gramine.a", "-l:libmbedcrypto_gramine.a", "-l:libcjson.a", "-l:libcjson_utils.a"],
deps = [
"//tensorflow/core:lib",
Expand Down
Loading

0 comments on commit f17dedb

Please sign in to comment.