Skip to content

Commit

Permalink
Implement aliased model rename.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Jan 25, 2025
1 parent 076e672 commit ec6dd7e
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 7 deletions.
125 changes: 125 additions & 0 deletions syzygy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,128 @@ class AlterField(StagedOperation, operations.AlterField):
"""
Subclass of ``AlterField`` that allows explicitly defining a stage.
"""


class AliasOperationMixin:
@staticmethod
def _create_instead_of_triggers(schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF INSERT ON {view_db_name}\n"
"BEGIN\n"
"INSERT INTO {new_table}({fields}) VALUES({values});\n"
"END"
).format(
trigger_name=f"{view_db_name}_insert",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
fields=", ".join(
quote(field.column) for field in new_model._meta.local_fields
),
values=", ".join(
f"NEW.{quote(field.column)}"
for field in new_model._meta.local_fields
),
)
)
for field in new_model._meta.local_fields:
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF UPDATE OF {column} ON {view_db_name}\n"
"BEGIN\n"
"UPDATE {new_table} SET {column}=NEW.{column} WHERE {pk}=NEW.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_update_{field.column}",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
column=quote(field.column),
pk=quote(new_model._meta.pk.column),
)
)
schema_editor.execute(
(
"CREATE TRIGGER {trigger_name} INSTEAD OF DELETE ON {view_db_name}\n"
"BEGIN\n"
"DELETE FROM {new_table} WHERE {pk}=OLD.{pk};\n"
"END"
).format(
trigger_name=f"{view_db_name}_delete",
view_db_name=quote(view_db_name),
new_table=quote(new_model._meta.db_table),
pk=quote(new_model._meta.pk.column),
)
)

@classmethod
def create_view(cls, schema_editor, view_db_name, new_model):
quote = schema_editor.quote_name
schema_editor.execute(
"CREATE VIEW {} AS SELECT * FROM {}".format(
quote(view_db_name), quote(new_model._meta.db_table)
)
)
if schema_editor.connection.vendor == "sqlite":
cls._create_instead_of_triggers(schema_editor, view_db_name, new_model)

@staticmethod
def drop_view(schema_editor, db_table):
schema_editor.execute("DROP VIEW {}".format(schema_editor.quote_name(db_table)))


class AliasedRenameModel(AliasOperationMixin, operations.RenameModel):
stage = Stage.PRE_DEPLOY

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
old_model = from_state.apps.get_model(app_label, self.old_name)
view_db_name = old_model._meta.db_table
super().database_forwards(app_label, schema_editor, from_state, to_state)
self.create_view(schema_editor, view_db_name, new_model)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.old_name_lower)
alias = schema_editor.connection.alias
if not self.allow_migrate_model(alias, new_model):
return
self.drop_view(schema_editor, new_model._meta.db_table)
super().database_backwards(app_label, schema_editor, from_state, to_state)

def describe(self):
return "Rename model %s to %s while creating an alias for %s" % (
self.old_name,
self.new_name,
self.old_name,
)

def reduce(self, operation, app_label):
if (
isinstance(operation, UnaliasModel)
and operation.name_lower == self.new_name_lower
):
return [operations.RenameModel(self.old_name, self.new_name)]
return super().reduce(operation, app_label)


class UnaliasModel(AliasOperationMixin, operations.models.ModelOperation):
stage = Stage.POST_DEPLOY

def __init__(self, name, view_db_name):
self.view_db_name = view_db_name
super().__init__(name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.drop_view(self.view_db_name)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if not self.allow_migrate_model(schema_editor.connection.alias, model):
return
self.create_view(schema_editor, self.view_db_name, model)
81 changes: 74 additions & 7 deletions tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import mock, skipUnless

from django.db import connection, migrations, models
from django.db.migrations.operations import RenameModel
from django.db.migrations.operations.base import Operation
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.serializer import OperationSerializer
Expand All @@ -13,6 +14,8 @@
from syzygy.compat import field_db_default_supported
from syzygy.constants import Stage
from syzygy.operations import (
AliasedRenameModel,
UnaliasModel,
get_post_add_field_operation,
get_pre_add_field_operation,
get_pre_remove_field_operation,
Expand All @@ -36,16 +39,18 @@ def setUpClass(cls):
def tearDown(self):
super().tearDown()
with connection.cursor() as cursor:
created_tables = {
table.name
created_tables = [
(table_name, table.type == "v")
for table in connection.introspection.get_table_list(cursor)
} - self.tables
if (table_name := table.name) not in self.tables
]
if created_tables:
with connection.schema_editor() as schema_editor:
sql_delete_table = schema_editor.sql_delete_table
for table in created_tables:
for name, is_view in created_tables:
with connection.cursor() as cursor:
cursor.execute(sql_delete_table % {"table": table})
if is_view:
cursor.execute(f"DROP VIEW {name}")
else:
cursor.execute(f"DROP TABLE {name}")


class OperationTestCase(SchemaTestCase):
Expand Down Expand Up @@ -436,3 +441,65 @@ def test_defined_db_default(self):
get_pre_remove_field_operation(
"model", "field", models.IntegerField(db_default=42)
)


class AliasedRenameModelTests(OperationTestCase):
def test_describe(self):
self.assertEqual(
AliasedRenameModel("OldName", "NewName").describe(),
"Rename model OldName to NewName while creating an alias for OldName",
)

def _apply_forwards(self):
model_name = "TestModel"
field = models.IntegerField()
pre_state = self.apply_operations(
[
migrations.CreateModel(model_name, [("foo", field)]),
]
)
new_model_name = "NewTestModel"
post_state = self.apply_operations(
[
AliasedRenameModel(model_name, new_model_name),
],
pre_state,
)
return (pre_state, model_name), (post_state, new_model_name)

def test_database_forwards(self):
(pre_state, _), (post_state, new_model_name) = self._apply_forwards()
pre_model = pre_state.apps.get_model("tests", "testmodel")
pre_obj = pre_model.objects.create(foo=1)
if connection.vendor == "sqlite":
# SQLite doesn't allow the usage of RETURNING in INSTEAD OF INSERT
# triggers and thus the object has to be refetched.
pre_obj = pre_model.objects.latest("pk")
self.assertEqual(pre_model.objects.get(), pre_obj)
post_model = post_state.apps.get_model("tests", new_model_name)
self.assertEqual(post_model.objects.get().pk, pre_obj.pk)
pre_model.objects.all().delete()
post_obj = post_model.objects.create(foo=2)
self.assertEqual(post_model.objects.get(), post_obj)
self.assertEqual(pre_model.objects.get().pk, post_obj.pk)
pre_model.objects.update(foo=3)
self.assertEqual(post_model.objects.get().foo, 3)

def test_database_backwards(self):
(pre_state, model_name), (post_state, new_model_name) = self._apply_forwards()
with connection.schema_editor() as schema_editor:
AliasedRenameModel(model_name, new_model_name).database_backwards(
"tests", schema_editor, post_state, pre_state
)

def test_elidable(self):
model_name = "TestModel"
new_model_name = "NewTestModel"
operations = [
AliasedRenameModel(
model_name,
new_model_name,
),
UnaliasModel(new_model_name, "tests_testmodel"),
]
self.assert_optimizes_to(operations, [RenameModel(model_name, new_model_name)])

0 comments on commit ec6dd7e

Please sign in to comment.