diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 05190e24d7..298fd859e5 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -138,7 +138,7 @@ def __get__( if name not in obj.__dict__["_pyro_params"]: init_value, constraint, event_dim = self # bind method's self arg - init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator] + init_value = functools.partial(init_value, obj) # type: ignore[arg-type,call-arg,misc,operator] setattr(obj, name, PyroParam(init_value, constraint, event_dim)) value: PyroParam = obj.__getattr__(name) return value diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 69d41756f6..19ad050b32 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -1,8 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import numbers -from typing import Iterator, NamedTuple, Optional, Tuple +from typing import Iterator, NamedTuple, Optional, Tuple, Union import torch from typing_extensions import Self @@ -108,7 +107,7 @@ def __exit__(self, *args) -> None: _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) - def __iter__(self) -> Iterator[int]: + def __iter__(self) -> Iterator[Union[int, float]]: if self._vectorized is True or self.dim is not None: raise ValueError( "cannot use plate {} as both vectorized and non-vectorized" @@ -121,7 +120,14 @@ def __iter__(self) -> Iterator[int]: for i in self.indices: self.next_context() with self: - yield i if isinstance(i, numbers.Number) else i.item() + if isinstance(i, (int, float)): + yield i + elif isinstance(i, torch.Tensor): + yield i.item() + else: + raise ValueError( + f"Expected int, float or torch.Tensor, but got {type(i)}" + ) def _reset(self) -> None: if self._vectorized: