diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 6585290..0bcf899 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -9,7 +9,6 @@ import torch from omegaconf import OmegaConf from progress_table import ProgressTable -from torch.utils.tensorboard import SummaryWriter from ..git import git_diff from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector @@ -426,8 +425,14 @@ class TensorboardCallback(Callback): def __init__(self, log_dir: Union[str, Path]): self.log_dir = Path(log_dir) + try: + from torch.utils.tensorboard import SummaryWriter # noqa: F401 + except ImportError: + raise ImportError('tensorflow is required for the TensorboardCallback') def pre_run(self, pipe): + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=self.log_dir) def post_epoch(self, stage: 'Stage'):