Skip to content

Commit

Permalink
Update analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosfelt committed Jul 25, 2023
1 parent 69606a1 commit 35692ff
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 655 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions multitask/visualization/cn_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from .plots import (
get_wandb_run_dfs,
make_yld_comparison_plot,
make_comparison_plot,
make_categorical_comparison_plot,
)
from summit import *
Expand Down Expand Up @@ -137,7 +137,7 @@ def baumgartner_cn_auxiliary_one_baumgartner_cn(

# Make yield subplot
ax_yld = fig_yld.add_subplot(4, 3, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -333,7 +333,7 @@ def baumgartner_cn_auxiliary_all_baumgartner_cn(

# Make yield subplot
ax = fig_yld.add_subplot(1, 4, i)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down
19 changes: 13 additions & 6 deletions multitask/visualization/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ def make_average_plot(
ax: Axes,
label: Optional[str] = None,
color: Optional[str] = None,
bounds: Optional[List[float]] = None,
dropna: bool = True,
):
yields = [df[output_name] for df in results]
if dropna:
yields = [df.dropna() for df in yields]
yields = np.array(yields)
mean_yield = np.mean(yields, axis=0)
std_yield = np.std(yields, axis=0)
x = np.arange(1, len(mean_yield) + 1, 1).astype(int)
ax.plot(x, mean_yield, label=label, linewidth=4, c=color)
top = mean_yield + 1.96 * std_yield
bottom = mean_yield - 1.96 * std_yield
bottom = np.clip(bottom, 0, 100)
top = np.clip(top, 0, 100)
if bounds is None:
bounds = [0, 100]
bottom = np.clip(bottom, bounds[0], bounds[1])
top = np.clip(top, bounds[0], bounds[1])
ax.fill_between(
x,
bottom,
Expand Down Expand Up @@ -59,8 +65,8 @@ def make_repeats_plot(
)


def make_yld_comparison_plot(
*args, output_name: str, ax: Axes, plot_type: str = "average"
def make_comparison_plot(
*args, output_name: str, ax: Axes, plot_type: str = "average", n_experiments:int=20
):
for arg in args:
if plot_type == "average":
Expand All @@ -70,6 +76,7 @@ def make_yld_comparison_plot(
ax,
label=arg["label"],
color=arg.get("color"),
bounds=arg.get("bounds"),
)
elif plot_type == "repeats":
make_repeats_plot(
Expand All @@ -86,8 +93,8 @@ def make_yld_comparison_plot(
ax.legend(prop=fontdict, framealpha=0.0)
else:
ax.legend(prop=fontdict)
ax.set_xlim(0, 20)
ax.set_xticks(np.arange(0, 20, 2).astype(int))
ax.set_xlim(0, n_experiments)
ax.set_xticks(np.arange(0, n_experiments, 2).astype(int))
ax.tick_params(direction="in")
return ax

Expand Down
14 changes: 7 additions & 7 deletions multitask/visualization/suzuki_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from multitask.etl.etl_baumgartner_suzuki import ligands, pre_catalysts
from .plots import (
make_yld_comparison_plot,
make_comparison_plot,
make_categorical_comparison_plot,
get_wandb_run_dfs,
)
Expand Down Expand Up @@ -128,7 +128,7 @@ def baumgartner_suzuki_auxiliary_one_reizman_suzuki(

# Make yield comparison subplot
ax_yld = fig_yld.add_subplot(1, 4, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -307,7 +307,7 @@ def baumgartner_suzuki_auxiliary_all_reizman_suzuki(
][:num_repeats]

# Make comparison subplot
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -484,7 +484,7 @@ def reizman_suzuki_auxiliary_one_baumgartner_suzuki(

# Make subplot
ax_yld = fig_yld.add_subplot(1, 4, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -674,7 +674,7 @@ def reizman_suzuki_auxiliary_one_reizman_suzuki(

# Make subplot
ax_yld = fig_yld.add_subplot(4, 3, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -864,7 +864,7 @@ def reizman_suzuki_auxiliary_all_baumgartner_suzuki(

# Make subplot
ax_yld = fig_yld.add_subplot(1, 4, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down Expand Up @@ -1057,7 +1057,7 @@ def reizman_suzuki_auxiliary_all_reizman_suzuki(

# Make subplot
ax_yld = fig_yld.add_subplot(1, 4, k)
make_yld_comparison_plot(
make_comparison_plot(
dict(results=stbo_dfs, label="STBO", color="#a50026"),
dict(results=stbo_head_start_dfs, label="STBO HS", color="#FDAE61"),
dict(results=mtbo_dfs, label="MTBO", color="#313695"),
Expand Down
Loading

0 comments on commit 35692ff

Please sign in to comment.