From 0749b88f92038d202ff136206e1a471853181e19 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 16 Sep 2022 14:22:11 -0700 Subject: [PATCH] multipy/runtime: mock _decomp --- multipy/runtime/interpreter/interpreter_impl.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/multipy/runtime/interpreter/interpreter_impl.cpp b/multipy/runtime/interpreter/interpreter_impl.cpp index 9ed8e09d..ead180f4 100644 --- a/multipy/runtime/interpreter/interpreter_impl.cpp +++ b/multipy/runtime/interpreter/interpreter_impl.cpp @@ -125,8 +125,16 @@ import importlib.abc import linecache from zipfile import ZipFile +class DummyMultiPyModule: + def __getattr__(self, key): + return self + + def __call__(self, *args, **kwargs): + return self + # Disable Python library registration since it's not compatible with multipy. -sys.modules["torch._meta_registrations"] = object +sys.modules["torch._meta_registrations"] = DummyMultiPyModule() +sys.modules["torch._decomp"] = DummyMultiPyModule() class RegisterModuleImporter(importlib.abc.InspectLoader): def __init__(self, find_module_source): @@ -160,6 +168,8 @@ class RegisterModuleImporter(importlib.abc.InspectLoader): # print("modules:", sys.modules) import torch # has to be done serially otherwise things will segfault +torch._decomp = DummyMultiPyModule() + import multipy.utils try: import torch.version # for some reason torch doesn't import this and cuda fails?