Skip to content

Commit

Permalink
Feat fix schema bump (#1215)
Browse files Browse the repository at this point in the history
* fix: changing how schema version bumps are discovered, to avoid re-bumping a major/minor change

* fix: refactor and fix schema bump code to pull original version properly

* fix: raise errors appropriately and skip properly

* test: coverage on new get_schema_json function

* chore: lint

* chore: docstrings

* tests: coverage on missing lines
  • Loading branch information
dbirman authored Jan 23, 2025
1 parent 0dfb05f commit 37cb5c3
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 26 deletions.
7 changes: 1 addition & 6 deletions src/aind_data_schema/core/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,7 @@ class Procedures(AindCoreModel):
)
subject_procedures: List[
Annotated[
Union[
Surgery,
TrainingProtocol,
WaterRestriction,
OtherSubjectProcedure
],
Union[Surgery, TrainingProtocol, WaterRestriction, OtherSubjectProcedure],
Field(discriminator="procedure_type"),
]
] = Field(default=[], title="Subject Procedures")
Expand Down
66 changes: 51 additions & 15 deletions src/aind_data_schema/utils/schema_version_bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ def __init__(self, commit_message: str = "", json_schemas_location: Path = Path(
self.commit_message = commit_message
self.json_schemas_location = json_schemas_location

def _get_schema_json(self, model: AindCoreModel) -> dict:
"""
Get the json schema of a model
Parameters
----------
model : AindCoreModel
The model to get the json schema of
Returns
-------
dict
The json schema of the model
"""
default_filename = model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
else:
raise FileNotFoundError(f"Schema file not found: {main_branch_schema_path}")
return main_branch_schema_contents

def _get_list_of_models_that_changed(self) -> List[AindCoreModel]:
"""
Get a list of core models that have been updated by comparing the json
Expand All @@ -46,20 +70,18 @@ def _get_list_of_models_that_changed(self) -> List[AindCoreModel]:
schemas_that_need_updating = []
for core_model in SchemaWriter.get_schemas():
core_model_json = core_model.model_json_schema()
default_filename = core_model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
diff = dictdiffer.diff(main_branch_schema_contents, core_model_json)
if len(list(diff)) > 0:
schemas_that_need_updating.append(core_model)
original_schema = self._get_schema_json(core_model)

diff_list = list(dictdiffer.diff(original_schema, core_model_json))

print(f"Diff for {core_model.__name__}: {diff_list}")
if len(diff_list) > 0:
schemas_that_need_updating.append(core_model)

print(f"Schemas that need updating: {[model.__name__ for model in schemas_that_need_updating]}")
return schemas_that_need_updating

@staticmethod
def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]:
def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]:
"""
Parameters
Expand All @@ -74,11 +96,24 @@ def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> D
"""
version_bump_map = {}
# TODO: Use commit message to determine version number to bump?
for model in models_that_changed:
# We only want to bump the patch if the major or minor versions didn't already change
# Load the current version of the model
original_schema = self._get_schema_json(model)
schema_version = original_schema.get("properties", {}).get("schema_version", {}).get("default")
if schema_version:
orig_ver = semver.Version.parse(schema_version)
else:
raise ValueError("Schema version not found in the schema file")

old_v = semver.Version.parse(model.model_fields["schema_version"].default)
new_v = old_v.bump_patch()
version_bump_map[model] = str(new_v)
if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor:
print(f"Updating {model.__name__} from {old_v} to {old_v.bump_patch()}")
new_ver = old_v.bump_patch()
version_bump_map[model] = str(new_ver)
else:
print(f"Skipping {model.__name__}, major or minor version already updated")
new_ver = old_v
return version_bump_map

@staticmethod
Expand All @@ -98,6 +133,7 @@ def _get_updated_file(python_file_path: str, new_ver: str) -> list:
"""
new_file_contents = []
print(f"Updating {python_file_path} to version {new_ver}")
with open(python_file_path, "rb") as f:
file_lines = f.readlines()
for line in file_lines:
Expand Down
88 changes: 83 additions & 5 deletions tests/test_bump_schema_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aind_data_schema.core.session import Session
from aind_data_schema.core.subject import Subject
from aind_data_schema.core.rig import Rig
from aind_data_schema.utils.json_writer import SchemaWriter
from aind_data_schema.utils.schema_version_bump import SchemaVersionHandler

Expand Down Expand Up @@ -39,16 +40,25 @@ def test_get_list_of_models_that_changed(self, mock_exists: MagicMock, mock_json
self.assertTrue(Session in models_that_changed)
self.assertTrue(Subject in models_that_changed)

def test_get_list_of_incremented_versions(self):
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_list_of_incremented_versions(self, mock_get_schema: MagicMock):
"""Tests get_list_of_incremented_versions method"""

handler = SchemaVersionHandler(json_schemas_location=Path("."))
old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker

def side_effect(model):
"""Side effect for mock_get_schema"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

model_map = handler._get_incremented_versions_map([Subject, Session])
expected_model_map = {Subject: new_subject_version, Session: new_session_version}
self.assertEqual(expected_model_map, model_map)
Expand All @@ -64,6 +74,58 @@ def test_write_new_file(self, mock_open: MagicMock):
mock_open.assert_called_once_with(file_path, "wb")
mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])])

@patch("builtins.open")
@patch("json.load")
@patch("pathlib.Path.exists")
def test_get_schema_json(self, mock_exists: MagicMock, mock_json_load: MagicMock, mock_open: MagicMock):
"""Tests _get_schema_json method"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = True
mock_json_load.return_value = {"properties": {"schema_version": {"default": "1.0.0"}}}

model = MagicMock()
model.default_filename.return_value = "test_model.json"

schema_json = handler._get_schema_json(model)
self.assertEqual(schema_json, {"properties": {"schema_version": {"default": "1.0.0"}}})

mock_open.assert_called_once_with(Path("./test_model_schema.json"), "r")
mock_json_load.assert_called_once()

@patch("pathlib.Path.exists")
def test_get_schema_json_file_not_found(self, mock_exists: MagicMock):
"""Tests _get_schema_json method when file is not found"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = False

model = MagicMock()
model.default_filename.return_value = "test_model.json"

with self.assertRaises(FileNotFoundError):
handler._get_schema_json(model)

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_exception(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {}

with self.assertRaises(ValueError):
handler._get_incremented_versions_map([Subject])

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_skip(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {"properties": {"schema_version": {"default": "0.0.0"}}}

empty_map = handler._get_incremented_versions_map([Subject])
self.assertEqual(empty_map, {})

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._write_new_file")
def test_update_files(self, mock_write: MagicMock):
"""Tests the update_files method"""
Expand All @@ -72,9 +134,11 @@ def test_update_files(self, mock_write: MagicMock):
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
old_rig_version = Rig.model_fields["schema_version"].default
new_rig_version = str(Version.parse(old_rig_version).bump_minor())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker
handler._update_files({Subject: new_subject_version, Session: new_session_version})
handler._update_files({Subject: new_subject_version, Session: new_session_version, Rig: new_rig_version})

expected_line_change0 = (
f'schema_version: SkipValidation[Literal["{new_subject_version}"]] = Field(default="{new_subject_version}")'
Expand All @@ -90,16 +154,30 @@ def test_update_files(self, mock_write: MagicMock):
self.assertTrue(expected_line_change1 in str(mock_write_args1[0]))
self.assertTrue("session.py" in str(mock_write_args1[1]))

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files")
def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock):
def test_run_job(
self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock
):
"""Tests run_job method"""

old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())

mock_get_list_of_models.return_value = [Subject, Session]

def side_effect(model):
"""Return values for get_schema_json"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

handler = SchemaVersionHandler(json_schemas_location=Path("."))
handler.run_job()
mock_update_files.assert_called_once_with({Subject: new_subject_version, Session: new_session_version})
Expand Down

0 comments on commit 37cb5c3

Please sign in to comment.