Skip to content

Commit

Permalink
chore: add update_metadata utility
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Jan 27, 2024
1 parent 51e8bb7 commit 19c3fad
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -21,6 +22,7 @@
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
view_to_reshape,
]
)

Expand Down
19 changes: 18 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, Dict, List

import torch

Expand Down Expand Up @@ -29,3 +29,20 @@ def get_tensor_placeholders(
]

return placeholders


def update_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: Dict[int, torch._ops.OpOverload]
) -> None:
"""
Given a graph and a node which has target_op in the graph,
a) If the node has metadata, store it in the map
b) If the node does not have metadata, retrieve it from the map
and assign to the node.
"""
for idx, node in enumerate(gm.graph.nodes):
if node.target == target_op:
if idx not in metadata and node.meta:
metadata[idx] = node.meta
elif idx in metadata and not node.meta:
node.meta = metadata[idx]
37 changes: 17 additions & 20 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import Dict, List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
update_metadata,
)

logger = logging.getLogger(__name__)
Expand All @@ -13,29 +14,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op and copy it to the replacement op
meta_map: Dict[int, torch._ops.OpOverload] = {}
update_metadata(gm, orig_op, meta_map)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

update_metadata(gm, replacement_op, meta_map)

return gm
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def construct_dynamic_input(input: Any) -> Input:
if var_range.lower == 2:
min_shape.append(1)
else:
min_shape.append(var_range.lower)
opt_shape.append(var_val)
max_shape.append(var_range.upper)
min_shape.append(int(var_range.lower))
opt_shape.append(int(var_val))
max_shape.append(int(var_range.upper))
else:
min_shape.append(dim)
opt_shape.append(dim)
Expand Down
3 changes: 1 addition & 2 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import pytest
import timm
import torch
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()

Expand Down

0 comments on commit 19c3fad

Please sign in to comment.