Skip to content

Commit

Permalink
Insert NULL for missing nullable fields on bulk insert.
Browse files Browse the repository at this point in the history
Fixes #2368 and replaces #2369
  • Loading branch information
coleifer committed Mar 5, 2021
1 parent ac5126d commit 12d05a0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 3 deletions.
10 changes: 8 additions & 2 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,11 +2633,15 @@ def _generate_insert(self, insert, ctx):
if col not in seen:
columns.append(col)

nullable_columns = set()
value_lookups = {}
for column in columns:
lookups = [column, column.name]
if isinstance(column, Field) and column.name != column.column_name:
lookups.append(column.column_name)
if isinstance(column, Field):
if column.name != column.column_name:
lookups.append(column.column_name)
if column.null:
nullable_columns.add(column)
value_lookups[column] = lookups

ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
Expand Down Expand Up @@ -2671,6 +2675,8 @@ def _generate_insert(self, insert, ctx):
val = defaults[column]
if callable_(val):
val = val()
elif column in nullable_columns:
val = None
else:
raise ValueError('Missing value for %s.' % column.name)

Expand Down
7 changes: 7 additions & 0 deletions tests/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,10 @@ class Meta:
indexes = [
SQL('CREATE UNIQUE INDEX "ukvp_kve" ON "ukvp" ("key", "value") '
'WHERE "extra" > 1')]


class DfltM(TestModel):
name = CharField()
dflt1 = IntegerField(default=1)
dflt2 = IntegerField(default=lambda: 2)
dfltn = IntegerField(null=True)
17 changes: 16 additions & 1 deletion tests/model_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Meta:

class TestModelSQL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [Category, CKM, Note, Person, Relationship, Sample, User]
requires = [Category, CKM, Note, Person, Relationship, Sample, User, DfltM]

def test_select(self):
query = (Person
Expand Down Expand Up @@ -440,6 +440,21 @@ def test_insert_many_defaults(self):
'INSERT INTO "sample" ("counter", "value") VALUES (?, ?), (?, ?)'),
[3, 1., 2, 2.])

def test_insert_many_defaults_nulls(self):
data = [
{'name': 'd1'},
{'name': 'd2', 'dflt1': 10},
{'name': 'd3', 'dflt2': 30},
{'name': 'd4', 'dfltn': 40}]
fields = [DfltM.name, DfltM.dflt1, DfltM.dflt2, DfltM.dfltn]
self.assertSQL(DfltM.insert_many(data, fields=fields), (
'INSERT INTO "dflt_m" ("name", "dflt1", "dflt2", "dfltn") VALUES '
'(?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)'),
['d1', 1, 2, None,
'd2', 10, 2, None,
'd3', 1, 30, None,
'd4', 1, 2, 40])

def test_insert_many_list_with_fields(self):
data = [(i,) for i in ('charlie', 'huey', 'zaizee')]
query = User.insert_many(data, fields=[User.username])
Expand Down
19 changes: 19 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ def test_insert_many(self):
names = [u.username for u in User.select().order_by(User.username)]
self.assertEqual(names, ['u%02d' % i for i in range(100)])

@requires_models(DfltM)
def test_insert_many_defaults_nullable(self):
data = [
{'name': 'd1'},
{'name': 'd2', 'dflt1': 10},
{'name': 'd3', 'dflt2': 30},
{'name': 'd4', 'dfltn': 40}]
fields = [DfltM.name, DfltM.dflt1, DfltM.dflt2, DfltM.dfltn]
DfltM.insert_many(data, fields).execute()

expected = [
('d1', 1, 2, None),
('d2', 10, 2, None),
('d3', 1, 30, None),
('d4', 1, 2, 40)]
query = DfltM.select().order_by(DfltM.name)
actual = [(d.name, d.dflt1, d.dflt2, d.dfltn) for d in query]
self.assertEqual(actual, expected)

@requires_models(User, Tweet)
def test_create(self):
with self.assertQueryCount(1):
Expand Down

0 comments on commit 12d05a0

Please sign in to comment.