Skip to content

Commit

Permalink
lightning: Only force init if report="notebook". (#595)
Browse files Browse the repository at this point in the history
Closes #594.
  • Loading branch information
daavoo authored Jun 8, 2023
1 parent 9088996 commit aad6a01
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def __init__(
self._live_init["dir"] = dir
self._experiment = experiment
self._version = run_name
# Force Live instantiation
self.experiment # noqa: B018
if report == "notebook":
# Force Live instantiation
self.experiment # noqa: B018

@property
def name(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_frameworks/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, Dataset

from dvclive import Live
from dvclive.lightning import DVCLiveLogger
except ImportError:
pytest.skip("skipping pytorch_lightning tests", allow_module_level=True)
Expand Down Expand Up @@ -239,3 +240,14 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio
# Without `self.experiment._latest_studio_step -= 1`
# This would be empty
assert len(val_loss["data"]) == 1


def test_lightning_force_init(tmp_dir, mocker):
"""Regression test for https://github.com/iterative/dvclive/issues/594
Only call Live.__init__ when report is notebook.
"""
init = mocker.spy(Live, "__init__")
DVCLiveLogger()
init.assert_not_called()
DVCLiveLogger(report="notebook")
init.assert_called_once()

0 comments on commit aad6a01

Please sign in to comment.