Skip to content

Commit

Permalink
fix nasty symbols bug in plotly plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Nov 3, 2023
1 parent 8941b11 commit b8816e3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
21 changes: 13 additions & 8 deletions psiflow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ def _to_wandb(

df_na = df[df["temperature"].isna()]
df_not_na = df[df["temperature"].notna()]

# sort to get markers right!
df_not_na = df_not_na.sort_values(by="marker_symbol")
df_na = df_na.sort_values(by="marker_symbol")

cmap = cc.cm.CET_I1
colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0, 1, cmap.N)]

Expand All @@ -363,7 +368,7 @@ def _to_wandb(
"marker_symbol",
],
symbol="marker_symbol",
symbol_sequence=["star-diamond", "circle"], # reversed?
symbol_sequence=["circle", "star-diamond"],
color_discrete_sequence=["darkgray"],
)
figure = px.scatter(
Expand All @@ -379,23 +384,23 @@ def _to_wandb(
"temperature",
],
symbol="marker_symbol",
symbol_sequence=["star-diamond", "circle"],
symbol_sequence=["circle", "star-diamond"],
color="temperature",
color_continuous_scale=colors,
)
for trace in figure_.data:
figure.add_trace(trace)
figure.update_traces( # wandb cannot deal with lines in non-circle symbols!
marker={"size": 12},
selector=dict(marker_symbol="star-diamond"),
)
figure.update_traces(
marker={
"size": 10,
"line": dict(width=1.0, color="DarkSlateGray"),
},
selector=dict(marker_symbol="circle"),
)
figure.update_traces( # wandb cannot deal with lines in non-circle symbols!
marker={"size": 12},
selector=dict(marker_symbol="star-diamond"),
)
figure.update_traces(
hovertemplate=(
"<b>iteration</b>: %{customdata[0]}<br>"
Expand Down Expand Up @@ -491,7 +496,7 @@ def log_walker(
i,
state,
error,
walker.is_reset(),
condition,
identifier,
disagreement,
**metadata_dict,
Expand All @@ -507,7 +512,7 @@ def log_walker(
self.iteration,
i,
walker.counter,
walker.is_reset(),
condition,
temperature,
)

Expand Down
3 changes: 2 additions & 1 deletion psiflow/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def log_evaluation_model(
*errors
)
else:
assert condition
s += "\twalker reset"
s += "\n"
logger.info(s)
Expand Down Expand Up @@ -188,14 +189,14 @@ def sample_with_model(
error=errors[i],
error_thresholds=error_thresholds_for_reset,
)
walkers[i].reset(condition)
log_evaluation_model(
i, metadatas[i], states[i], errors[i], condition, identifier
)
if metrics is not None:
metrics.log_walker(
i, walkers[i], metadatas[i], states[i], errors[i], condition, identifier
)
walkers[i].reset(condition)
return Dataset(states).labeled(), identifier


Expand Down

0 comments on commit b8816e3

Please sign in to comment.