Skip to content

Commit

Permalink
Add typing annotations for torch.nn.quantized.dynamic.modules.rnn (#4…
Browse files Browse the repository at this point in the history
…3186)

Summary:
Fixes pytorch/pytorch#43185

xref: [gh-43072](pytorch/pytorch#43072)

Pull Request resolved: pytorch/pytorch#43186

Reviewed By: ezyang

Differential Revision: D23441259

Pulled By: malfet

fbshipit-source-id: 80265ae7f3a70f0087e620969dbd4aa8ca17c317
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Sep 1, 2020
1 parent 8ca3913 commit 63a0bb0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 46 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ ignore_errors = True
[mypy-torch.nn.qat.modules.conv]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.rnn]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True

Expand Down
91 changes: 48 additions & 43 deletions torch/nn/quantized/dynamic/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import Tuple, Optional, List # noqa: F401
from torch._jit_internal import Tuple, Optional, List, Union, Dict # noqa: F401
from torch.nn.utils.rnn import PackedSequence
from torch.nn.quantized.modules.utils import _quantize_weight

def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)

class PackedParameter(torch.nn.Module):
Expand Down Expand Up @@ -53,12 +52,14 @@ def __init__(self, mode, input_size, hidden_size,
self.training = False
num_directions = 2 if bidirectional else 1

if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
# "type: ignore" is required since ints and Numbers are not fully comparable
# https://github.com/python/mypy/issues/8566
if not isinstance(dropout, numbers.Number) \
or not 0 <= dropout <= 1 or isinstance(dropout, bool): # type: ignore
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0 and num_layers == 1:
if dropout > 0 and num_layers == 1: # type: ignore
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} and "
Expand Down Expand Up @@ -149,8 +150,7 @@ def __repr__(self):
main_str += ')'
return main_str

def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
Expand All @@ -161,33 +161,31 @@ def check_input(self, input, batch_sizes):
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1)))

def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = batch_sizes[0]
mini_batch = int(mini_batch)
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
return expected_hidden_size

def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None
def check_hidden_size(
self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
msg: str = 'Expected hidden size {}, got {}'
) -> None:
if hx.size() != expected_hidden_size:
raise RuntimeError(msg.format(
expected_hidden_size, list(hx.size())))

def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> None
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(hidden, expected_hidden_size,
msg='Expected hidden size {}, got {}')

def permute_hidden(self, hx, permutation):
# type: (Tensor, Optional[Tensor]) -> Tensor
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
if permutation is None:
return hx
return apply_permutation(hx, permutation)
Expand Down Expand Up @@ -287,7 +285,7 @@ def quantize_and_pack(w, b):

def _weight_bias(self):
# Returns a dict of weights and biases
weight_bias_dict = {'weight' : {}, 'bias' : {}}
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
count = 0
num_directions = 2 if self.bidirectional else 1
for layer in range(self.num_layers):
Expand Down Expand Up @@ -337,8 +335,11 @@ def __init__(self, *args, **kwargs):
def _get_name(self):
return 'DynamicQuantizedLSTM'

def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
def forward_impl(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
batch_sizes: Optional[Tensor], max_batch_size: int,
sorted_indices: Optional[Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
Expand Down Expand Up @@ -367,8 +368,9 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
return output, hidden

@torch.jit.export
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
def forward_tensor(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
Expand All @@ -380,27 +382,32 @@ def forward_tensor(self, input, hx=None):
return output, self.permute_hidden(hidden, unsorted_indices)

@torch.jit.export
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input
def forward_packed(
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa
input_, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)

output, hidden = self.forward_impl(
input, hx, batch_sizes, max_batch_size, sorted_indices)
output_, hidden = self.forward_impl(
input_, hx, batch_sizes, max_batch_size, sorted_indices)

output = PackedSequence(output, batch_sizes,
output = PackedSequence(output_, batch_sizes,
sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)

def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
# "type: ignore" is required due to issue #43072
def permute_hidden( # type: ignore
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
# "type: ignore" is required due to issue #43072
def check_forward_args( # type: ignore
self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

Expand Down Expand Up @@ -483,8 +490,7 @@ def check_forward_input(self, input):
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))

def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
Expand Down Expand Up @@ -518,6 +524,8 @@ def from_float(cls, mod):
if dtype not in supported_scalar_types:
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))

qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]

if type(mod) == torch.nn.LSTMCell:
qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
elif type(mod) == torch.nn.GRUCell:
Expand Down Expand Up @@ -561,7 +569,7 @@ def process_weights(weight, bias, dtype):

def _weight_bias(self):
# Returns a dict of weights and biases
weight_bias_dict = {'weight' : {}, 'bias' : {}}
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
w1, b1 = self._packed_weight_ih.__getstate__()[0]
w2, b2 = self._packed_weight_hh.__getstate__()[0]
weight_bias_dict['weight']['weight_ih'] = w1
Expand Down Expand Up @@ -614,8 +622,7 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtyp
def _get_name(self):
return 'DynamicQuantizedRNNCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -661,13 +668,12 @@ class LSTMCell(RNNCellBase):
"""

def __init__(self, *args, **kwargs):
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs)
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs) # type: ignore

def _get_name(self):
return 'DynamicQuantizedLSTMCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
self.check_forward_input(input)
if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -707,8 +713,7 @@ def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
def _get_name(self):
return 'DynamicQuantizedGRUCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down

0 comments on commit 63a0bb0

Please sign in to comment.