Skip to content

Commit

Permalink
Better handling of extended models - special handling for marshmallow…
Browse files Browse the repository at this point in the history
… and ui/marshmallow to simplify the extended model (#202)
  • Loading branch information
mesemus authored Jul 17, 2023
1 parent cc9c9ff commit b65baa7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 11 deletions.
35 changes: 34 additions & 1 deletion oarepo_model_builder/profiles/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def build(
**kwargs,
):
current_model = dict_get(model.schema, model_path)

# record has type "model" if not explicitly stated otherwise
if "type" not in current_model:
current_model["type"] = "model"

if "extend" in current_model:
self.handle_extend(
current_model["extend"],
Expand All @@ -35,6 +40,7 @@ def build(
current_model,
builder,
)

return super().build(
model, profile, model_path, output_directory, builder, **kwargs
)
Expand Down Expand Up @@ -87,4 +93,31 @@ def handle_extend(
ExtendProfile().build(extended_model, profile, model_path, "", builder)

loaded_model = json5.loads(fs.read("model.json5"))
deepmerge(current_model, loaded_model)
# extension means that:
# 1. deep merge everything as usual, but
# 2. keep special attention to marshmallow - if parent defined inside schema, move
# extension's marshmallow class to base classes, add imports and do not merge

deepmerge(current_model, loaded_model, dictmerge=marshmallow_merge)


def marshmallow_merge(target, source, stack):
# if target...marshmallow does not exist, source...marshmallow (that is, marshmallow from extended model)
# is copied automatically
if not stack or stack[-1] != "marshmallow":
if "properties" in target:
# it is an object, add marshmallow and ui/marshmallow sections, so that correct merging is performed
target.setdefault("marshmallow", {})
target.setdefault("ui", {}).setdefault("marshmallow", {})
# and use the default merging

return None # use default merging

if "class" in source:
target.setdefault("base-classes", []).append(source["class"])
target.setdefault("imports", []).append({"import": source["class"]})
elif "base-classes" in source:
target.setdefault("base-classes", []).extend(source["base-classes"])
for bc in source["base-classes"]:
target.setdefault("imports", []).append({"import": bc})
return target
40 changes: 31 additions & 9 deletions oarepo_model_builder/utils/deepmerge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import copy
from typing import Union


def deepmerge(target, source, stack=None, listmerge="overwrite"):
def deepmerge(
target,
source,
stack=None,
listmerge: Union[str, callable] = "overwrite",
dictmerge=None,
):
if stack is None:
stack = []

Expand All @@ -11,13 +18,22 @@ def deepmerge(target, source, stack=None, listmerge="overwrite"):
raise AttributeError(
f"Incompatible source and target on path {stack}: source {source}, target {target}"
)
for k, v in source.items():
if k not in target:
target[k] = source[k]
else:
target[k] = deepmerge(
target[k], source[k], stack + [k], listmerge=listmerge
)
if dictmerge:
merged = dictmerge(target, source, stack)
else:
merged = None
if merged is None:
for k, v in source.items():
if k not in target:
target[k] = source[k]
else:
target[k] = deepmerge(
target[k],
source[k],
stack + [k],
listmerge=listmerge,
dictmerge=dictmerge,
)
elif isinstance(target, list):
if source is not None:
if not isinstance(source, list):
Expand All @@ -27,7 +43,11 @@ def deepmerge(target, source, stack=None, listmerge="overwrite"):
if listmerge == "overwrite":
for idx in range(min(len(source), len(target))):
target[idx] = deepmerge(
target[idx], source[idx], stack + [idx], listmerge=listmerge
target[idx],
source[idx],
stack + [idx],
listmerge=listmerge,
dictmerge=dictmerge,
)
for idx in range(len(target), len(source)):
target.append(source[idx])
Expand All @@ -36,6 +56,8 @@ def deepmerge(target, source, stack=None, listmerge="overwrite"):
elif listmerge == "keep":
if len(source) > len(target):
target.extend(source[len(target) :])
elif callable(listmerge):
listmerge(target, source, stack)
else:
raise AttributeError(
'listmerge must be one of "overwrite", "extend" or "keep"'
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = oarepo-model-builder
version = 4.0.27
version = 4.0.28
description = A utility library that generates OARepo required data model files from a JSON specification file
authors = Miroslav Bauer <[email protected]>, Miroslav Simek <[email protected]>
readme = README.md
Expand Down
49 changes: 49 additions & 0 deletions tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from oarepo_model_builder.entrypoints import create_builder_from_entrypoints
from oarepo_model_builder.fs import InMemoryFileSystem
from oarepo_model_builder.profiles.extend import ExtendProfile
from oarepo_model_builder.profiles.record import RecordProfile
from oarepo_model_builder.schema import ModelSchema


Expand Down Expand Up @@ -39,6 +40,54 @@ def test_extend_property_preprocessor():
check_marshmallow(loaded_model, "")


def test_extend_in_model():
fs = InMemoryFileSystem()
builder = create_builder_from_entrypoints(
filesystem=fs,
)

model = ModelSchema(
"/tmp/test.json", # NOSONAR: this is just a dummy path
content=extension_model,
included_models={"extended-model": lambda parent_schema: extended_model},
validate=True,
)
RecordProfile().build(model, "record", ["record"], "", builder)
schema = fs.read("test/services/records/schema.py")
assert "class TestSchema(aaa.BlahSchema)" in schema
assert "class TestMetadataSchema(aaa.BlahMetadataSchema)" in schema
assert "metadata = ma.fields.Nested(lambda: TestMetadataSchema())" in schema

schema = fs.read("test/services/records/ui_schema.py")
assert "class TestUISchema(aaa.BlahUISchema)" in schema
assert "class TestMetadataUISchema(aaa.BlahMetadataUISchema)" in schema
assert "metadata = ma.fields.Nested(lambda: TestMetadataUISchema())" in schema


extension_model = {
"record": {
"module": {"qualified": "test"},
"extend": "extended-model",
"properties": {"metadata": {"properties": {}}},
},
"settings": {
"python": {"use-black": False, "use-isort": False, "use-autoflake": False}
},
}

extended_model = {
"marshmallow": {"class": "aaa.BlahSchema"},
"ui": {"marshmallow": {"class": "aaa.BlahUISchema"}},
"properties": {
"metadata": {
"type": "object",
"marshmallow": {"class": "aaa.BlahMetadataSchema"},
"ui": {"marshmallow": {"class": "aaa.BlahMetadataUISchema"}},
"properties": {},
}
},
}

nr_documents_model = {
"record": {
"type": "object",
Expand Down

0 comments on commit b65baa7

Please sign in to comment.