Skip to content

Commit

Permalink
[ONNX] Improve the conversion of from dynamic axes to shapes (pytor…
Browse files Browse the repository at this point in the history
…ch#140488)

Features:
(1) Add support for tree structure.
(2) Add user warning before axes to shapes conversion
(3) Add suggestion of providing `dynamic_shapes` when conversion fails

Notes:
(1) `input_names` is crucial to the conversion, as we don't know the ONNX graph inputs.
(2) min and max are set as default, so LLM has higher chance to fail if users use `dynamic_axes` in terms of the min/max constraints dependency between `attention_mask` and `sequence_length`, etc. (Found in llama-3.2-1B_Instruct)
Pull Request resolved: pytorch#140488
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Nov 15, 2024
1 parent 9482476 commit 865a7c5
Show file tree
Hide file tree
Showing 3 changed files with 377 additions and 40 deletions.
59 changes: 59 additions & 0 deletions test/onnx/exporter/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def forward(self, x, b):
return x.relu(), b.sigmoid()


class NestedModelForDynamicShapes(torch.nn.Module):
def forward(
self,
x: torch.Tensor,
ys: list[torch.Tensor],
zs: dict[str, torch.Tensor],
c: torch.Tensor,
):
y = ys[0] + ys[1] + zs["a"] + zs["b"]
w = 5
if x.shape[0] < 3 and c.shape[0] != 4:
return x + w, x + y, c
else:
return x - w, x - y, c


class TestExportAPIDynamo(common_utils.TestCase):
"""Tests for the ONNX exporter API when dynamo=True."""

Expand Down Expand Up @@ -71,6 +87,7 @@ def test_dynamic_axes_supports_partial_dynamic_shapes(self):
self.assert_export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
input_names=["x", "b"],
dynamic_axes={
"b": [0, 1, 2],
},
Expand All @@ -80,6 +97,7 @@ def test_dynamic_axes_supports_output_names(self):
self.assert_export(
SampleModelForDynamicShapes(),
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
input_names=["x", "b"],
dynamic_axes={
"b": [0, 1, 2],
},
Expand Down Expand Up @@ -181,6 +199,47 @@ def forward(self, x):
assert onnx_program is not None
onnx_testing.assert_onnx_program(onnx_program)

def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(self):
# kwargs can still be renamed as long as it's in order
input_names = ["input_x", "input_y", "input_z", "d", "e", "f"]

dynamic_axes = {
"input_x": {0: "dim"},
"input_y": {0: "dim"},
"input_z": {0: "dim"},
"d": {0: "dim"},
"e": {0: "dim"},
}

model = NestedModelForDynamicShapes()
input = (
torch.ones(5),
[torch.zeros(5), torch.ones(5)],
{"a": torch.zeros(5), "b": torch.ones(5)},
torch.ones(4),
)

self.assert_export(
model, input, dynamic_axes=dynamic_axes, input_names=input_names
)

# Check whether inputs are dynamically shaped
onnx_program = torch.onnx.export(
model,
input,
dynamic_axes=dynamic_axes,
input_names=input_names,
dynamo=True,
)
self.assertTrue(
all(
[
input.type.tensor_type.shape.dim[0].dim_param
for input in onnx_program.model_proto.graph.input
][:-1]
)
)

def test_refine_dynamic_shapes_with_onnx_export(self):
# NOTE: From test/export/test_export.py

Expand Down
249 changes: 249 additions & 0 deletions test/onnx/exporter/test_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the _compat module."""

from __future__ import annotations

import torch
from torch.onnx._internal.exporter import _compat
from torch.testing._internal import common_utils
from torch.utils import _pytree


class SingnatureOnlyLlamaModel(torch.nn.Module):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
num_logits_to_keep: int = 0,
**kwargs,
):
pass


@common_utils.instantiate_parametrized_tests
class TestPyTreeDynamicAxesShapes(common_utils.TestCase):
# The test can't be parametrized because the torch.export.Dim generates objects,
# and we need the exact same object to compare them.
def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_tuple(self):
inputs = (torch.randn(1, 2, 3), torch.randn(1, 2, 3))
x_dim = torch.export.Dim("x_dim_0")
y_dim = torch.export.Dim("y_dim_1")
dynamic_shapes = {
"x": {0: x_dim},
"y": {1: y_dim},
}
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
inputs, dynamic_shapes
)

expected_dynamic_shapes = (
{0: x_dim},
{1: y_dim},
)
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)

def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_dict(self):
inputs = {"x": torch.randn(1, 2, 3), "y": torch.randn(1, 2, 3)}
x_dim = torch.export.Dim("x_dim_0")
y_dim = torch.export.Dim("y_dim_1")
dynamic_shapes = {
"x": {0: x_dim},
"y": {1: y_dim},
}
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
inputs, dynamic_shapes
)

expected_dynamic_shapes = {
"x": {0: x_dim},
"y": {1: y_dim},
}
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)

def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_tuple_of_mixed_structure(
self,
):
inputs = (
torch.randn(1, 2, 3),
({"x0": torch.randn(1, 2, 3)}, {"x1": torch.randn(1, 2, 3)}),
(torch.randn(1, 2, 3), torch.randn(1, 2, 3)),
[torch.randn(1, 2, 3), torch.randn(1, 2, 3)],
)
w_dim_0 = torch.export.Dim("w_dim_0")
x0_dim_1 = torch.export.Dim("x0_dim_1")
x0_dim_2 = torch.export.Dim("x0_dim_2")
x1_dim_1 = torch.export.Dim("x1_dim_1")
y0_dim_0 = torch.export.Dim("y0_dim_0")
y0_dim_1 = torch.export.Dim("y0_dim_1")
y1_dim_2 = torch.export.Dim("y1_dim_2")
z0_dim_2 = torch.export.Dim("z0_dim_2")
z1_dim_1 = torch.export.Dim("z1_dim_1")
dynamic_shapes = {
"w": {0: w_dim_0},
"x0": {1: x0_dim_1, 2: x0_dim_2},
"x1": {1: x1_dim_1},
"y0": {0: y0_dim_0, 1: y0_dim_1},
"y1": {2: y1_dim_2},
"z0": {2: z0_dim_2},
"z1": {1: z1_dim_1},
}
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
inputs, dynamic_shapes
)
expected_dynamic_shapes = (
{0: w_dim_0},
({"x0": {1: x0_dim_1, 2: x0_dim_2}}, {"x1": {1: x1_dim_1}}),
({0: y0_dim_0, 1: y0_dim_1}, {2: y1_dim_2}),
[{2: z0_dim_2}, {1: z1_dim_1}],
)
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)

@common_utils.parametrize(
"model, args, kwargs,input_names, output_names, dynamic_axes, expected_dynamic_shapes",
[
# llama-3.2-1B-Instruct (trimmed)
(
SingnatureOnlyLlamaModel(),
(),
{
"input_ids": torch.randn(2, 16),
"attention_mask": torch.randn(2, 32),
"position_ids": torch.randn(2, 16),
"past_key_values": [
(torch.randn(2, 8, 16, 64), torch.randn(2, 8, 16, 64)),
(torch.randn(2, 8, 16, 64), torch.randn(2, 8, 16, 64)),
],
},
[
"input_ids",
"attention_mask",
"position_ids",
"past_key_values.0.key",
"past_key_values.0.value",
"past_key_values.1.key",
"past_key_values.1.value",
],
[
"logits",
"present.0.key",
"present.0.value",
"present.1.key",
"present.1.value",
],
{
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {
0: "batch_size",
1: "past_sequence_length + sequence_length",
},
"position_ids": {0: "batch_size", 1: "sequence_length"},
"past_key_values.0.key": {
0: "batch_size",
2: "past_sequence_length",
},
"past_key_values.0.value": {
0: "batch_size",
2: "past_sequence_length",
},
"past_key_values.1.key": {
0: "batch_size",
2: "past_sequence_length",
},
"past_key_values.1.value": {
0: "batch_size",
2: "past_sequence_length",
},
"logits": {0: "batch_size", 1: "sequence_length"},
"present.0.key": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.0.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.1.key": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
"present.1.value": {
0: "batch_size",
2: "past_sequence_length + sequence_length",
},
},
[
{
0: torch.export.Dim("batch_size"),
1: torch.export.Dim("sequence_length"),
},
{
0: torch.export.Dim("batch_size"),
1: torch.export.Dim("past_sequence_lengthsequence_length"),
},
{
0: torch.export.Dim("batch_size"),
1: torch.export.Dim("sequence_length"),
},
[
(
{
0: torch.export.Dim("batch_size"),
2: torch.export.Dim("past_sequence_length"),
},
{
0: torch.export.Dim("batch_size"),
2: torch.export.Dim("past_sequence_length"),
},
),
(
{
0: torch.export.Dim("batch_size"),
2: torch.export.Dim("past_sequence_length"),
},
{
0: torch.export.Dim("batch_size"),
2: torch.export.Dim("past_sequence_length"),
},
),
],
],
)
],
)
def test__from_dynamic_axes_to_dynamic_shapes_succeeds_on_llm(
self,
model,
args,
kwargs,
input_names,
output_names,
dynamic_axes,
expected_dynamic_shapes,
):
dynamic_shapes = _compat._from_dynamic_axes_to_dynamic_shapes(
model,
args,
kwargs,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)

# NOTE: torch.export.Dim being an object makes it impossible to compare the objects directly.
# And it's unrealistic to test whole model, so we are testing the structure of the dynamic_shapes.
_, tree1 = _pytree.tree_flatten(dynamic_shapes)
_, tree2 = _pytree.tree_flatten(expected_dynamic_shapes)
self.assertEqual(tree1, tree2)


if __name__ == "__main__":
common_utils.run_tests()
Loading

0 comments on commit 865a7c5

Please sign in to comment.