From e15bb3fc11b50f4edc04452bfc03bf9fdc86c237 Mon Sep 17 00:00:00 2001 From: Pariente Manuel Date: Sat, 30 Jul 2022 11:54:44 +0200 Subject: [PATCH] [src] Allow disabling script_if_tracing for ONNX export (#16) --- asteroid_filterbanks/scripting.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/asteroid_filterbanks/scripting.py b/asteroid_filterbanks/scripting.py index 6e51a46..9b4183d 100644 --- a/asteroid_filterbanks/scripting.py +++ b/asteroid_filterbanks/scripting.py @@ -2,6 +2,20 @@ import torch +global SCRIPT_ENABLED +SCRIPT_ENABLED = True + + +def disable_script_if_tracing(): + global SCRIPT_ENABLED + SCRIPT_ENABLED = False + + +def enable_script_if_tracing(): + global SCRIPT_ENABLED + SCRIPT_ENABLED = True + + def is_tracing(): # Taken for pytorch for compat in 1.6.0 """ @@ -32,7 +46,7 @@ def script_if_tracing(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): - if not is_tracing(): + if not is_tracing() or not SCRIPT_ENABLED: # Not tracing, don't do anything return fn(*args, **kwargs)