Skip to content

Commit

Permalink
fixing mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Jan 15, 2025
1 parent ac15394 commit 6c0b1f5
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from jax import Array
from jax.random import PRNGKey
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
Expand Down Expand Up @@ -476,7 +477,7 @@ def plot_mcmc_chains(
return fig


def _sample_prior_distributions(priors, num_samples):
def _sample_prior_distributions(priors, num_samples) -> dict[str, Array]:
"""Sample numpyro.distributions `num_samples` times.
Parameters
Expand Down Expand Up @@ -589,7 +590,7 @@ def plot_prior_distributions(
ax.set_title(param_name)
samples = sampled_priors[param_name]
ax.hist(samples, **hist_kwargs)
ax.axvline(jnp.median(samples), **median_line_kwargs)
ax.axvline(float(jnp.median(samples)), **median_line_kwargs)
# testing
# Turn off any unused subplots
for j in range(i + 1, len(axs_flat)):
Expand All @@ -602,8 +603,8 @@ def plot_prior_distributions(


def plot_violin_plots(
priors: dict[str:list] = None,
posteriors: dict[str:list] = None,
priors: dict[str, list] | None = None,
posteriors: dict[str, list] | None = None,
matplotlib_style: list[str]
| str = [
"seaborn-v0_8-colorblind",
Expand All @@ -613,11 +614,11 @@ def plot_violin_plots(
raise VisualizationError(
"must provide either a dictionary of priors or posteriors"
)
num_params = (
len(posteriors.keys())
if posteriors is not None
else len(priors.keys())
)
# we are given that both are not none, so get num_params from one of them
if posteriors is not None:
num_params = len(posteriors.keys())
elif priors is not None:
num_params = len(priors.keys())
# Calculate the number of rows and columns for a square-ish layout
num_cols = int(np.ceil(np.sqrt(num_params)))
num_rows = int(np.ceil(num_params / num_cols))
Expand Down

0 comments on commit 6c0b1f5

Please sign in to comment.