-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Cannot convert simple torchscript containing two torch.nn.Upsample operations #1823
Comments
I've reproduced the error on main, and it is occurring on this line, where the operation TensorRT/core/lowering/lowering.cpp Line 106 in a245b86
The graph at that point is shown below, and has a Graphgraph(%X.1 : Tensor):
%2 : bool = prim::Constant[value=0]()
%3 : float[] = prim::Constant[value=[2., 2.]]()
%4 : str = prim::Constant[value="bilinear"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/upsampling.py:156:66
%5 : int = prim::Constant[value=5]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3956:22
%6 : int = prim::Constant[value=3]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3952:22
%7 : int = prim::Constant[value=4]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3949:76
%8 : int = prim::Constant[value=2]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:24
%9 : NoneType = prim::Constant() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3871:32
%10 : str = prim::Constant[value="builtins.ValueError"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3872:18
%11 : str = prim::Constant[value="Input and scale_factor must have the same number of spatial dimensions, but got input with spatial dimensions of {} and scale_factor of shape {}. Please provide input tensor in (N, C, d1, d2, ...,dK) format and scale_factor in (s1, s2, ...,sK) format."]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20
%12 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:34
%13 : str = prim::Constant[value="builtins.NotImplementedError"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:14
%14 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:34
%15 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8
%16 : int = prim::Constant[value=1]()
%48 : int = prim::Constant[value=2]()
%49 : str = prim::Constant[value="builtins.ValueError"]()
%50 : float[] = prim::Constant[value=[2., 2.]]()
%51 : str = prim::Constant[value="Input and scale_factor must have the same number of spatial dimensions, but got input with spatial dimensions of {} and scale_factor of shape {}. Please provide input tensor in (N, C, d1, d2, ...,dK) format and scale_factor in (s1, s2, ...,sK) format."]()
%52 : int = prim::Constant[value=1]()
%53 : NoneType = prim::Constant()
%54 : int = prim::Constant[value=4]()
%55 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]()
%56 : str = prim::Constant[value="bilinear"]()
%57 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]()
%58 : int = prim::Constant[value=5]()
%59 : str = prim::Constant[value="builtins.NotImplementedError"]()
%60 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]()
%61 : int = prim::Constant[value=3]()
%62 : bool = prim::Constant[value=0]()
%63 : Tensor = prim::Uninitialized()
%64 : int = aten::dim(%X.1) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:10
%dim.2 : int = aten::sub(%64, %48) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:10
%66 : bool = aten::ne(%48, %dim.2) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:15
= prim::If(%66) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:12
block0():
%67 : int[] = aten::size(%X.1) # <string>:13:9
%68 : int[] = aten::slice(%67, %48, %53, %52) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:26
%69 : int[] = aten::list(%68) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:21
%70 : str = aten::format(%51, %69, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20
= prim::RaiseException(%70, %49) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3908:16
-> ()
block1():
-> ()
%71 : bool = aten::eq(%64, %54) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:7
%out1.1 : Tensor = prim::If(%71) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:4
block0():
%73 : Tensor = aten::upsample_bilinear2d(%X.1, %53, %62, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3983:15
-> (%73)
block1():
%74 : bool = aten::eq(%64, %61) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:7
= prim::If(%74) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:4
block0():
= prim::RaiseException(%60, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:8
-> ()
block1():
-> ()
%75 : bool = aten::eq(%64, %58) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:7
= prim::If(%75) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:4
block0():
= prim::RaiseException(%57, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:8
-> ()
block1():
-> ()
%76 : str = aten::format(%55, %64, %56) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8
= prim::RaiseException(%76, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4006:4
-> (%63)
= prim::If(%66) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:12
block0():
%77 : int[] = aten::size(%X.1) # <string>:13:9
%78 : int[] = aten::slice(%77, %48, %53, %52) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:26
%79 : int[] = aten::list(%78) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:21
%80 : str = aten::format(%51, %79, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20
= prim::RaiseException(%80, %49) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3908:16
-> ()
block1():
-> ()
%out2.1 : Tensor = prim::If(%71) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:4
block0():
%82 : Tensor = aten::upsample_bilinear2d(%X.1, %53, %62, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3983:15
-> (%82)
block1():
%83 : bool = aten::eq(%64, %61) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:7
= prim::If(%83) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:4
block0():
= prim::RaiseException(%60, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:8
-> ()
block1():
-> ()
%84 : bool = aten::eq(%64, %58) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:7
= prim::If(%84) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:4
block0():
= prim::RaiseException(%57, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:8
-> ()
block1():
-> ()
%85 : str = aten::format(%55, %64, %56) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8
= prim::RaiseException(%85, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4006:4
-> (%63)
%out.1 : Tensor = aten::add(%out1.1, %out2.1, %16) # unrelated.py:24:14
return (%out.1) @bowang007 - this seems related to your work with exceptions and control flow, do you have any suggestions on this? |
Thanks @gs-olive for having taken a look at this issue so quickly! Nice that you could reproduce it! Network definitionimport torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
class Block(nn.Module):
def __init__(self, in_channel, out_channel):
super(Block, self).__init__()
self.conv = nn.Conv2d(
in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False
)
self.norm = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False
)
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.relu(out)
out = self.maxpool(out)
return out
class Network(torch.nn.Module):
def __init__(self, num_classes=2):
super(Network, self).__init__()
self.num_classes = num_classes
self.block1 = Block(3, 32)
self.block2 = Block(32, 64)
self.upsample1 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=False
)
self.upsample2 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=False
)
self.conv = nn.Conv2d(64, num_classes, 1, bias=True)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
gclayer1 = self.upsample1(out)
gclayer2 = self.upsample2(gclayer1)
out = self.conv(gclayer2)
return out Interestingly, when converting a torchscript generated from that network using Both graphs, as printed by Torch-TensorRT when calling Graph diff
Please let me know if I can provide more details to help solve this issue! |
Thank you for the additional details - this is very helpful! |
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842), this commit implements the proposal in a way that is extremely specific to Upsample resulting in upsample_bilinear2d.
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842), this commit implements the proposal in a way that is extremely specific to Upsample resulting in upsample_bilinear2d.
Hi @gs-olive I would appreciate some feedback on this. Is my approach going in the right direction, w.r.t. what you had in mind when describing the #1842 issue? |
Also, what do you (or @bowang007 maybe?) think of the changes I made to
or
you could also catch more complex blocks, given that:
What I've typically been observing in the scope of the Upsample investigation is something like:
|
Hi @gcuendet - thanks for the update! The commit linked here is definitely along the lines of what was intended for #1842. One thing I was wondering about for that commit - on line 83 - was there an issue with calling Regarding the changes made to
Also @narendasan for any comments on the proposed edits to |
Thanks for the quick feedback! I am still fiddling with these changes and trying to make them work in more generic cases than the two overly simplified networks shared above, but regarding your questions:
At some point I had the impression that destroying the node was unnecessary, because some dead code removal pass would do it for you (I did observe that in some cases, but that might not be true in all cases). Maybe it could still be good to try to do it at that point, that way if the outputs are not properly replaced, this would fail. I'll check that.
I don't think
You are right, but we also check earlier that if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
} So the hypothesis at that point is that none of the arms actually return a Value. In the example you point to (the |
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842), this commit implements the proposal in a way that is extremely specific to Upsample resulting in upsample_bilinear2d.
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842), this commit implements the proposal in a way that is extremely specific to Upsample resulting in upsample_bilinear2d.
Hi! I have a small update regarding this work. Upsample bilinear 2D exception eliminationNew commit implementing the custom and specific A small note on that commit: even though the Regarding your comment 1. above,
I think my previous answer was not completely accurate:
Exception eliminationNew commit implementing changes to I changed just slightly the implementation, most importantly to verify that the block of the Let me know if that's of interest to you, I'll be happy to open a PR! |
Just for completeness: using
This seems linked to the check on the validity of the A workaround for this is to not call I also include the lowered graph below (lowered with all the changes described above): Lowered graphINFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor): [540/18335]
%4 : float = prim::Constant[value=2.]()
%14 : int[] = prim::Constant[value=[0, 0]]()
%16 : int[] = prim::Constant[value=[1, 1]]()
%17 : int = prim::Constant[value=1]()
%18 : bool = prim::Constant[value=1]()
%self.conv.bias : Float(2, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=0.01 * 7.4246 -8.4862 [ CUDAFloatType{2} ]]()
%self.conv.weight : Float(2, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.conv.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.conv.bias.15 : NoneType = prim::Constant()
%self.block1.norm.running_var : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.norm.running_mean : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.norm.training.15 : bool = prim::Constant[value=0]()
%452 : float = prim::Constant[value=0.10000000000000001]()
%453 : float = prim::Constant[value=1.0000000000000001e-05]()
%455 : int[] = prim::Constant[value=[2, 2]]()
%self.block2.conv.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block2.norm.running_var : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block2.norm.running_mean : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%460 : str[] = prim::Constant[value=["nearest", "area", "nearest-exact"]]()
%461 : str = prim::Constant[value="bilinear"]()
%462 : str = prim::Constant[value="builtins.ValueError"]()
%463 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]()
%464 : int = prim::Constant[value=2]()
%579 : bool = prim::Constant[value=0]()
%580 : int[] = prim::Constant[value=[0, 0]]()
%581 : Tensor = aten::_convolution(%x.1, %self.block1.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %579, %580, %17, %579, %579, %579, %579)
%out1.2 : Tensor = aten::batch_norm(%581, %self.block1.norm.running_var, %self.block1.norm.running_mean, %self.block1.norm.running_mean, %self.block1.norm.running_var, %self.block1.norm.training.15, %452, %453, %18)
%467 : Tensor = aten::relu(%out1.2)
%out.2 : Tensor = aten::max_pool2d(%467, %455, %455, %14, %16, %self.block1.norm.training.15)
%582 : bool = prim::Constant[value=0]()
%583 : int[] = prim::Constant[value=[0, 0]]()
%584 : Tensor = aten::_convolution(%out.2, %self.block2.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %582, %583, %17, %582, %582, %582, %582)
%out0.1 : Tensor = aten::batch_norm(%584, %self.block2.norm.running_var, %self.block2.norm.running_mean, %self.block2.norm.running_mean, %self.block2.norm.running_var, %self.block1.norm.training.15, %452, %453, %18)
%471 : Tensor = aten::relu(%out0.1)
%out0.2 : Tensor = aten::max_pool2d(%471, %455, %455, %14, %16, %self.block1.norm.training.15)
%474 : bool? = prim::Uninitialized() # :0:0
%475 : bool = prim::Uninitialized() # :0:0
%476 : bool = aten::__contains__(%460, %461)
%align_corners0.1 : bool? = prim::If(%476)
block0():
= prim::RaiseException(%463, %462)
-> (%474)
block1():
-> (%self.block1.norm.training.15)
%478 : int = aten::dim(%out0.2)
%dim.2 : int = aten::sub(%478, %464)
%scale_factors2.2 : float[] = prim::ListConstruct()
= prim::Loop(%dim.2, %18)
block0(%49 : int):
%50 : float[] = aten::append(%scale_factors2.2, %4)
-> (%18)
%480 : int = prim::Constant[value=4]()
%481 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]()
%482 : str = prim::Constant[value="AssertionError: "]()
%485 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]()
%486 : int = prim::Constant[value=3]()
%487 : str = prim::Constant[value="builtins.NotImplementedError"]()
%488 : int = prim::Constant[value=5]()
%489 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]()
%494 : bool = aten::eq(%478, %480)
%571 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15)
%align_corners6.2 : bool = prim::If(%571)
block0():
%align_corners7.3 : bool = prim::unchecked_cast(%align_corners0.1)
-> (%align_corners7.3)
block1():
= prim::RaiseException(%482, %self.block1.conv.bias.15)
-> (%475)
%574 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias.15, %align_corners6.2, %scale_factors2.2)
= prim::If(%494)
block0():
= prim::If(%571)
block0():
-> ()
block1():
= prim::RaiseException(%482, %self.block1.conv.bias.15)
-> ()
-> ()
block1():
%500 : bool = aten::eq(%478, %486)
= prim::If(%500)
block0():
= prim::RaiseException(%485, %487)
-> ()
block1():
-> ()
%501 : bool = aten::eq(%478, %488)
= prim::If(%501)
block0():
= prim::RaiseException(%489, %487)
-> ()
block1():
-> ()
%502 : str = aten::format(%481, %478, %461)
= prim::RaiseException(%502, %487)
-> ()
%align_corners0 : bool? = prim::If(%self.block1.norm.training.15)
block0():
= prim::RaiseException(%463, %462)
-> (%474)
block1():
-> (%self.block1.norm.training.15)
%504 : int = aten::dim(%574)
%dim.1 : int = aten::sub(%504, %464)
%scale_factors2.1 : float[] = prim::ListConstruct()
= prim::Loop(%dim.1, %18)
block0(%64 : int):
%65 : float[] = aten::append(%scale_factors2.1, %4)
-> (%18)
%516 : bool = aten::eq(%504, %480)
%575 : bool = aten::__isnot__(%align_corners0, %self.block1.conv.bias.15)
%align_corners6.6 : bool = prim::If(%575)
block0():
%align_corners7.7 : bool = prim::unchecked_cast(%align_corners0)
-> (%align_corners7.7)
block1():
= prim::RaiseException(%482, %self.block1.conv.bias.15)
-> (%475)
%578 : Tensor = aten::upsample_bilinear2d(%574, %self.block1.conv.bias.15, %align_corners6.6, %scale_factors2.1)
= prim::If(%516)
block0():
= prim::If(%575)
block0():
-> ()
block1():
= prim::RaiseException(%482, %self.block1.conv.bias.15)
-> ()
-> ()
block1():
%522 : bool = aten::eq(%504, %486)
= prim::If(%522)
block0():
= prim::RaiseException(%485, %487)
-> ()
block1():
-> ()
%523 : bool = aten::eq(%504, %488)
= prim::If(%523)
block0():
= prim::RaiseException(%489, %487)
-> ()
block1():
-> ()
%524 : str = aten::format(%481, %504, %461)
= prim::RaiseException(%524, %487)
-> ()
%585 : bool = prim::Constant[value=0]()
%586 : int[] = prim::Constant[value=[0, 0]]()
%587 : Tensor = aten::_convolution(%578, %self.conv.weight, %self.conv.bias, %16, %14, %16, %585, %586, %17, %585, %585, %585, %585)
return (%587) |
Hi @gcuendet - thank you very much for all of the work and detailed answers on this topic. I made a few comments on the @narendasan - do you have any input on the proposed changes to Regarding the |
Thanks @bowang007 . Did you test on the small network described above? Is it working for you?
For reference the lowered graph is the following. Lowered graphINFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor): [588/19311]
%4 : float = prim::Constant[value=2.]()
%14 : int[] = prim::Constant[value=[0, 0]]()
%16 : int[] = prim::Constant[value=[1, 1]]()
%17 : int = prim::Constant[value=1]()
%18 : bool = prim::Constant[value=1]()
%self.conv.bias : Float(2, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=0.01 * 7.4246 -8.4862 [ CUDAFloatType{2} ]]()
%self.conv.weight : Float(2, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.conv.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.conv.bias.15 : NoneType = prim::Constant()
%self.block1.norm.running_var : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.norm.running_mean : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block1.norm.training.15 : bool = prim::Constant[value=0]()
%452 : float = prim::Constant[value=0.10000000000000001]()
%453 : float = prim::Constant[value=1.0000000000000001e-05]()
%455 : int[] = prim::Constant[value=[2, 2]]()
%self.block2.conv.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block2.norm.running_var : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.block2.norm.running_mean : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%460 : str[] = prim::Constant[value=["nearest", "area", "nearest-exact"]]()
%461 : str = prim::Constant[value="bilinear"]()
%462 : str = prim::Constant[value="builtins.ValueError"]()
%463 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]()
%464 : int = prim::Constant[value=2]()
%579 : bool = prim::Constant[value=0]()
%580 : int[] = prim::Constant[value=[0, 0]]()
%581 : Tensor = aten::_convolution(%x.1, %self.block1.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %579, %580, %17, %579, %579, %579, %579)
%out1.2 : Tensor = aten::batch_norm(%581, %self.block1.norm.running_var, %self.block1.norm.running_mean, %self.block1.norm.running_mean, %self.block1.norm.running_var, %self.block1.norm.training.15, %452, %453, %18)
%467 : Tensor = aten::relu(%out1.2)
%out.2 : Tensor = aten::max_pool2d(%467, %455, %455, %14, %16, %self.block1.norm.training.15)
%582 : bool = prim::Constant[value=0]()
%583 : int[] = prim::Constant[value=[0, 0]]()
%584 : Tensor = aten::_convolution(%out.2, %self.block2.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %582, %583, %17, %582, %582, %582, %582)
%out0.1 : Tensor = aten::batch_norm(%584, %self.block2.norm.running_var, %self.block2.norm.running_mean, %self.block2.norm.running_mean, %self.block2.norm.running_var, %self.block1.norm.training.15, %452, %453, %18)
%471 : Tensor = aten::relu(%out0.1)
%out0.2 : Tensor = aten::max_pool2d(%471, %455, %455, %14, %16, %self.block1.norm.training.15)
%474 : bool? = prim::Uninitialized() # :0:0
%475 : bool = prim::Uninitialized() # :0:0
%476 : bool = aten::__contains__(%460, %461)
%align_corners0.1 : bool? = prim::If(%476)
block0():
= prim::RaiseException(%463, %462)
-> (%474)
block1():
-> (%self.block1.norm.training.15)
%478 : int = aten::dim(%out0.2)
%dim.2 : int = aten::sub(%478, %464)
%scale_factors2.2 : float[] = prim::ListConstruct()
= prim::Loop(%dim.2, %18)
block0(%49 : int):
%50 : float[] = aten::append(%scale_factors2.2, %4)
-> (%18)
%480 : int = prim::Constant[value=4]()
%482 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]()
%483 : int = prim::Constant[value=5]()
%485 : str = prim::Constant[value="AssertionError: "]()
%486 : str = prim::Constant[value="builtins.NotImplementedError"]()
%487 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]()
%488 : int = prim::Constant[value=3]()
%489 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]()
%494 : bool = aten::eq(%478, %480)
%571 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15)
%align_corners6.2 : bool = prim::If(%571)
block0():
%align_corners7.3 : bool = prim::unchecked_cast(%align_corners0.1)
-> (%align_corners7.3)
block1():
= prim::RaiseException(%485, %self.block1.conv.bias.15)
-> (%475)
%574 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias.15, %align_corners6.2, %scale_factors2.2)
= prim::If(%494)
block0():
= prim::If(%571)
block0():
-> ()
block1():
= prim::RaiseException(%485, %self.block1.conv.bias.15)
-> ()
-> ()
block1():
%500 : bool = aten::eq(%478, %488)
= prim::If(%500)
block0():
= prim::RaiseException(%487, %486)
-> ()
block1():
-> ()
%501 : bool = aten::eq(%478, %483)
= prim::If(%501)
block0():
= prim::RaiseException(%489, %486)
-> ()
block1():
-> ()
%502 : str = aten::format(%482, %478, %461)
= prim::RaiseException(%502, %486)
-> ()
%align_corners0 : bool? = prim::If(%self.block1.norm.training.15)
block0():
= prim::RaiseException(%463, %462)
-> (%474)
block1():
-> (%self.block1.norm.training.15)
%504 : int = aten::dim(%574)
%dim.1 : int = aten::sub(%504, %464)
%scale_factors2.1 : float[] = prim::ListConstruct()
= prim::Loop(%dim.1, %18)
block0(%64 : int):
%65 : float[] = aten::append(%scale_factors2.1, %4)
-> (%18)
%516 : bool = aten::eq(%504, %480)
%575 : bool = aten::__isnot__(%align_corners0, %self.block1.conv.bias.15)
%align_corners6.6 : bool = prim::If(%575)
block0():
%align_corners7.7 : bool = prim::unchecked_cast(%align_corners0)
-> (%align_corners7.7)
block1():
= prim::RaiseException(%485, %self.block1.conv.bias.15)
-> (%475)
%578 : Tensor = aten::upsample_bilinear2d(%574, %self.block1.conv.bias.15, %align_corners6.6, %scale_factors2.1)
= prim::If(%516)
block0():
= prim::If(%575)
block0():
-> ()
block1():
= prim::RaiseException(%485, %self.block1.conv.bias.15)
-> ()
-> ()
block1():
%522 : bool = aten::eq(%504, %488)
= prim::If(%522)
block0():
= prim::RaiseException(%487, %486)
-> ()
block1():
-> ()
%523 : bool = aten::eq(%504, %483)
= prim::If(%523)
block0():
= prim::RaiseException(%489, %486)
-> ()
block1():
-> ()
%524 : str = aten::format(%482, %504, %461)
= prim::RaiseException(%524, %486)
-> ()
%585 : bool = prim::Constant[value=0]()
%586 : int[] = prim::Constant[value=[0, 0]]()
%587 : Tensor = aten::_convolution(%578, %self.conv.weight, %self.conv.bias, %16, %14, %16, %585, %586, %17, %585, %585, %585, %585)
return (%587) |
Hey @gcuendet,
Any details that I might miss? |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
Bug Description
Scripting a simple "network" containing two
torch.nn.Upsample
modules and trying to convert the resulting torchscript does not work.To Reproduce
Steps to reproduce the behavior:
torch.jit.script
.Expected behavior
The conversions succeeds and a new valid torchscript is obtained.
Environment
I managed to reproduce the bug both when using pytorch 1.11 and torch-tensorRT 1.1.0 and using pytorch 1.13.1 and torch-tensorRT main.
Torch-TensorRT 1.1.0
conda
,pip
,libtorch
, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRTWhen using torch-tensorRT 1.1.0, I get the following error:
That looked kind of similar to this issue and patching Torch-TensorRT with this PR makes the behavior exactly the same as in the second case (i.e. when using pytorch 1.13.1 and torch-tensorRT main).
Torch-TensorRT main (commit 861edd0)
conda
,pip
,libtorch
, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRTWhen using torch-TensorRT main, the conversion just hangs for ever after
Additional context
Interestingly, when using the tracing mechanism of pytorch to generate the torchscript, everything seems fine (I didn't check the results, but the conversion finishes properly).
Also, when scripting with pytorch 1.9, everything works fine 🤯
The thing I noticed is that pytorch changed slightly the
torch.nn.interpolate
API and I am wondering if that could explain (at least partially) the problem:torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)
See the attached .zip file containing a python file to generate the torchscript.
upsample.zip
Let me know if you need more details to reproduce the problem. Thanks!
The text was updated successfully, but these errors were encountered: