diff --git a/psiflow/metrics.py b/psiflow/metrics.py
index c6d95e7..3f69f36 100644
--- a/psiflow/metrics.py
+++ b/psiflow/metrics.py
@@ -343,6 +343,11 @@ def _to_wandb(
df_na = df[df["temperature"].isna()]
df_not_na = df[df["temperature"].notna()]
+
+ # sort to get markers right!
+ df_not_na = df_not_na.sort_values(by="marker_symbol")
+ df_na = df_na.sort_values(by="marker_symbol")
+
cmap = cc.cm.CET_I1
colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0, 1, cmap.N)]
@@ -363,7 +368,7 @@ def _to_wandb(
"marker_symbol",
],
symbol="marker_symbol",
- symbol_sequence=["star-diamond", "circle"], # reversed?
+ symbol_sequence=["circle", "star-diamond"],
color_discrete_sequence=["darkgray"],
)
figure = px.scatter(
@@ -379,12 +384,16 @@ def _to_wandb(
"temperature",
],
symbol="marker_symbol",
- symbol_sequence=["star-diamond", "circle"],
+ symbol_sequence=["circle", "star-diamond"],
color="temperature",
color_continuous_scale=colors,
)
for trace in figure_.data:
figure.add_trace(trace)
+ figure.update_traces( # wandb cannot deal with lines in non-circle symbols!
+ marker={"size": 12},
+ selector=dict(marker_symbol="star-diamond"),
+ )
figure.update_traces(
marker={
"size": 10,
@@ -392,10 +401,6 @@ def _to_wandb(
},
selector=dict(marker_symbol="circle"),
)
- figure.update_traces( # wandb cannot deal with lines in non-circle symbols!
- marker={"size": 12},
- selector=dict(marker_symbol="star-diamond"),
- )
figure.update_traces(
hovertemplate=(
"iteration: %{customdata[0]}
"
@@ -491,7 +496,7 @@ def log_walker(
i,
state,
error,
- walker.is_reset(),
+ condition,
identifier,
disagreement,
**metadata_dict,
@@ -507,7 +512,7 @@ def log_walker(
self.iteration,
i,
walker.counter,
- walker.is_reset(),
+ condition,
temperature,
)
diff --git a/psiflow/sampling.py b/psiflow/sampling.py
index 05a755b..eddb5ba 100644
--- a/psiflow/sampling.py
+++ b/psiflow/sampling.py
@@ -112,6 +112,7 @@ def log_evaluation_model(
*errors
)
else:
+ assert condition
s += "\twalker reset"
s += "\n"
logger.info(s)
@@ -188,7 +189,6 @@ def sample_with_model(
error=errors[i],
error_thresholds=error_thresholds_for_reset,
)
- walkers[i].reset(condition)
log_evaluation_model(
i, metadatas[i], states[i], errors[i], condition, identifier
)
@@ -196,6 +196,7 @@ def sample_with_model(
metrics.log_walker(
i, walkers[i], metadatas[i], states[i], errors[i], condition, identifier
)
+ walkers[i].reset(condition)
return Dataset(states).labeled(), identifier