diff --git a/psiflow/metrics.py b/psiflow/metrics.py
index 2cd7ae2..d46c060 100644
--- a/psiflow/metrics.py
+++ b/psiflow/metrics.py
@@ -164,9 +164,13 @@ def fix_plotly_layout(figure):
def _to_wandb(
wandb_id: str,
wandb_project: str,
+ wandb_api_key: str,
inputs: list = [],
) -> None:
import os
+
+ os.environ["WANDB_API_KEY"] = wandb_api_key
+ os.environ["WANDB_SILENT"] = "True"
import tempfile
import numpy as np
from pathlib import Path
@@ -289,7 +293,6 @@ def _to_wandb(
figure.update_layout(yaxis_title="forces RMSE [meV/atom]")
figure.update_layout(xaxis_title="" + x_axis + "")
figures[title] = figure
- os.environ["WANDB_SILENT"] = "True"
path_wandb = Path(tempfile.mkdtemp())
wandb.init(id=wandb_id, dir=path_wandb, project=wandb_project, resume="allow")
wandb.log(figures)
@@ -530,5 +533,6 @@ def to_wandb(self):
return to_wandb(
self.wandb_id,
self.wandb_project,
+ os.environ['WANDB_API_KEY'],
inputs=[self.metrics],
)