Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

select_scatter decomp #2515

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ def slice_scatter_decomposition(
return output_tensor


@register_torch_trt_decomposition(
torch.ops.aten.select_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def select_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
index: int,
) -> torch.Tensor:
src_tensor = torch.unsqueeze(src_tensor, dim)
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 192 additions & 3 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand All @@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimZero_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimOne_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_multidimension_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading