Skip to content

Commit

Permalink
support python 3.9 lint
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 2, 2024
1 parent d32b54a commit 29acf37
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -348,7 +348,7 @@ def scan(
f: Callable,
init,
xs,
length: int | None = None,
length: Optional[int] = None,
reverse: bool = False,
history: int = 1,
):
Expand Down
12 changes: 6 additions & 6 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType
from typing import Any, Callable, OrderedDict as OrderedDictType, Union

import jax
from jax import random
Expand All @@ -21,9 +21,9 @@

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

RunInferenceResult = (
dict[str, Any] | tuple[AutoGuide, dict[str, Any]]
) # for mcmc or sdvi
RunInferenceResult = Union[
dict[str, Any], tuple[AutoGuide, dict[str, Any]]
] # for mcmc or sdvi


class StochasticSupportInference(ABC):
Expand Down Expand Up @@ -124,12 +124,12 @@ def _combine_inferences(
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
raise NotImplementedError

def run(
self, rng_key: ArrayLike, *args: Any, **kwargs: Any
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
"""
Run inference on each SLP separately and combine the results.
Expand Down
5 changes: 3 additions & 2 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from collections import OrderedDict
from itertools import product
from typing import Union

import numpy as np

Expand Down Expand Up @@ -230,7 +231,7 @@ def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:


def summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> dict:
"""
Returns a summary table displaying diagnostics of ``samples`` from the
Expand Down Expand Up @@ -284,7 +285,7 @@ def summary(


def print_summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
10 changes: 5 additions & 5 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import random
import re
from threading import Lock
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, Optional
import warnings

import numpy as np
Expand All @@ -27,7 +27,7 @@
_CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3'


def set_rng_seed(rng_seed: int | None = None) -> None:
def set_rng_seed(rng_seed: Optional[int] = None) -> None:
"""
Initializes internal state for the Python and NumPy random number generators.
Expand All @@ -49,7 +49,7 @@ def enable_x64(use_x64: bool = True) -> None:
jax.config.update("jax_enable_x64", use_x64)


def set_platform(platform: str | None = None) -> None:
def set_platform(platform: Optional[str] = None) -> None:
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand Down Expand Up @@ -408,7 +408,7 @@ def loop_fn(collection):


def soft_vmap(
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: Optional[int] = None
) -> Any:
"""
Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
Expand Down Expand Up @@ -466,7 +466,7 @@ def format_shapes(
*,
compute_log_prob: bool = False,
title: str = "Trace Shapes:",
last_site: str | None = None,
last_site: Optional[str] = None,
):
"""
Given the trace of a function, returns a string showing a table of the shapes of
Expand Down

0 comments on commit 29acf37

Please sign in to comment.