Skip to content

Commit

Permalink
Extends marshmallow (#203)
Browse files Browse the repository at this point in the history
* Better handling of extended models - special handling for marshmallow and ui/marshmallow to simplify the extended model

* Handling imports in extended schemas
  • Loading branch information
mesemus authored Jul 17, 2023
1 parent b65baa7 commit db07dfc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def _create_marshmallow_field(
):
f = []
datatypes.call_components(section.item, field_accessor, fields=f)
if not f:
return

item_field: MarshmallowField = f[0]
f = []
super()._create_marshmallow_field(
Expand All @@ -22,6 +25,9 @@ def _create_marshmallow_field(
fields=f,
item_field=item_field,
)
if not f:
return

fld: MarshmallowField = f[0]
fld.imports.extend(item_field.imports)
fld.reference = item_field.reference
Expand Down
11 changes: 8 additions & 3 deletions oarepo_model_builder/profiles/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from oarepo_model_builder.schema import ModelSchema
from oarepo_model_builder.utils.deepmerge import deepmerge
from oarepo_model_builder.utils.dict import dict_get
from oarepo_model_builder.utils.python_name import base_name


class RecordProfile(Profile):
Expand Down Expand Up @@ -114,10 +115,14 @@ def marshmallow_merge(target, source, stack):
return None # use default merging

if "class" in source:
target.setdefault("base-classes", []).append(source["class"])
target.setdefault("base-classes", []).append(base_name(source["class"]))
target.setdefault("imports", []).append({"import": source["class"]})
elif "base-classes" in source:
target.setdefault("base-classes", []).extend(source["base-classes"])
target_base_classes = target.setdefault("base-classes", [])
target_imports = target.setdefault("imports", [])
for bc in source["base-classes"]:
target.setdefault("imports", []).append({"import": bc})
target_base_classes.append(base_name(bc))
if "." in bc:
target_imports.append({"import": bc})
target_imports.extend(source.get("imports", []))
return target
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = oarepo-model-builder
version = 4.0.28
version = 4.0.29
description = A utility library that generates OARepo required data model files from a JSON specification file
authors = Miroslav Bauer <[email protected]>, Miroslav Simek <[email protected]>
readme = README.md
Expand Down
11 changes: 7 additions & 4 deletions tests/test_extend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pprint import pprint

import json5

from oarepo_model_builder.entrypoints import create_builder_from_entrypoints
Expand Down Expand Up @@ -54,13 +56,14 @@ def test_extend_in_model():
)
RecordProfile().build(model, "record", ["record"], "", builder)
schema = fs.read("test/services/records/schema.py")
assert "class TestSchema(aaa.BlahSchema)" in schema
assert "class TestMetadataSchema(aaa.BlahMetadataSchema)" in schema
pprint(schema)
assert "class TestSchema(BlahSchema)" in schema
assert "class TestMetadataSchema(BlahMetadataSchema)" in schema
assert "metadata = ma.fields.Nested(lambda: TestMetadataSchema())" in schema

schema = fs.read("test/services/records/ui_schema.py")
assert "class TestUISchema(aaa.BlahUISchema)" in schema
assert "class TestMetadataUISchema(aaa.BlahMetadataUISchema)" in schema
assert "class TestUISchema(BlahUISchema)" in schema
assert "class TestMetadataUISchema(BlahMetadataUISchema)" in schema
assert "metadata = ma.fields.Nested(lambda: TestMetadataUISchema())" in schema


Expand Down

0 comments on commit db07dfc

Please sign in to comment.