Skip to content

Commit

Permalink
small bugfixes in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Oct 31, 2023
1 parent a2e4819 commit 732a55d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 36 deletions.
37 changes: 16 additions & 21 deletions psiflow/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class BaseLearning:
pretraining_nstates: int = 50
pretraining_amplitude_pos: float = 0.05
pretraining_amplitude_box: float = 0.0
metrics: Optional[Metrics] = Metrics()
metrics: Metrics = Metrics()
atomic_energies: dict[str, Union[float, AppFuture]] = field(
default_factory=lambda: {}
)
Expand All @@ -51,8 +51,7 @@ def __post_init__(self) -> None: # save self in output folder
self.atomic_energies = atomic_energies
config["path_output"] = str(self.path_output) # yaml requires str
config.pop("metrics")
if self.metrics is not None:
config["Metrics"] = self.metrics.as_dict()
config["Metrics"] = self.metrics.as_dict()
path_config = self.path_output / (self.__class__.__name__ + ".yaml")
if path_config.is_file():
logger.warning("overriding learning config file {}".format(path_config))
Expand All @@ -67,6 +66,7 @@ def run_pretraining(
reference: BaseReference,
walkers: list[BaseWalker],
) -> Dataset:
self.metrics.iteration = "pretraining"
nstates = self.pretraining_nstates
amplitude_pos = self.pretraining_amplitude_pos
amplitude_box = self.pretraining_amplitude_box
Expand Down Expand Up @@ -110,12 +110,11 @@ def run_pretraining(
data_train=data_train,
data_valid=data_valid,
)
if self.metrics is not None:
self.metrics.save(
self.path_output / "pretraining",
model=model,
dataset=data,
)
self.metrics.save(
self.path_output / "pretraining",
model=model,
dataset=data,
)
psiflow.wait()
return data

Expand All @@ -140,8 +139,7 @@ def initialize_run(
continue with online learning, build on top of initial data
"""
if self.metrics is not None:
self.metrics.insert_name(model)
self.metrics.insert_name(model)
if len(self.atomic_energies) > 0:
for element, energy in self.atomic_energies.items():
model.add_atomic_energy(element, energy)
Expand Down Expand Up @@ -227,6 +225,7 @@ def run(
for i in range(self.niterations):
if self.output_exists(str(i)):
continue # skip iterations in case of restarted run
self.metrics.iteration = i
self.update_walkers(walkers, initialize=(i == 0))
new_data, self.identifier = sample_with_model(
model,
Expand Down Expand Up @@ -254,8 +253,7 @@ def run(
data_train=data_train,
data_valid=data_valid,
)
if self.metrics is not None:
self.metrics.save(self.path_output / str(i), model, data)
self.metrics.save(self.path_output / str(i), model, data)
psiflow.wait()
return data

Expand All @@ -280,6 +278,7 @@ def run(
for i in range(self.niterations):
if self.output_exists(str(i)):
continue # skip iterations in case of restarted run
self.metrics.iteration = i
self.update_walkers(walkers, initialize=(i == 0))
new_data, self.identifier = sample_with_committee(
committee,
Expand All @@ -302,8 +301,7 @@ def run(
data_train=data_train,
data_valid=data_valid,
)
if self.metrics is not None:
self.metrics.save(self.path_output / str(i), committee.models[0], data)
self.metrics.save(self.path_output / str(i), committee.models[0], data)
committee.save(self.path_output / str(i) / "committee")
psiflow.wait()
return data
Expand Down Expand Up @@ -372,6 +370,7 @@ def run(
for i in range(self.niterations):
if self.output_exists(str(i)):
continue # skip iterations in case of restarted run
self.metrics.iteration = i
self.update_walkers(walkers, initialize=(i == 0))
new_data, self.identifier = sample_with_model(
model,
Expand Down Expand Up @@ -399,8 +398,7 @@ def run(
data_train=data_train,
data_valid=data_valid,
)
if self.metrics is not None:
self.metrics.save(self.path_output / str(i), model, data)
self.metrics.save(self.path_output / str(i), model, data)
psiflow.wait()
return data

Expand Down Expand Up @@ -429,9 +427,6 @@ def load_learning(path_output: Union[Path, str]):
atomic_energies[element] = energy
config["atomic_energies"] = atomic_energies
config["path_output"] = str(path_output)
if "Metrics" in config.keys():
metrics = Metrics(**config.pop("Metrics"))
else:
metrics = None
metrics = Metrics(**config.pop("Metrics", {}))
learning = learning_cls(metrics=metrics, **config)
return learning
28 changes: 13 additions & 15 deletions psiflow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def _trace_identifier(
identifier_traces: dict,
state: FlowAtoms,
iteration: int,
iteration: Union[str, int],
walker_index: int,
nsteps: int,
condition: bool,
Expand All @@ -31,7 +31,6 @@ def _trace_identifier(
if not state == NullState: # same checks as sampling.py:assign_identifier
if state.reference_status:
identifier = state.info["identifier"]
print(identifier, condition)
assert identifier not in identifier_traces
identifier_traces[identifier] = (
iteration,
Expand Down Expand Up @@ -352,8 +351,8 @@ def _to_wandb(
if x_axis.startswith("CV") or (x_axis == "identifier"):
for y_axis in dataset_log:
if (y_axis == "e_rmse") or y_axis.startswith("f_rmse"):
figure = px.scatter(
data_frame=df_not_na,
figure_ = px.scatter(
data_frame=df_na,
x=x_axis,
y=y_axis,
custom_data=[
Expand All @@ -362,16 +361,13 @@ def _to_wandb(
"nsteps",
"identifier",
"marker_symbol",
"temperature",
],
symbol="marker_symbol",
symbol_sequence=["star-diamond", "circle"],
color="temperature",
color_continuous_scale=colors,
color_discrete_sequence=["darkgray"],
)
# Overlay the scatter plot for missing values
figure_ = px.scatter(
data_frame=df_na,
figure = px.scatter(
data_frame=df_not_na,
x=x_axis,
y=y_axis,
custom_data=[
Expand All @@ -380,22 +376,24 @@ def _to_wandb(
"nsteps",
"identifier",
"marker_symbol",
"temperature",
],
symbol="marker_symbol",
symbol_sequence=["star-diamond", "circle"],
color_discrete_sequence=["darkgray"],
color="temperature",
color_continuous_scale=colors,
)
for trace in figure_.data:
figure.add_trace(trace)
figure.update_traces(
marker={
"size": 11,
"line": dict(width=1.2, color="DarkSlateGray"),
"size": 10,
"line": dict(width=1.0, color="DarkSlateGray"),
},
selector=dict(marker_symbol="circle"),
)
figure.update_traces( # wandb cannot deal with lines in non-circle symbols!
marker={"size": 15},
marker={"size": 12},
selector=dict(marker_symbol="star-diamond"),
)
figure.update_traces(
Expand Down Expand Up @@ -508,7 +506,7 @@ def log_walker(
state,
self.iteration,
i,
metadata.counter,
walker.counter,
condition,
temperature,
)
Expand Down

0 comments on commit 732a55d

Please sign in to comment.