Skip to content

Commit

Permalink
torch dyanamo in c++ test
Browse files Browse the repository at this point in the history
ghstack-source-id: 53fa37b8b40ac4c4b4778cef83f9405aba1d6259
Pull Request resolved: #299
  • Loading branch information
PaliC committed Jan 14, 2023
1 parent 2e9a48a commit a8755bb
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
12 changes: 12 additions & 0 deletions multipy/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ set(INTERPRETER_TEST_SOURCES_GPU
${DEPLOY_DIR}/test_deploy_gpu.cpp
)

set(INTERPRETER_TEST_SOURCES_COMPAT
${DEPLOY_DIR}/test_dynamo_compat.cpp
)

# TODO: Currently tests can only be done when ABI=1 as the testing infrustructure
# used by ASSERT_TRUE requires ABI=1 in Github actions, we should fix this!

Expand All @@ -99,6 +103,14 @@ target_link_libraries(test_deploy
)
target_include_directories(test_deploy PRIVATE ${CMAKE_SOURCE_DIR}/../..)

add_executable(test_compat ${INTERPRETER_TEST_SOURCES_COMPAT})
# target_compile_definitions(test_compat PUBLIC TEST_CUSTOM_LIBRARY)
target_include_directories(test_compat PRIVATE ${PYTORCH_ROOT}/torch)
target_link_libraries(test_compat
PUBLIC "-Wl,--no-as-needed -rdynamic" gtest dl torch_deploy_interface c10 torch_cpu
)
target_include_directories(test_compat PRIVATE ${CMAKE_SOURCE_DIR}/../..)

if(BUILD_CUDA_TESTS)
LINK_DIRECTORIES("${PYTORCH_ROOT}/torch/lib")
add_executable(test_deploy_gpu ${INTERPRETER_TEST_SOURCES_GPU})
Expand Down
4 changes: 4 additions & 0 deletions multipy/runtime/test_deploy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ TEST(TorchpyTest, SimpleModel) {
compare_torchpy_jit(path("SIMPLE", simple), path("SIMPLE_JIT", simple_jit));
}

TEST(TorchpyTest, DynamoTest) {

}

#ifdef FBCODE_CAFFE2
TEST(TorchpyTest, LoadTextAndBinary) {
torch::deploy::InterpreterManager manager(1);
Expand Down
68 changes: 68 additions & 0 deletions multipy/runtime/test_dynamo_compat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/Parallel.h>
#include <gtest/gtest.h>
#include <libgen.h>
#include <cstring>

#include <c10/util/irange.h>
#include <libgen.h>
#include <multipy/runtime/deploy.h>
#include <torch/script.h>
#include <torch/torch.h>

#include <future>
#include <iostream>
#include <string>

void compare_torchpy_jit(const char* model_filename, const char* jit_filename) {
// Test

torch::deploy::InterpreterManager m(2);
torch::deploy::Package p = m.loadPackage(model_filename);
auto model = p.loadPickle("model", "model.pkl");
at::IValue eg;
{
auto I = p.acquireSession();
eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
}
auto I = p.acquireSession();
auto cModelObj = I.global("torch", "compile")(model.toObj(&I));
auto cModel = m.createMovable(cModelObj, &I);
at::Tensor output = cModel(eg.toTupleRef().elements()).toTensor();

// Reference
auto ref_model = torch::jit::load(jit_filename);
at::Tensor ref_output =
ref_model.forward(eg.toTupleRef().elements()).toTensor();

ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05));
}

const char* simple = "multipy/runtime/example/generated/simple";
const char* simple_jit = "multipy/runtime/example/generated/simple_jit";

const char* path(const char* envname, const char* path) {
const char* e = getenv(envname);
return e ? e : path;
}

TEST(TorchpyTest, SimpleModel) {
compare_torchpy_jit(path("SIMPLE", simple), path("SIMPLE_JIT", simple_jit));
}

int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
char tempeh[256];
getcwd(tempeh, 256);
std::cout << "Current working directory: " << tempeh << std::endl;
int rc = RUN_ALL_TESTS();
char tmp[256];
getcwd(tmp, 256);
std::cout << "Current working directory: " << tmp << std::endl;
return rc;
}

0 comments on commit a8755bb

Please sign in to comment.