Skip to content

Commit

Permalink
Pruning extended classes (#255)
Browse files Browse the repository at this point in the history
* Pruning extended classes

* Removed marshmallow from children

* Removed unneeded parameter

* Removed dead code

* Updating only non-existing keys
  • Loading branch information
mesemus authored Feb 28, 2024
1 parent ecac851 commit c5b3740
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 76 deletions.
11 changes: 11 additions & 0 deletions oarepo_model_builder/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ def load_model(
"oarepo_model_builder.loaders.extend", None
),
},
post_reference_processors={
ModelSchema.REF_KEYWORD: load_entry_points_list(
"oarepo_model_builder.loaders.post.ref", None
),
ModelSchema.USE_KEYWORD: load_entry_points_list(
"oarepo_model_builder.loaders.post.use", None
),
ModelSchema.EXTEND_KEYWORD: load_entry_points_list(
"oarepo_model_builder.loaders.post.extend", None
),
},
)
for config in configs:
load_config(schema, config, loaders)
Expand Down
156 changes: 101 additions & 55 deletions oarepo_model_builder/loaders/extend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from oarepo_model_builder.schema import ModelSchema
from oarepo_model_builder.utils.python_name import package_name
from oarepo_model_builder.validation import InvalidModelException


def extract_extended_record(included_data, *, context, **kwargs):
Expand Down Expand Up @@ -32,70 +30,28 @@ def extract_extended_record(included_data, *, context, **kwargs):

def extend_modify_marshmallow(included_data, *, context, **kwargs):
"""
This processor modified marshmallow of the extended object. At first, it puts
marshmallow and ui back to the included data. Then, for the top-level marshmallow & ui.marshmallow
it converts class -> base-classes and adds import for that.
For the properties, it marks them as read=False and write=False and for each object, it marks it as
generate=False - this way, the classes will be reused from the already existing library and not
generated again.
This processor moves the marshmallow section of the base record to base-class-marshmallow
and base-class-ui-marshmallow. It also sets the from-base-class flag to True.
"""

def remove_marshmallow_from_children(node):
def mark_as_from_base_class(node):
ret = {**node}
node_properties = ret.pop("properties", None)
node_items = ret.pop("items", None)

if "marshmallow" in ret:
convert_marshmallow_class_to_base_class(ret["marshmallow"])
elif "properties" in ret:
raise InvalidModelException(
f"marshmallow section not in object {node}. "
f"Please pass generated model (records.json5), not the source model."
)

if "ui" in ret and "marshmallow" in ret["ui"]:
convert_marshmallow_class_to_base_class(ret["ui"]["marshmallow"])
elif "properties" in ret:
raise InvalidModelException(
f"ui.marshmallow section not in object {node}. "
f"Please pass generated model (records.json5), not the source model."
)
ret["from-base-class"] = True

if node_properties:
properties = ret.setdefault("properties", {})
for k, v in node_properties.items():
remove_marshmallow_from_child(v)
v = remove_marshmallow_from_children(v)
v = mark_as_from_base_class(v)
properties[k] = v
if node_items:
remove_marshmallow_from_child(node.item)
ret["items"] = remove_marshmallow_from_children(node.item)
return ret

def remove_marshmallow_from_child(child):
# for object/nested, do not set the read & write to False because
# the extending schema might add more properties.
# This will generate unnecessary classes, but these might be dealt
# on later in marshmallow generator
if "properties" not in child and "items" not in child:
marshmallow = child.setdefault("marshmallow", {})
marshmallow.update({"read": False, "write": False})

ui_marshmallow = child.setdefault("ui", {}).setdefault("marshmallow", {})
ui_marshmallow.update({"read": False, "write": False})

def convert_marshmallow_class_to_base_class(marshmallow):
# pop module & package
marshmallow.pop("module", None)
marshmallow.pop("package", None)

if "class" not in marshmallow:
return
clz = marshmallow.pop("class")
marshmallow.setdefault("base-classes", []).insert(0, clz)
marshmallow.setdefault("imports", []).append(
{"import": package_name(clz), "alias": package_name(clz)}
ret["items"] = mark_as_from_base_class(node.item)
ret["base-class-marshmallow"] = ret.pop("marshmallow", {})
ret["base-class-ui-marshmallow"] = ret.setdefault("ui", {}).pop(
"marshmallow", {}
)
return ret

def as_array(x):
if isinstance(x, list):
Expand Down Expand Up @@ -124,7 +80,7 @@ def replace_use_with_extend(data):

included_data["marshmallow"] = context["props"].get("marshmallow", {})
included_data["ui"] = context["props"].get("ui", {})
ret = remove_marshmallow_from_children(included_data)
ret = mark_as_from_base_class(included_data)

for ext in (
ModelSchema.EXTEND_KEYWORD,
Expand All @@ -136,3 +92,93 @@ def replace_use_with_extend(data):

replace_use_with_extend(ret)
return ret


def post_extend_modify_marshmallow(*, element, **kwargs):
def convert_schema_classes(node):
node_properties = node.get("properties", None)
node_items = node.get("items", None)

was_inherited = "from-base-class" in node
if not was_inherited:
return False

contains_only_inherited_properties = node.pop("from-base-class", False)
if node_properties:
for k, v in node_properties.items():
prop_contains_only_inherited_properties = convert_schema_classes(v)
if not prop_contains_only_inherited_properties:
contains_only_inherited_properties = False
elif node_items:
contains_only_inherited_properties = (
convert_schema_classes(node_items)
and contains_only_inherited_properties
)
base_class_marshmallow = node.pop("base-class-marshmallow", {})
base_class_ui_marshmallow = node.pop("base-class-ui-marshmallow", {})

def update_marshmallow(new_marshmallow, base_marshmallow):
if new_marshmallow.get("generate", True) is False:
# the class is set to not generate -> if there is a class, do not change it,
# if not, set it to the base class
if not new_marshmallow.get("class") and base_marshmallow.get("class"):
new_marshmallow["class"] = base_marshmallow["class"]
return

if "items" in node:
# array itself does not have a marshmallow, so no need to modify this
_update_non_existing(new_marshmallow, base_marshmallow)
return

if "properties" not in node:
# primitive data type -> set it not to be generated unless the field says otherwise
if "read" not in new_marshmallow:
new_marshmallow["read"] = False
if "write" not in new_marshmallow:
new_marshmallow["write"] = False
for k, v in base_marshmallow.items():
if k not in new_marshmallow:
new_marshmallow[k] = v
return

# now we have an object to modify - convert to base classes only if there are extra properties
convert_to_base_classes = (
node_properties and not contains_only_inherited_properties
)

if "class" in new_marshmallow:
# someone added class to the new_marshmallow, so we do not want to change it
convert_to_base_classes = True

if convert_to_base_classes:
if base_marshmallow.get("class"):
new_marshmallow.setdefault("base-classes", []).insert(
0, base_marshmallow["class"]
)
new_marshmallow["generate"] = True

elif contains_only_inherited_properties:
# keep the base class marshmallow, but do not generate the class as it has been generated
# in the extended library
new_marshmallow.clear()
new_marshmallow.update(base_marshmallow)
new_marshmallow["generate"] = False

else:
_update_non_existing(new_marshmallow, base_marshmallow)

update_marshmallow(node.setdefault("marshmallow", {}), base_class_marshmallow)
update_marshmallow(
node.setdefault("ui", {}).setdefault("marshmallow", {}),
base_class_ui_marshmallow,
)

return contains_only_inherited_properties

convert_schema_classes(element)


def _update_non_existing(target, source):
for k, v in source.items():
if k not in target:
target[k] = v
18 changes: 17 additions & 1 deletion oarepo_model_builder/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
validate=True,
source_locations=None,
reference_processors=None,
post_reference_processors=None,
):
"""
Creates and parses model schema
Expand All @@ -48,7 +49,15 @@ def __init__(
self.USE_KEYWORD: [],
self.EXTEND_KEYWORD: [],
},
reference_processors,
reference_processors or {},
)
self._post_reference_processors = deepmerge(
{
self.REF_KEYWORD: [],
self.USE_KEYWORD: [],
self.EXTEND_KEYWORD: [],
},
post_reference_processors or {},
)

if content is not None:
Expand Down Expand Up @@ -234,6 +243,13 @@ def _load_and_merge_reference(self, element, key, name, stack):
context=context,
)
deepmerge(element, included_data, [], listmerge="keep")
for rp in self._post_reference_processors[key]:
rp(
element=element,
key=key,
name=name,
context=context,
)

@property
def abs_path(self):
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = oarepo-model-builder
version = 4.0.76
version = 4.0.77
description = A utility library that generates OARepo required data model files from a JSON specification file
authors = Miroslav Bauer <[email protected]>, Miroslav Simek <[email protected]>
readme = README.md
Expand Down Expand Up @@ -173,6 +173,9 @@ oarepo_model_builder.loaders.extend =
0100-extract_record = oarepo_model_builder.loaders.extend:extract_extended_record
1000-modify-marshmallow = oarepo_model_builder.loaders.extend:extend_modify_marshmallow

oarepo_model_builder.loaders.post.extend =
1000-modify-marshmallow = oarepo_model_builder.loaders.extend:post_extend_modify_marshmallow


oarepo_model_builder.settings =
1000-python = oarepo_model_builder.settngs:python.json
Expand Down
36 changes: 17 additions & 19 deletions tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from oarepo_model_builder.loaders.extend import (
extend_modify_marshmallow,
extract_extended_record,
post_extend_modify_marshmallow,
)
from oarepo_model_builder.schema import ModelSchema

Expand All @@ -26,24 +27,28 @@ def test_extend_marshmallow():
extend_modify_marshmallow,
]
},
post_reference_processors={
ModelSchema.EXTEND_KEYWORD: [
post_extend_modify_marshmallow,
]
},
)
builder.build(model, "record", ["record"], "")

loaded_model = json5.loads(fs.read("test/models/records.json"))
assert loaded_model["model"]["marshmallow"] == {
"base-classes": ["aaa.BlahSchema"],
"class": "test.services.records.schema.TestSchema",
"base-classes": ["marshmallow.Schema"],
"class": "aaa.BlahSchema",
"extra-code": "",
"generate": True,
"imports": [{"alias": "aaa", "import": "aaa"}],
"generate": False,
"module": "test.services.records.schema",
}
assert loaded_model["model"]["ui"]["marshmallow"] == {
"base-classes": ["aaa.BlahUISchema"],
"class": "test.services.records.ui_schema.TestUISchema",
"base-classes": ["oarepo_runtime.services.schema.ui.InvenioUISchema"],
"class": "aaa.BlahUISchema",
"extra-code": "",
"generate": True,
"imports": [{"alias": "aaa", "import": "aaa"}],
"generate": False,
"imports": [],
"module": "test.services.records.ui_schema",
}
# assert that "a" is read & write false
Expand All @@ -53,17 +58,10 @@ def test_extend_marshmallow():
assert property_a["ui"]["marshmallow"]["read"] is False
assert property_a["ui"]["marshmallow"]["write"] is False

schema = fs.read("test/services/records/schema.py")
print(schema)
assert "class TestSchema(BlahSchema)" in schema
assert "class TestMetadataSchema(BlahMetadataSchema)" in schema
assert "metadata = ma_fields.Nested(lambda: TestMetadataSchema())" in schema

schema = fs.read("test/services/records/ui_schema.py")
print(schema)
assert "class TestUISchema(BlahUISchema)" in schema
assert "class TestMetadataUISchema(BlahMetadataUISchema)" in schema
assert "metadata = ma_fields.Nested(lambda: TestMetadataUISchema())" in schema
service_config = fs.read("test/services/records/config.py")
print(service_config)
assert "from aaa import BlahSchema" in service_config
assert "schema = BlahSchema" in service_config


extension_model = {
Expand Down

0 comments on commit c5b3740

Please sign in to comment.