From 0f59e9fd70605fc9666976f3df9e6b53d76e1cb3 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Mon, 24 Apr 2023 04:34:38 -0400 Subject: [PATCH] Auto-initialize DVC repo (#539) * initialize dvc * add tests * drop dvc repo * drop .dvciginore --- setup.cfg | 1 + src/dvclive/dvc.py | 8 ++++++-- src/dvclive/live.py | 12 ++++++++++-- tests/conftest.py | 1 + tests/test_dvc.py | 11 ++++++++++- tests/test_main.py | 1 + 6 files changed, 29 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index cf3a216e..2a53fa30 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires= dvc-studio-client>=0.7.0,<1 funcy ruamel.yaml + scmrepo [options.extras_require] image = numpy diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 76d67725..bc0cf0f7 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -76,12 +76,16 @@ def make_checkpoint(): def get_dvc_repo(): from dvc.exceptions import NotDvcRepoError from dvc.repo import Repo - from dvc.scm import SCMError + from dvc.scm import Git, SCMError + from scmrepo.exceptions import SCMError as GitSCMError try: return Repo() except (NotDvcRepoError, SCMError): - return None + try: + return Repo.init(Git().root_dir) + except GitSCMError: + return None def make_dvcyaml(live): diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 42c1f3b6..5342a883 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -124,8 +124,16 @@ def _init_dvc(self): if self._dvc_repo is None: if self._save_dvc_exp: logger.warning( - "Can't save experiment without a DVC Repo." - "\nYou can create a DVC Repo by calling `dvc init`." + "Can't save experiment without a Git Repo." + "\nCreate a Git repo (`git init`) and commit (`git commit`)." + ) + self._save_dvc_exp = False + return + if self._dvc_repo.scm.no_commits: + if self._save_dvc_exp: + logger.warning( + "Can't save experiment to an empty Git Repo." + "\nAdd a commit (`git commit`) to save experiments." ) self._save_dvc_exp = False return diff --git a/tests/conftest.py b/tests/conftest.py index b611e16e..cad9ac73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ def mocked_dvc_repo(tmp_dir, mocker): _dvc_repo.index.stages = [] _dvc_repo.scm.get_rev.return_value = "f" * 40 _dvc_repo.scm.get_ref.return_value = None + _dvc_repo.scm.no_commits = False _dvc_repo.experiments.save.return_value = "e" * 40 _dvc_repo.root_dir = tmp_dir mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 89390e3f..5acd725e 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -17,10 +17,17 @@ def test_get_dvc_repo(tmp_dir): assert get_dvc_repo() is None Git.init(tmp_dir) - Repo.init(tmp_dir) assert isinstance(get_dvc_repo(), Repo) +def test_get_dvc_repo_subdir(tmp_dir): + Git.init(tmp_dir) + subdir = tmp_dir / "sub" + subdir.mkdir() + os.chdir(subdir) + assert get_dvc_repo().root_dir == str(tmp_dir) + + def test_make_dvcyaml_empty(tmp_dir): live = Live() make_dvcyaml(live) @@ -154,6 +161,7 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): dvc_repo.index.stages = [dvc_stage, dvc_file] dvc_repo.scm.get_rev.return_value = "current_rev" dvc_repo.scm.get_ref.return_value = None + dvc_repo.scm.no_commits = False with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): live = Live(save_dvc_exp=True) assert live._save_dvc_exp @@ -193,6 +201,7 @@ def test_exp_save_with_dvc_files(tmp_dir, mocker): dvc_repo.index.stages = [dvc_file] dvc_repo.scm.get_rev.return_value = "current_rev" dvc_repo.scm.get_ref.return_value = None + dvc_repo.scm.no_commits = False with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): live = Live(save_dvc_exp=True) diff --git a/tests/test_main.py b/tests/test_main.py index adeec46a..d73f2388 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -452,6 +452,7 @@ def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker): dvc_repo.index.stages = [] dvc_repo.scm.get_rev.return_value = "current_rev" dvc_repo.scm.get_ref.return_value = None + dvc_repo.scm.no_commits = False with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo), mocker.patch( "dvclive.live.os.getpid", return_value=test_pid ):