Skip to content

Commit

Permalink
Remove multi-model union restriction.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Mar 28, 2018
1 parent b319bbd commit 67ceaf3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 31 deletions.
29 changes: 6 additions & 23 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -5568,26 +5568,6 @@ def __init__(self, model, *args, **kwargs):
self.model = model
super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs)

def get_query_models(self):
accum = set()
stack = [self.lhs, self.rhs]
while stack:
src = stack.pop()
if isinstance(src, ModelCompoundSelectQuery):
stack.extend([src.lhs, src.rhs])
elif isinstance(src, ModelSelect):
accum.add(src.model)
return accum

def _get_cursor_wrapper(self, cursor):
row_type = self._row_type or self.default_row_type
if row_type == ROW.MODEL and len(self.get_query_models()) > 1:
raise ValueError('Compound queries involving multiple model '
'classes must use dicts(), tuples() or '
'namedtuples() for the row type.')
return (super(ModelCompoundSelectQuery, self)
._get_cursor_wrapper(cursor))


class ModelSelect(BaseModelSelect, Select):
def __init__(self, model, fields_or_models, is_default=False):
Expand Down Expand Up @@ -5948,16 +5928,19 @@ def _initialize_columns(self):
self.fields = fields = [None] * self.ncols

for idx, description_item in enumerate(description):
column = description_item[0]
column = description_item[0].strip('"')
dot_index = column.find('.')
if dot_index != -1:
column = column[dot_index + 1:]

self.columns.append(column.strip('"'))
self.columns.append(column)
try:
raw_node = self.select[idx]
except IndexError:
continue
if column in combined:
raw_node = node = combined[column]
else:
continue
else:
node = raw_node.unwrap()

Expand Down
17 changes: 9 additions & 8 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,18 +2385,19 @@ def _inner():
for filename, data in file_data:
CFile.create(filename=filename, data=data, timestamp=make_ts())

def test_cannot_mix_models_with_model_row_type(self):
def test_mix_models_with_model_row_type(self):
cast = 'CHAR' if IS_MYSQL else 'TEXT'
lhs = CNote.select(CNote.id.cast(cast).alias('id'),
lhs = CNote.select(CNote.id.cast(cast).alias('id_text'),
CNote.content, CNote.timestamp)
rhs = CFile.select(CFile.filename, CFile.data, CFile.timestamp)
query = (lhs | rhs)
self.assertRaises(ValueError, query.execute)
query = (lhs | rhs).order_by(SQL('timestamp')).limit(4)

# Can use tuples/dicts/namedtuples.
query.tuples().execute()
query.dicts().execute()
query.namedtuples().execute()
data = [(n.id_text, n.content, n.timestamp) for n in query]
self.assertEqual(data, [
('1', 'note-a', self.ts(1)),
('2', 'note-b', self.ts(2)),
('3', 'note-c', self.ts(3)),
('peewee.txt', 'peewee orm', self.ts(4))])

def test_mixed_models_tuple_row_type(self):
cast = 'CHAR' if IS_MYSQL else 'TEXT'
Expand Down

0 comments on commit 67ceaf3

Please sign in to comment.