Skip to content

Commit

Permalink
Resolved import error for PyTorch >= 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
HuFY-dev authored Mar 29, 2024
1 parent b6ba6cb commit 18d7e8a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
This reset method is useful when resampling dead neurons during training.
"""
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from typing import Any
from typing_extensions import TypeAlias

from jaxtyping import Float, Int
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.optimizer import params_t

from sparse_autoencoder.tensor_types import Axis

# params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors

Check failure on line 16 in sparse_autoencoder/optimizer/adam_with_reset.py

View workflow job for this annotation

GitHub Actions / Checks (3.10)

Ruff (I001)

sparse_autoencoder/optimizer/adam_with_reset.py:5:1: I001 Import block is un-sorted or un-formatted

Check failure on line 16 in sparse_autoencoder/optimizer/adam_with_reset.py

View workflow job for this annotation

GitHub Actions / Checks (3.11)

Ruff (I001)

sparse_autoencoder/optimizer/adam_with_reset.py:5:1: I001 Import block is un-sorted or un-formatted
# Copied from PyTorch 2.2 with modifications for better style
ParamsT: TypeAlias = Iterable[Tensor] | Iterable[dict[str, Any]]

class AdamWithReset(Adam):
"""Adam Optimizer with a reset method.
Expand All @@ -35,7 +39,7 @@ class AdamWithReset(Adam):

def __init__( # (extending existing implementation)
self,
params: params_t,
params: ParamsT,
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
Expand Down

0 comments on commit 18d7e8a

Please sign in to comment.