From 1d269e2a6d443bb87616c47594960af0dba91280 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 28 Nov 2024 01:36:03 +0000 Subject: [PATCH 1/5] feat: Reduce partitioning overhead --- py/torch_tensorrt/dynamo/_compiler.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9859668cd9..93e33db311 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -642,7 +642,7 @@ def compile_module( num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops ) - + dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops dryrun_tracker.compilation_settings = settings @@ -684,6 +684,20 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) + breakpoint() + # Skip partitioning if the whole graph is supported to reduce partitioning overhead + if num_supported_ops == total_ops: + gm_inputs = partitioning.construct_submodule_inputs(gm) + trt_module = convert_module( + gm, + gm_inputs, + settings=settings, + name="whole_graph", + engine_cache=engine_cache, + ) + breakpoint() + return trt_module + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -697,6 +711,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( "Partitioning failed on the subgraph with fast partition. See trace above. " @@ -717,7 +732,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: require_full_compilation=settings.require_full_compilation, ) - dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + # dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators # The global partitioner leaves non-TRT nodes as-is if not settings.use_fast_partitioner: From 4e96a91cfbb8c39f08f470051592a29c2a5f7d9a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 2 Dec 2024 23:42:50 +0000 Subject: [PATCH 2/5] chore: updates --- examples/dynamo/torch_export_flux.py | 254 ++++++++++++++++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 4 + 2 files changed, 258 insertions(+) create mode 100644 examples/dynamo/torch_export_flux.py diff --git a/examples/dynamo/torch_export_flux.py b/examples/dynamo/torch_export_flux.py new file mode 100644 index 0000000000..3a2559235b --- /dev/null +++ b/examples/dynamo/torch_export_flux.py @@ -0,0 +1,254 @@ + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from diffusers import FluxPipeline, FluxTransformer2DModel +from utils import export_llm, generate +from torch.export import Dim +from typing import Optional, Dict, Any +import logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + +import time +from contextlib import contextmanager + +@contextmanager +def timer(logger, name:str): + logger.info(f"{name} section Start...") + start = time.time() + yield + end = time.time() + logger.info(f"{name} section End...") + logger.info(f"{name} section elapsed time: {end - start} seconds") + +class MyModule(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, **kwargs): + + + return self.module.forward( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + # guidance, + # joint_attention_kwargs, + # return_dict + ) + +def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length): + from unittest.mock import patch + +# Assume `instance` is your class instance containing the `__call__` method + +# Use patch.object to mock the __call__ method of self.transformer + with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call: + # one step is enough for intercept the inputs + image =instance( + prompt, + guidance_scale=0.0, + num_inference_steps=1, + max_sequence_length=max_sequence_length, + generator=torch.Generator("cpu").manual_seed(0) + ).images[0] + + + # Access the call arguments of the first (or specific) call + if mock_transformer_call.call_args_list: + args, kwargs = mock_transformer_call.call_args_list[0] + # Store the inputs in a tuple + intercepted_inputs = (args, kwargs) + + # print("Intercepted args:", args) + # print("Intercepted kwargs:", kwargs) + return (args, kwargs) + else: + print("No calls were made to self.transformer.__call__") + return (None, None) + + +if __name__ == "__main__": + + # config + dryrun = False + + # parameter setting + batch_size = 2 + max_seq_len = 256 + prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)] + device = "cuda:0" + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.float16, num_layers=1) + # pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", + # torch_dtype=torch.float16, device_map="balanced") + pipe.to(device) + # image = pipe( + # prompt, + # guidance_scale=0.0, + # num_inference_steps=4, + # max_sequence_length=256, + # generator=torch.Generator("cpu").manual_seed(0) + # ).images[0] + # image.save("pytorch_flux-schnell.png") + # breakpoint() + # pipe.reset_device_map() + # pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power + example_args , example_kwargs = wrap_pipeline_transformer_call(pipe, prompt, max_seq_len) + tensor_inputs = ['hidden_states', 'timestep', 'pooled_projections', 'encoder_hidden_states', 'txt_ids', 'img_ids' ] + example_kwargs_shapes = {key: example_kwargs[key].shape for key in tensor_inputs} + BATCH = Dim("batch", min=1, max=batch_size) + SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len) + dynamic_shapes = ({0 : BATCH}, + {0 : BATCH, + 1 : SEQ_LEN, + }, + {0 : BATCH}, + {0 : BATCH}, + {0 : BATCH}, + {0 : BATCH, + 1 : SEQ_LEN, + }, + # None, + # None, + # None, + ) + example_args = ( + example_kwargs['hidden_states'], + example_kwargs['encoder_hidden_states'], + example_kwargs['pooled_projections'], + example_kwargs['timestep'], + example_kwargs['img_ids'], + example_kwargs['txt_ids'], + # example_kwargs['guidance'], + # example_kwargs['joint_attention_kwargs'], + # example_kwargs['return_dict'], + + ) + + # dynamic_shapes = {'hidden_states': {0 : BATCH}, + # 'encoder_hidden_states': {0 : BATCH, + # 1 : SEQ_LEN, + # }, + # 'pooled_projections': {0 : BATCH}, + # 'timestep': {0 : BATCH}, + # 'img_ids': {0 : BATCH}, + # 'txt_ids': {0 : BATCH, + # 1 : SEQ_LEN, + # }, + # 'guidance': None, + # 'joint_attention_kwargs': None, + # 'return_dict': None, + # } + + with timer(logger=logger, name="ep_gen"): + with torch.no_grad(): + + # model = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell",torch_dtype=torch.float16) + model = MyModule(pipe.transformer).eval().half().to(device) + # try: + # logger.info("Trying to export the model using torch.export.export()..") + # # print("Trying to export the model using torch.export.export()..") + # # strict=False only enables aotautograd tracing and excludes dynamo. + # # Have core dump this path + # ep = torch.export.export( + # model, args=example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=False + # ) + # except: + logger.info("Directly use _export because torch.export.export doesn't work") + # This API is used to express the constraint violation guards as asserts in the graph. + from torch.export._trace import _export + ep = _export( + model, + args=example_args, + # kwargs=example_kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + logger.info(f"Generating TRT engine now, dryrun={dryrun}...") + # print("Generating TRT engine now...") + #TODO: if some non-tensor input, do we still need to provide them. + with timer(logger, "trt_gen"): + with torch_tensorrt.logging.debug(): + trt_start = time.time() + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=list(example_args), + enabled_precisions={torch.float32}, + truncate_double=True, + device=torch.device(device), + disable_tf32=True, + use_explicit_typing=True, + dryrun=dryrun, + debug=True, + use_fp32_acc=True, + ) + trt_end = time.time() + + del pipe + del ep + del model + import gc + gc.collect() + torch.cuda.empty_cache() + + with timer(logger, "trt_save"): + try: + trt_ep = torch.export.export(trt_model, args=example_args, + dynamic_shapes=dynamic_shapes, strict=False) + torch.export.save(trt_ep, "trt.ep") + except Exception as e: + import traceback + # Capture the full traceback + tb = traceback.format_exc() + logger.warning("An error occurred. Here's the traceback:") + # print(tb) + logger.warning(tb) + breakpoint() + torch_tensorrt.save(trt_model, "trt.ep") + # finally: + # breakpoint() + + + + # if not dryrun: + # pipe.transformer.forward = MyModule(trt_model).forward + # with timer(logger, "trt_infer"): + # image = pipe( + # prompt, + # guidance_scale=0.0, + # num_inference_steps=4, + # max_sequence_length=256, + # generator=torch.Generator("cpu").manual_seed(0) + # ).images[0] + # image.save("trt_flux-schnell.png") + breakpoint() + + + + + +# breakpoint() +# flux_model_ep = export_llm(model, inputs=) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 93e33db311..88c81164fd 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -685,6 +685,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) breakpoint() + import time + start = time.time() # Skip partitioning if the whole graph is supported to reduce partitioning overhead if num_supported_ops == total_ops: gm_inputs = partitioning.construct_submodule_inputs(gm) @@ -695,6 +697,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: name="whole_graph", engine_cache=engine_cache, ) + end = time.time() + logger.info(f"Conversion time: {end - start} seconds") breakpoint() return trt_module From 1c7a638201c9dfe737fd236c99a8c12b3f3433e1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 4 Dec 2024 22:21:25 +0000 Subject: [PATCH 3/5] fix: skip fusions for fully convertible models --- py/torch_tensorrt/dynamo/_compiler.py | 21 ++----------------- .../partitioning/_adjacency_partitioner.py | 5 +++++ 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 88c81164fd..a2f46763e5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -642,7 +642,7 @@ def compile_module( num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops ) - + dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops dryrun_tracker.compilation_settings = settings @@ -684,24 +684,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) - breakpoint() - import time - start = time.time() - # Skip partitioning if the whole graph is supported to reduce partitioning overhead - if num_supported_ops == total_ops: - gm_inputs = partitioning.construct_submodule_inputs(gm) - trt_module = convert_module( - gm, - gm_inputs, - settings=settings, - name="whole_graph", - engine_cache=engine_cache, - ) - end = time.time() - logger.info(f"Conversion time: {end - start} seconds") - breakpoint() - return trt_module - # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -714,6 +696,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, + skip_fusion=(num_supported_ops == total_ops), ) except torch.fx.passes.splitter_base.FxNetSplitterInternalError: diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 0e9077cdcb..429de3ffbb 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -111,6 +111,7 @@ def __init__( min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, return_tuple: bool = False, + skip_fusion: bool = False, ): """ Preprocesses graph before splitting: @@ -127,6 +128,7 @@ def __init__( self.settings = _SplitterSettingBase( min_acc_module_size=min_block_size, allow_non_tensor=True, + skip_fusion=skip_fusion, ) self.operator_support = operator_support @@ -252,6 +254,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + skip_fusion: bool = False, ) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -262,6 +265,7 @@ def partition( min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT + skip_fusion: Skip fusions found by FxNetAccFusionsFinder Returns: torch.fx.GraphModule, OpSupportTester """ @@ -277,6 +281,7 @@ def partition( supported_ops, min_block_size=min_block_size, require_full_compilation=require_full_compilation, + skip_fusion=skip_fusion, ) partitioned_graph = partitioner.partition_graph() From 6e275ba1575e647980d07cf167dacda7a659a68e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 5 Dec 2024 10:54:43 -0800 Subject: [PATCH 4/5] chore: remove flux example --- examples/dynamo/torch_export_flux.py | 254 --------------------------- 1 file changed, 254 deletions(-) delete mode 100644 examples/dynamo/torch_export_flux.py diff --git a/examples/dynamo/torch_export_flux.py b/examples/dynamo/torch_export_flux.py deleted file mode 100644 index 3a2559235b..0000000000 --- a/examples/dynamo/torch_export_flux.py +++ /dev/null @@ -1,254 +0,0 @@ - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from diffusers import FluxPipeline, FluxTransformer2DModel -from utils import export_llm, generate -from torch.export import Dim -from typing import Optional, Dict, Any -import logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -handler = logging.StreamHandler() -handler.setLevel(logging.DEBUG) -logger.addHandler(handler) - -import time -from contextlib import contextmanager - -@contextmanager -def timer(logger, name:str): - logger.info(f"{name} section Start...") - start = time.time() - yield - end = time.time() - logger.info(f"{name} section End...") - logger.info(f"{name} section elapsed time: {end - start} seconds") - -class MyModule(torch.nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = False, **kwargs): - - - return self.module.forward( - hidden_states, - encoder_hidden_states, - pooled_projections, - timestep, - img_ids, - txt_ids, - # guidance, - # joint_attention_kwargs, - # return_dict - ) - -def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length): - from unittest.mock import patch - -# Assume `instance` is your class instance containing the `__call__` method - -# Use patch.object to mock the __call__ method of self.transformer - with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call: - # one step is enough for intercept the inputs - image =instance( - prompt, - guidance_scale=0.0, - num_inference_steps=1, - max_sequence_length=max_sequence_length, - generator=torch.Generator("cpu").manual_seed(0) - ).images[0] - - - # Access the call arguments of the first (or specific) call - if mock_transformer_call.call_args_list: - args, kwargs = mock_transformer_call.call_args_list[0] - # Store the inputs in a tuple - intercepted_inputs = (args, kwargs) - - # print("Intercepted args:", args) - # print("Intercepted kwargs:", kwargs) - return (args, kwargs) - else: - print("No calls were made to self.transformer.__call__") - return (None, None) - - -if __name__ == "__main__": - - # config - dryrun = False - - # parameter setting - batch_size = 2 - max_seq_len = 256 - prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)] - device = "cuda:0" - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", - torch_dtype=torch.float16, num_layers=1) - # pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", - # torch_dtype=torch.float16, device_map="balanced") - pipe.to(device) - # image = pipe( - # prompt, - # guidance_scale=0.0, - # num_inference_steps=4, - # max_sequence_length=256, - # generator=torch.Generator("cpu").manual_seed(0) - # ).images[0] - # image.save("pytorch_flux-schnell.png") - # breakpoint() - # pipe.reset_device_map() - # pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power - example_args , example_kwargs = wrap_pipeline_transformer_call(pipe, prompt, max_seq_len) - tensor_inputs = ['hidden_states', 'timestep', 'pooled_projections', 'encoder_hidden_states', 'txt_ids', 'img_ids' ] - example_kwargs_shapes = {key: example_kwargs[key].shape for key in tensor_inputs} - BATCH = Dim("batch", min=1, max=batch_size) - SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len) - dynamic_shapes = ({0 : BATCH}, - {0 : BATCH, - 1 : SEQ_LEN, - }, - {0 : BATCH}, - {0 : BATCH}, - {0 : BATCH}, - {0 : BATCH, - 1 : SEQ_LEN, - }, - # None, - # None, - # None, - ) - example_args = ( - example_kwargs['hidden_states'], - example_kwargs['encoder_hidden_states'], - example_kwargs['pooled_projections'], - example_kwargs['timestep'], - example_kwargs['img_ids'], - example_kwargs['txt_ids'], - # example_kwargs['guidance'], - # example_kwargs['joint_attention_kwargs'], - # example_kwargs['return_dict'], - - ) - - # dynamic_shapes = {'hidden_states': {0 : BATCH}, - # 'encoder_hidden_states': {0 : BATCH, - # 1 : SEQ_LEN, - # }, - # 'pooled_projections': {0 : BATCH}, - # 'timestep': {0 : BATCH}, - # 'img_ids': {0 : BATCH}, - # 'txt_ids': {0 : BATCH, - # 1 : SEQ_LEN, - # }, - # 'guidance': None, - # 'joint_attention_kwargs': None, - # 'return_dict': None, - # } - - with timer(logger=logger, name="ep_gen"): - with torch.no_grad(): - - # model = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell",torch_dtype=torch.float16) - model = MyModule(pipe.transformer).eval().half().to(device) - # try: - # logger.info("Trying to export the model using torch.export.export()..") - # # print("Trying to export the model using torch.export.export()..") - # # strict=False only enables aotautograd tracing and excludes dynamo. - # # Have core dump this path - # ep = torch.export.export( - # model, args=example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=False - # ) - # except: - logger.info("Directly use _export because torch.export.export doesn't work") - # This API is used to express the constraint violation guards as asserts in the graph. - from torch.export._trace import _export - ep = _export( - model, - args=example_args, - # kwargs=example_kwargs, - dynamic_shapes=dynamic_shapes, - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - - logger.info(f"Generating TRT engine now, dryrun={dryrun}...") - # print("Generating TRT engine now...") - #TODO: if some non-tensor input, do we still need to provide them. - with timer(logger, "trt_gen"): - with torch_tensorrt.logging.debug(): - trt_start = time.time() - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=list(example_args), - enabled_precisions={torch.float32}, - truncate_double=True, - device=torch.device(device), - disable_tf32=True, - use_explicit_typing=True, - dryrun=dryrun, - debug=True, - use_fp32_acc=True, - ) - trt_end = time.time() - - del pipe - del ep - del model - import gc - gc.collect() - torch.cuda.empty_cache() - - with timer(logger, "trt_save"): - try: - trt_ep = torch.export.export(trt_model, args=example_args, - dynamic_shapes=dynamic_shapes, strict=False) - torch.export.save(trt_ep, "trt.ep") - except Exception as e: - import traceback - # Capture the full traceback - tb = traceback.format_exc() - logger.warning("An error occurred. Here's the traceback:") - # print(tb) - logger.warning(tb) - breakpoint() - torch_tensorrt.save(trt_model, "trt.ep") - # finally: - # breakpoint() - - - - # if not dryrun: - # pipe.transformer.forward = MyModule(trt_model).forward - # with timer(logger, "trt_infer"): - # image = pipe( - # prompt, - # guidance_scale=0.0, - # num_inference_steps=4, - # max_sequence_length=256, - # generator=torch.Generator("cpu").manual_seed(0) - # ).images[0] - # image.save("trt_flux-schnell.png") - breakpoint() - - - - - -# breakpoint() -# flux_model_ep = export_llm(model, inputs=) From f4c8befa08d8a90a0ff49db4e8b1dddf2ff08e4d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Dec 2024 09:16:05 -0800 Subject: [PATCH 5/5] chore: minor fix --- py/torch_tensorrt/dynamo/_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4bb521a643..a89d7bbd2c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -788,7 +788,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: require_full_compilation=settings.require_full_compilation, ) - # dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators # The global partitioner leaves non-TRT nodes as-is if not settings.use_fast_partitioner: