From 18d7e8a6ac5b38cfd121ca67af51635db2643d18 Mon Sep 17 00:00:00 2001 From: Fengyuan Hu <127644049+HuFY-dev@users.noreply.github.com> Date: Fri, 29 Mar 2024 00:05:15 -0400 Subject: [PATCH] Resolved import error for PyTorch >= 2.2 --- sparse_autoencoder/optimizer/adam_with_reset.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index 5a84c941..dd1921c7 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -2,16 +2,20 @@ This reset method is useful when resampling dead neurons during training. """ -from collections.abc import Iterator +from collections.abc import Iterable, Iterator +from typing import Any +from typing_extensions import TypeAlias from jaxtyping import Float, Int from torch import Tensor from torch.nn.parameter import Parameter from torch.optim import Adam -from torch.optim.optimizer import params_t from sparse_autoencoder.tensor_types import Axis +# params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors +# Copied from PyTorch 2.2 with modifications for better style +ParamsT: TypeAlias = Iterable[Tensor] | Iterable[dict[str, Any]] class AdamWithReset(Adam): """Adam Optimizer with a reset method. @@ -35,7 +39,7 @@ class AdamWithReset(Adam): def __init__( # (extending existing implementation) self, - params: params_t, + params: ParamsT, lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,