Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index creation in Tortoise.generate_schemas() for MySQL and Postgres #1847

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.indexes import Index
from tests.testmodels import ModelWithIndexes


class CustomIndex(Index):
Expand All @@ -14,7 +15,7 @@ def __init__(self, *args, **kw):
self._foo = ""


class TestIndexHashEqualRepr(test.TestCase):
class TestIndexHashEqualRepr(test.SimpleTestCase):
def test_index_eq(self):
assert Index(fields=("id",)) == Index(fields=("id",))
assert CustomIndex(fields=("id",)) == CustomIndex(fields=("id",))
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_index_repr(self):
assert repr(Index(fields=("id",), name="MyIndex")) == "Index(fields=['id'], name='MyIndex')"
assert repr(Index(Field("id"))) == f'Index({str(Field("id"))})'
assert repr(Index(Field("a"), name="Id")) == f"Index({str(Field('a'))}, name='Id')"
with self.assertRaises(ValueError):
with self.assertRaises(ConfigurationError):
Index(Field("id"), fields=("name",))


Expand Down Expand Up @@ -94,3 +95,11 @@ class TestIndexAliasUUID(TestIndexAlias):
class TestIndexAliasChar(TestIndexAlias):
Field = fields.CharField
init_kwargs = {"max_length": 10}


class TestModelWithIndexes(test.TestCase):
def test_meta(self):
self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))])
self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique)
20 changes: 10 additions & 10 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,10 @@ async def test_index_safe(self):
"""CREATE TABLE IF NOT EXISTS `index` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`full_text` LONGTEXT NOT NULL,
`geometry` GEOMETRY NOT NULL
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX IF NOT EXISTS `idx_index_full_te_3caba4` ON `index` (`full_text`) WITH PARSER ngram;
CREATE SPATIAL INDEX IF NOT EXISTS `idx_index_geometr_0b4dfb` ON `index` (`geometry`);""",
`geometry` GEOMETRY NOT NULL,
FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram,
SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`)
) CHARACTER SET utf8mb4;""",
)

async def test_index_unsafe(self):
Expand All @@ -738,10 +738,10 @@ async def test_index_unsafe(self):
"""CREATE TABLE `index` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`full_text` LONGTEXT NOT NULL,
`geometry` GEOMETRY NOT NULL
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_index_full_te_3caba4` ON `index` (`full_text`) WITH PARSER ngram;
CREATE SPATIAL INDEX `idx_index_geometr_0b4dfb` ON `index` (`geometry`);""",
`geometry` GEOMETRY NOT NULL,
FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram,
SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`)
) CHARACTER SET utf8mb4;""",
)

async def test_m2m_no_auto_create(self):
Expand Down Expand Up @@ -1102,7 +1102,7 @@ async def test_index_unsafe(self):
CREATE INDEX "idx_index_gist_c807bf" ON "index" USING GIST ("gist");
CREATE INDEX "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist");
CREATE INDEX "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash");
CREATE INDEX "idx_index_partial_c5be6a" ON "index" USING ("partial") WHERE id = 1;""",
CREATE INDEX "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1;""",
)

async def test_index_safe(self):
Expand All @@ -1126,7 +1126,7 @@ async def test_index_safe(self):
CREATE INDEX IF NOT EXISTS "idx_index_gist_c807bf" ON "index" USING GIST ("gist");
CREATE INDEX IF NOT EXISTS "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist");
CREATE INDEX IF NOT EXISTS "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash");
CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" USING ("partial") WHERE id = 1;""",
CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1;""",
)

async def test_m2m_no_auto_create(self):
Expand Down
17 changes: 17 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tortoise import fields
from tortoise.exceptions import ValidationError
from tortoise.fields import NO_ACTION
from tortoise.indexes import Index
from tortoise.manager import Manager
from tortoise.models import Model
from tortoise.queryset import QuerySet
Expand Down Expand Up @@ -1050,3 +1051,19 @@ class BenchmarkManyFields(Model):
col_text4 = fields.TextField(null=True)
col_decimal4 = fields.DecimalField(12, 8, null=True)
col_json4 = fields.JSONField[dict](null=True)


class ModelWithIndexes(Model):
id = fields.IntField(primary_key=True)
indexed = fields.CharField(max_length=16, index=True)
unique_indexed = fields.CharField(max_length=16, unique=True)
f1 = fields.CharField(max_length=16)
f2 = fields.CharField(max_length=16)
u1 = fields.IntField()
u2 = fields.IntField()

class Meta:
indexes = [
Index(fields=["f1", "f2"]),
]
unique_together = [("u1", "u2")]
17 changes: 17 additions & 0 deletions tests/utils/test_describe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tests.testmodels import (
Event,
JSONFields,
ModelWithIndexes,
Reporter,
SourceFields,
StraightFields,
Expand Down Expand Up @@ -1561,3 +1562,19 @@ def test_describe_model_json_native(self):
"m2m_fields": [],
},
)

def test_describe_indexes_serializable(self):
val = ModelWithIndexes.describe()

self.assertEqual(
val["indexes"],
[{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}],
)

def test_describe_indexes_not_serializable(self):
val = ModelWithIndexes.describe(serializable=False)

self.assertEqual(
val["indexes"],
ModelWithIndexes._meta.indexes,
)
57 changes: 44 additions & 13 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
from hashlib import sha256
from typing import TYPE_CHECKING, Any, List, Set, Type, Union, cast
from typing import TYPE_CHECKING, Any, List, Optional, Set, Type, Union, cast

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
Expand All @@ -23,8 +25,10 @@ class BaseSchemaGenerator:
DIALECT = "sql"
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}'
INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(" INDEX", " UNIQUE INDEX")
INDEX_CREATE_TEMPLATE = (
'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};'
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX")
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
Expand Down Expand Up @@ -167,21 +171,33 @@ def _generate_fk_name(
)
return index_name

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
def _get_index_sql(
self,
model: "Type[Model]",
field_names: List[str],
safe: bool,
index_name: Optional[str] = None,
index_type: Optional[str] = None,
extra: Optional[str] = None,
) -> str:
return self.INDEX_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
index_name=self._generate_index_name("idx", model, field_names),
index_name=index_name or self._generate_index_name("idx", model, field_names),
index_type=f"{index_type} " if index_type else "",
table_name=model._meta.db_table,
fields=", ".join([self.quote(f) for f in field_names]),
extra=f"{extra}" if extra else "",
)

def _get_unique_index_sql(self, exists: str, table_name: str, field_names: List[str]) -> str:
index_name = self._generate_index_name("uidx", table_name, field_names)
return self.UNIQUE_INDEX_CREATE_TEMPLATE.format(
exists=exists,
index_name=index_name,
index_type="",
table_name=table_name,
fields=", ".join([self.quote(f) for f in field_names]),
extra="",
)

def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str]) -> str:
Expand Down Expand Up @@ -324,22 +340,37 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
self._get_unique_constraint_sql(model, unique_together_to_create)
)

# Indexes.
_indexes = [
self._get_index_sql(model, [field_name], safe=safe) for field_name in fields_with_index
]

if model._meta.indexes:
for indexes_list in model._meta.indexes:
if not isinstance(indexes_list, Index):
indexes_to_create = []
for field in indexes_list:
for index in model._meta.indexes:
if not isinstance(index, Index):
fields = []
for field in index:
field_object = model._meta.fields_map[field]
indexes_to_create.append(field_object.source_field or field)
fields.append(field_object.source_field or field)

_indexes.append(self._get_index_sql(model, indexes_to_create, safe=safe))
_indexes.append(self._get_index_sql(model, fields, safe=safe))
else:
_indexes.append(indexes_list.get_sql(self, model, safe))
if index.fields:
fields = [f for f in index.fields]
elif index.expressions:
fields = [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})"
for expression in index.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

_indexes.append(
self._get_index_sql(
model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra
)
)

field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val]

Expand Down
23 changes: 22 additions & 1 deletion tortoise/backends/base_postgres/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import TYPE_CHECKING, Any, List
from typing import TYPE_CHECKING, Any, List, Optional, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.converters import encoders
from tortoise.models import Model

if TYPE_CHECKING: # pragma: nocoverage
from .client import BasePostgresClient


class BasePostgresSchemaGenerator(BaseSchemaGenerator):
DIALECT = "postgres"
INDEX_CREATE_TEMPLATE = (
'CREATE INDEX {exists}"{index_name}" ON "{table_name}" {index_type}({fields}){extra};'
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX")
TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE \"{table}\" IS '{comment}';"
COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}'
Expand Down Expand Up @@ -61,3 +66,19 @@ def _escape_default_value(self, default: Any):
if isinstance(default, bool):
return default
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(
self,
model: "Type[Model]",
field_names: List[str],
safe: bool,
index_name: Optional[str] = None,
index_type: Optional[str] = None,
extra: Optional[str] = None,
) -> str:
if index_type:
index_type = f"USING {index_type}"

return super()._get_index_sql(
model, field_names, safe, index_name=index_name, index_type=index_type, extra=extra
)
16 changes: 13 additions & 3 deletions tortoise/backends/mssql/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Type
from typing import TYPE_CHECKING, Any, List, Optional, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.converters import encoders
Expand Down Expand Up @@ -59,8 +59,18 @@ def _column_default_generator(
def _escape_default_value(self, default: Any):
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
return super(MSSQLSchemaGenerator, self)._get_index_sql(model, field_names, False)
def _get_index_sql(
self,
model: "Type[Model]",
field_names: List[str],
safe: bool,
index_name: Optional[str] = None,
index_type: Optional[str] = None,
extra: Optional[str] = None,
) -> str:
return super(MSSQLSchemaGenerator, self)._get_index_sql(
model, field_names, False, index_name=index_name, index_type=index_type, extra=extra
)

def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
return super(MSSQLSchemaGenerator, self)._get_table_sql(model, False)
Expand Down
18 changes: 14 additions & 4 deletions tortoise/backends/mysql/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Type
from typing import TYPE_CHECKING, Any, List, Optional, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.converters import encoders
Expand All @@ -11,7 +11,7 @@
class MySQLSchemaGenerator(BaseSchemaGenerator):
DIALECT = "mysql"
TABLE_CREATE_TEMPLATE = "CREATE TABLE {exists}`{table_name}` ({fields}){extra}{comment};"
INDEX_CREATE_TEMPLATE = "KEY `{index_name}` ({fields})"
INDEX_CREATE_TEMPLATE = "{index_type}KEY `{index_name}` ({fields}){extra}"
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = "UNIQUE KEY `{index_name}` ({fields})"
UNIQUE_INDEX_CREATE_TEMPLATE = UNIQUE_CONSTRAINT_CREATE_TEMPLATE
FIELD_TEMPLATE = "`{name}` {type}{nullable}{unique}{primary}{comment}{default}"
Expand Down Expand Up @@ -68,9 +68,19 @@ def _column_default_generator(
def _escape_default_value(self, default: Any):
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
def _get_index_sql(
self,
model: "Type[Model]",
field_names: List[str],
safe: bool,
index_name: Optional[str] = None,
index_type: Optional[str] = None,
extra: Optional[str] = None,
) -> str:
"""Get index SQLs, but keep them for ourselves"""
index_create_sql = super()._get_index_sql(model, field_names, safe)
index_create_sql = super()._get_index_sql(
model, field_names, safe, index_name=index_name, index_type=index_type, extra=extra
)
self._field_indexes.append(index_create_sql)
return ""

Expand Down
16 changes: 13 additions & 3 deletions tortoise/backends/oracle/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Type
from typing import TYPE_CHECKING, Any, List, Optional, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.converters import encoders
Expand Down Expand Up @@ -85,8 +85,18 @@ def _column_default_generator(
def _escape_default_value(self, default: Any):
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
return super(OracleSchemaGenerator, self)._get_index_sql(model, field_names, False)
def _get_index_sql(
self,
model: "Type[Model]",
field_names: List[str],
safe: bool,
index_name: Optional[str] = None,
index_type: Optional[str] = None,
extra: Optional[str] = None,
) -> str:
return super(OracleSchemaGenerator, self)._get_index_sql(
model, field_names, False, index_name=index_name, index_type=index_type, extra=extra
)

def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
return super(OracleSchemaGenerator, self)._get_table_sql(model, False)
Expand Down
4 changes: 1 addition & 3 deletions tortoise/contrib/postgres/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


class PostgreSQLIndex(PartialIndex):
INDEX_CREATE_TEMPLATE = (
"CREATE INDEX {exists}{index_name} ON {table_name} USING{index_type}({fields}){extra};"
)
pass


class BloomIndex(PostgreSQLIndex):
Expand Down
Loading
Loading