Skip to content

Commit

Permalink
make check NaN optional
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanChain committed Dec 10, 2024
1 parent da20835 commit a3cab4b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
14 changes: 7 additions & 7 deletions netobs/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def reweighting_log_psi(params, electrons, system):
i, params, subkeys, data, system, state, aux_data
)

has_nan = any(jnp.isnan(v).any() for v in obs_values.values())

if options.reweight_ratio > 0.0:
log_psi = pmap_log_psi(params, data, system)
weights = jnp.minimum(
Expand All @@ -146,11 +148,9 @@ def reweighting_log_psi(params, electrons, system):
mean_obs_values = {
k: weighted_sum(v, weights) for k, v in obs_values.items()
}
mean_obs_values["reweighting_weights"] = jnp.mean(weights)
mean_obs_values["reweighting_weights"] = jnp.nanmean(weights)
else:
mean_obs_values = {k: jnp.mean(v, (0, 1)) for k, v in obs_values.items()}

has_nan = any(jnp.isnan(v).any() for v in mean_obs_values.values())
mean_obs_values = {k: jnp.nanmean(v, (0, 1)) for k, v in obs_values.items()}

all_values = {k: v.at[i].set(mean_obs_values[k]) for k, v in all_values.items()}

Expand All @@ -161,7 +161,7 @@ def reweighting_log_psi(params, electrons, system):
should_log = last_log < now - options.log_interval
should_save = last_save < now - options.save_interval

if has_nan or should_save or should_log:
if options.check_nan and has_nan or should_save or should_log:
all_values_yet = {k: v[: i + 1] for k, v in all_values.items()}

if options.reweight_ratio > 0.0:
Expand All @@ -171,7 +171,7 @@ def reweighting_log_psi(params, electrons, system):
}
digest = estimator.digest(all_values_yet, state)

if has_nan or should_save:
if options.check_nan and has_nan or should_save:
last_save = now
checkpoint_mgr.save(
i, data, digest, all_values, state, aux_data, metadata
Expand All @@ -181,7 +181,7 @@ def reweighting_log_psi(params, electrons, system):
logger.info("Loop %s", i)
log_digest(i, digest)

if has_nan:
if options.check_nan and has_nan:
logger.error("NaN detected. Stopping")
return digest, all_values, state

Expand Down
3 changes: 3 additions & 0 deletions netobs/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class NetObsOptions:
save_interval: int = 600
"Time interval in seconds between saves."

check_nan: bool = False
"Check if there are NaN values and stop."

network_restore: Any = None
"""The restore option for the network adaptor.
Defaults to None, which means the network adaptor has a fixed and known way to
Expand Down

0 comments on commit a3cab4b

Please sign in to comment.