Skip to content

Commit

Permalink
Fix (jit): remove patcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 10, 2023
1 parent 015eb09 commit edde30d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from brevitas.proxy.quant_proxy import QuantProxyProtocol
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.jit_utils import jit_patches_generator
# from brevitas.utils.jit_utils import jit_patches_generator
from brevitas.utils.python_utils import patch


Expand Down Expand Up @@ -162,7 +162,7 @@ class BaseManager(ABC):

target_name = None
handlers = []
_base_trace_patches_generator = jit_patches_generator
_base_trace_patches_generator = None # jit_patches_generator
_fn_to_cache = []
_fn_cache = []
_cached_io_handler_map = {}
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from brevitas.config import JIT_ENABLED

IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')
# IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')


def _disabled(fn):
Expand All @@ -20,10 +20,10 @@ def _disabled(fn):
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

if not IS_ABOVE_110:
script_method_110_disabled = _disabled
else:
script_method_110_disabled = script_method
script_method_110_disabled = script_method
# script_method_110_disabled = _disabled
# if not IS_ABOVE_110:
# else:

else:

Expand Down
42 changes: 20 additions & 22 deletions src/brevitas/utils/jit_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import inspect
# import inspect

import torch

try:
from torch._jit_internal import get_torchscript_modifier
except:
get_torchscript_modifier = None

from dependencies import Injector
# from dependencies import Injector
from packaging import version
import torch

from brevitas import torch_version
from brevitas.inject import ExtendedInjector
from brevitas.jit import IS_ABOVE_110

from .python_utils import patch
# try:
# from torch._jit_internal import get_torchscript_modifier
# except:
# get_torchscript_modifier = None

# from brevitas.inject import ExtendedInjector
# from brevitas.jit import IS_ABOVE_110

def _get_modifier_wrapper(fn):
if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
return None
else:
return get_torchscript_modifier(fn)
# from .python_utils import patch

# def _get_modifier_wrapper(fn):
# if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
# return None
# else:
# return get_torchscript_modifier(fn)

if IS_ABOVE_110:
# if IS_ABOVE_110:

def jit_patches_generator():
return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
else:
jit_patches_generator = None
# def jit_patches_generator():
# return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
# else:
# jit_patches_generator = None


def clear_class_registry():
Expand Down

0 comments on commit edde30d

Please sign in to comment.