diff --git a/psiflow/metrics.py b/psiflow/metrics.py index c6d95e7..3f69f36 100644 --- a/psiflow/metrics.py +++ b/psiflow/metrics.py @@ -343,6 +343,11 @@ def _to_wandb( df_na = df[df["temperature"].isna()] df_not_na = df[df["temperature"].notna()] + + # sort to get markers right! + df_not_na = df_not_na.sort_values(by="marker_symbol") + df_na = df_na.sort_values(by="marker_symbol") + cmap = cc.cm.CET_I1 colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0, 1, cmap.N)] @@ -363,7 +368,7 @@ def _to_wandb( "marker_symbol", ], symbol="marker_symbol", - symbol_sequence=["star-diamond", "circle"], # reversed? + symbol_sequence=["circle", "star-diamond"], color_discrete_sequence=["darkgray"], ) figure = px.scatter( @@ -379,12 +384,16 @@ def _to_wandb( "temperature", ], symbol="marker_symbol", - symbol_sequence=["star-diamond", "circle"], + symbol_sequence=["circle", "star-diamond"], color="temperature", color_continuous_scale=colors, ) for trace in figure_.data: figure.add_trace(trace) + figure.update_traces( # wandb cannot deal with lines in non-circle symbols! + marker={"size": 12}, + selector=dict(marker_symbol="star-diamond"), + ) figure.update_traces( marker={ "size": 10, @@ -392,10 +401,6 @@ def _to_wandb( }, selector=dict(marker_symbol="circle"), ) - figure.update_traces( # wandb cannot deal with lines in non-circle symbols! - marker={"size": 12}, - selector=dict(marker_symbol="star-diamond"), - ) figure.update_traces( hovertemplate=( "iteration: %{customdata[0]}
" @@ -491,7 +496,7 @@ def log_walker( i, state, error, - walker.is_reset(), + condition, identifier, disagreement, **metadata_dict, @@ -507,7 +512,7 @@ def log_walker( self.iteration, i, walker.counter, - walker.is_reset(), + condition, temperature, ) diff --git a/psiflow/sampling.py b/psiflow/sampling.py index 05a755b..eddb5ba 100644 --- a/psiflow/sampling.py +++ b/psiflow/sampling.py @@ -112,6 +112,7 @@ def log_evaluation_model( *errors ) else: + assert condition s += "\twalker reset" s += "\n" logger.info(s) @@ -188,7 +189,6 @@ def sample_with_model( error=errors[i], error_thresholds=error_thresholds_for_reset, ) - walkers[i].reset(condition) log_evaluation_model( i, metadatas[i], states[i], errors[i], condition, identifier ) @@ -196,6 +196,7 @@ def sample_with_model( metrics.log_walker( i, walkers[i], metadatas[i], states[i], errors[i], condition, identifier ) + walkers[i].reset(condition) return Dataset(states).labeled(), identifier