Skip to content

Commit

Permalink
Model-graph resolution for compound select, better column resolution.
Browse files Browse the repository at this point in the history
When doing a compound query where columns are selected from multiple
models, try to use the model-graph cursor wrapper. Also prevent
attribute overwriting when using flat cursor wrapper.

Refs #1579.
  • Loading branch information
coleifer committed Apr 18, 2018
1 parent 6bb5af2 commit 80a2048
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
16 changes: 10 additions & 6 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -3470,7 +3470,7 @@ def _initialize_columns(self):
def _row_to_dict(self, row):
result = {}
for i in range(self.ncols):
result[self.columns[i]] = row[i]
result.setdefault(self.columns[i], row[i]) # Do not overwite.
return result

process_row = _row_to_dict
Expand Down Expand Up @@ -5555,6 +5555,11 @@ def __iter__(self):
self.execute()
return iter(self._cursor_wrapper)

@Node.copy
def objects(self, constructor=None):
self._row_type = ROW.CONSTRUCTOR
self._constructor = self.model if constructor is None else constructor

def prefetch(self, *subqueries):
return prefetch(self, *subqueries)

Expand Down Expand Up @@ -5592,6 +5597,9 @@ def __init__(self, model, *args, **kwargs):
self.model = model
super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs)

def _get_model_cursor_wrapper(self, cursor):
return self.lhs._get_model_cursor_wrapper(cursor)


class ModelSelect(BaseModelSelect, Select):
def __init__(self, model, fields_or_models, is_default=False):
Expand Down Expand Up @@ -5625,11 +5633,6 @@ def switch(self, ctx=None):
self._join_ctx = self.model if ctx is None else ctx
return self

@Node.copy
def objects(self, constructor=None):
self._row_type = ROW.CONSTRUCTOR
self._constructor = self.model if constructor is None else constructor

def _get_model(self, src):
if is_model(src):
return src, True
Expand Down Expand Up @@ -6031,6 +6034,7 @@ def process_row(self, row):

for i in range(self.ncols):
attr = columns[i]
if attr in result: continue # Don't overwrite if we have dupes.
if converters[i] is not None:
result[attr] = converters[i](row[i])
else:
Expand Down
51 changes: 51 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,57 @@ def test_compound_select(self):
self.assertEqual([row.value for row in c2], [0, 1, 5, 8, 9])
self.assertEqual(c2.count(), 5)

@requires_models(User, Tweet)
def test_union_column_resolution(self):
u1 = User.create(username='u1')
u2 = User.create(username='u2')
q1 = User.select().where(User.id == 1)
q2 = User.select()
union = q1 | q2
self.assertSQL(union, (
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."id" = ?) '
'UNION '
'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2"'), [1])

results = [(user.id, user.username) for user in union]
self.assertEqual(sorted(results), [
(1, 'u1'),
(2, 'u2')])

t1_1 = Tweet.create(user=u1, content='u1-t1')
t1_2 = Tweet.create(user=u1, content='u1-t2')
t2_1 = Tweet.create(user=u2, content='u2-t1')
q1 = Tweet.select(Tweet, User).join(User).where(User.id == 1)
q2 = Tweet.select(Tweet, User).join(User)
union = q1 | q2
self.assertSQL(union, (
'SELECT "t1"."id", "t1"."user_id", "t1"."content", '
'"t1"."timestamp", "t2"."id", "t2"."username" '
'FROM "tweet" AS "t1" '
'INNER JOIN "users" AS "t2" ON ("t1"."user_id" = "t2"."id") '
'WHERE ("t2"."id" = ?) '
'UNION '
'SELECT "t3"."id", "t3"."user_id", "t3"."content", '
'"t3"."timestamp", "t4"."id", "t4"."username" '
'FROM "tweet" AS "t3" '
'INNER JOIN "users" AS "t4" ON ("t3"."user_id" = "t4"."id")'), [1])

with self.assertQueryCount(1):
results = [(t.id, t.content, t.user.username) for t in union]
self.assertEqual(sorted(results), [
(1, 'u1-t1', 'u1'),
(2, 'u1-t2', 'u1'),
(3, 'u2-t1', 'u2')])

union_flat = (q1 | q2).objects()
with self.assertQueryCount(1):
results = [(t.id, t.content, t.username) for t in union_flat]
self.assertEqual(sorted(results), [
(1, 'u1-t1', 'u1'),
(2, 'u1-t2', 'u1'),
(3, 'u2-t1', 'u2')])

@requires_models(Category)
def test_self_referential_fk(self):
self.assertTrue(Category.parent.rel_model is Category)
Expand Down

0 comments on commit 80a2048

Please sign in to comment.