Skip to content

Commit

Permalink
Generate update set on conflict for upserts (#145)
Browse files Browse the repository at this point in the history
When upserting to tables with only key columns, typo would generate `do nothing`
as a conflict resolution. This has the disadvantage of making the query not
execute the `returning` part, which meant that in cases where an upsert was
executed targeting a single row, no rows would return and the generated code
would cause an error.

As a workaround, we generate `update set k = excluded.k` for an arbitrary key
column k instead. This causes the `returning` part of the query to run but
shouldn't change the value of the row as typo sees it.

- Add failing test case for upserting on an existing row
- Make the test case work by generating `update set` instead of `do nothing`
  • Loading branch information
kaaveland authored Oct 31, 2024
1 parent 2f44343 commit 4f36771
Show file tree
Hide file tree
Showing 100 changed files with 1,292 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${ParameterValue(unsaved.id, null, MaritalStatusId.toStatement)}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"
"""
.executeInsert(MaritalStatusRow.rowParser(1).single)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package object hardcoded {
case "VARCHAR" => "text[]"
case other => s"${other}[]"
}

override def jdbcType: scala.Int = java.sql.Types.ARRAY
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${ParameterValue(unsaved.id, null, MaritalStatusId.toStatement)}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"
"""
.executeInsert(MaritalStatusRow.rowParser(1).single)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package object hardcoded {
case "VARCHAR" => "text[]"
case other => s"${other}[]"
}

override def jdbcType: scala.Int = java.sql.Types.ARRAY
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${fromWrite(unsaved.id)(Write.fromPut(MaritalStatusId.put))}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"
""".query(using MaritalStatusRow.read).unique
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${fromWrite(unsaved.id)(Write.fromPut(MaritalStatusId.put))}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"
""".query(using MaritalStatusRow.read).unique
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${Segment.paramSegment(unsaved.id)(MaritalStatusId.setter)}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"""".insertReturning(using MaritalStatusRow.jdbcDecoder)
}
/* NOTE: this functionality is not safe if you use auto-commit mode! it runs 3 SQL statements */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MaritalStatusRepoImpl extends MaritalStatusRepo {
${Segment.paramSegment(unsaved.id)(MaritalStatusId.setter)}::int8
)
on conflict ("id")
do nothing
do update set "id" = EXCLUDED."id"
returning "id"""".insertReturning(using MaritalStatusRow.jdbcDecoder)
}
/* NOTE: this functionality is not safe if you use auto-commit mode! it runs 3 SQL statements */
Expand Down
6 changes: 6 additions & 0 deletions init/data/test-tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,9 @@ INSERT INTO issue142 (tabellkode)
VALUES ('aa'),
('bb')
;

create table only_pk_columns(
key_column_1 text not null,
key_column_2 int not null,
constraint only_pk_columns_pk primary key (key_column_1, key_column_2)
);
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class EmployeedepartmenthistoryRepoImpl extends EmployeedepartmenthistoryRepo {
val shiftid = compositeIds.map(_.shiftid)
SQL"""select "businessentityid", "departmentid", "shiftid", "startdate"::text, "enddate"::text, "modifieddate"::text
from "humanresources"."employeedepartmenthistory"
where ("businessentityid", "startdate", "departmentid", "shiftid")
where ("businessentityid", "startdate", "departmentid", "shiftid")
in (select unnest(${businessentityid}), unnest(${startdate}), unnest(${departmentid}), unnest(${shiftid}))
""".as(EmployeedepartmenthistoryRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class EmployeepayhistoryRepoImpl extends EmployeepayhistoryRepo {
val ratechangedate = compositeIds.map(_.ratechangedate)
SQL"""select "businessentityid", "ratechangedate"::text, "rate", "payfrequency", "modifieddate"::text
from "humanresources"."employeepayhistory"
where ("businessentityid", "ratechangedate")
where ("businessentityid", "ratechangedate")
in (select unnest(${businessentityid}), unnest(${ratechangedate}))
""".as(EmployeepayhistoryRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package object adventureworks {
case "VARCHAR" => "text[]"
case other => s"${other}[]"
}

override def jdbcType: scala.Int = java.sql.Types.ARRAY
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class BusinessentityaddressRepoImpl extends BusinessentityaddressRepo {
val addresstypeid = compositeIds.map(_.addresstypeid)
SQL"""select "businessentityid", "addressid", "addresstypeid", "rowguid", "modifieddate"::text
from "person"."businessentityaddress"
where ("businessentityid", "addressid", "addresstypeid")
where ("businessentityid", "addressid", "addresstypeid")
in (select unnest(${businessentityid}), unnest(${addressid}), unnest(${addresstypeid}))
""".as(BusinessentityaddressRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class BusinessentitycontactRepoImpl extends BusinessentitycontactRepo {
val contacttypeid = compositeIds.map(_.contacttypeid)
SQL"""select "businessentityid", "personid", "contacttypeid", "rowguid", "modifieddate"::text
from "person"."businessentitycontact"
where ("businessentityid", "personid", "contacttypeid")
where ("businessentityid", "personid", "contacttypeid")
in (select unnest(${businessentityid}), unnest(${personid}), unnest(${contacttypeid}))
""".as(BusinessentitycontactRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class EmailaddressRepoImpl extends EmailaddressRepo {
val emailaddressid = compositeIds.map(_.emailaddressid)
SQL"""select "businessentityid", "emailaddressid", "emailaddress", "rowguid", "modifieddate"::text
from "person"."emailaddress"
where ("businessentityid", "emailaddressid")
where ("businessentityid", "emailaddressid")
in (select unnest(${businessentityid}), unnest(${emailaddressid}))
""".as(EmailaddressRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class PersonphoneRepoImpl extends PersonphoneRepo {
val phonenumbertypeid = compositeIds.map(_.phonenumbertypeid)
SQL"""select "businessentityid", "phonenumber", "phonenumbertypeid", "modifieddate"::text
from "person"."personphone"
where ("businessentityid", "phonenumber", "phonenumbertypeid")
where ("businessentityid", "phonenumber", "phonenumbertypeid")
in (select unnest(${businessentityid}), unnest(${phonenumber}), unnest(${phonenumbertypeid}))
""".as(PersonphoneRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ProductcosthistoryRepoImpl extends ProductcosthistoryRepo {
val startdate = compositeIds.map(_.startdate)
SQL"""select "productid", "startdate"::text, "enddate"::text, "standardcost", "modifieddate"::text
from "production"."productcosthistory"
where ("productid", "startdate")
where ("productid", "startdate")
in (select unnest(${productid}), unnest(${startdate}))
""".as(ProductcosthistoryRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class ProductdocumentRepoImpl extends ProductdocumentRepo {
val documentnode = compositeIds.map(_.documentnode)
SQL"""select "productid", "modifieddate"::text, "documentnode"
from "production"."productdocument"
where ("productid", "documentnode")
where ("productid", "documentnode")
in (select unnest(${productid}), unnest(${documentnode}))
""".as(ProductdocumentRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ProductinventoryRepoImpl extends ProductinventoryRepo {
val locationid = compositeIds.map(_.locationid)
SQL"""select "productid", "locationid", "shelf", "bin", "quantity", "rowguid", "modifieddate"::text
from "production"."productinventory"
where ("productid", "locationid")
where ("productid", "locationid")
in (select unnest(${productid}), unnest(${locationid}))
""".as(ProductinventoryRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ProductlistpricehistoryRepoImpl extends ProductlistpricehistoryRepo {
val startdate = compositeIds.map(_.startdate)
SQL"""select "productid", "startdate"::text, "enddate"::text, "listprice", "modifieddate"::text
from "production"."productlistpricehistory"
where ("productid", "startdate")
where ("productid", "startdate")
in (select unnest(${productid}), unnest(${startdate}))
""".as(ProductlistpricehistoryRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ProductmodelillustrationRepoImpl extends ProductmodelillustrationRepo {
val illustrationid = compositeIds.map(_.illustrationid)
SQL"""select "productmodelid", "illustrationid", "modifieddate"::text
from "production"."productmodelillustration"
where ("productmodelid", "illustrationid")
where ("productmodelid", "illustrationid")
in (select unnest(${productmodelid}), unnest(${illustrationid}))
""".as(ProductmodelillustrationRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ProductmodelproductdescriptioncultureRepoImpl extends Productmodelproductd
val cultureid = compositeIds.map(_.cultureid)
SQL"""select "productmodelid", "productdescriptionid", "cultureid", "modifieddate"::text
from "production"."productmodelproductdescriptionculture"
where ("productmodelid", "productdescriptionid", "cultureid")
where ("productmodelid", "productdescriptionid", "cultureid")
in (select unnest(${productmodelid}), unnest(${productdescriptionid}), unnest(${cultureid}))
""".as(ProductmodelproductdescriptioncultureRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class ProductproductphotoRepoImpl extends ProductproductphotoRepo {
val productphotoid = compositeIds.map(_.productphotoid)
SQL"""select "productid", "productphotoid", "primary", "modifieddate"::text
from "production"."productproductphoto"
where ("productid", "productphotoid")
where ("productid", "productphotoid")
in (select unnest(${productid}), unnest(${productphotoid}))
""".as(ProductproductphotoRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class WorkorderroutingRepoImpl extends WorkorderroutingRepo {
val operationsequence = compositeIds.map(_.operationsequence)
SQL"""select "workorderid", "productid", "operationsequence", "locationid", "scheduledstartdate"::text, "scheduledenddate"::text, "actualstartdate"::text, "actualenddate"::text, "actualresourcehrs", "plannedcost", "actualcost", "modifieddate"::text
from "production"."workorderrouting"
where ("workorderid", "productid", "operationsequence")
where ("workorderid", "productid", "operationsequence")
in (select unnest(${workorderid}), unnest(${productid}), unnest(${operationsequence}))
""".as(WorkorderroutingRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class FlaffRepoImpl extends FlaffRepo {
val specifier = compositeIds.map(_.specifier)
SQL"""select "code", "another_code", "some_number", "specifier", "parentspecifier"
from "public"."flaff"
where ("code", "another_code", "some_number", "specifier")
where ("code", "another_code", "some_number", "specifier")
in (select unnest(${code}), unnest(${anotherCode}), unnest(${someNumber}), unnest(${specifier}))
""".as(FlaffRow.rowParser(1).*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Issue142RepoImpl extends Issue142Repo {
${ParameterValue(unsaved.tabellkode, null, Issue142Id.toStatement)}
)
on conflict ("tabellkode")
do nothing
do update set "tabellkode" = EXCLUDED."tabellkode"
returning "tabellkode"
"""
.executeInsert(Issue142Row.rowParser(1).single)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Issue1422RepoImpl extends Issue1422Repo {
${ParameterValue(unsaved.tabellkode, null, Issue142Id.toStatement)}
)
on conflict ("tabellkode")
do nothing
do update set "tabellkode" = EXCLUDED."tabellkode"
returning "tabellkode"
"""
.executeInsert(Issue1422Row.rowParser(1).single)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* File has been automatically generated by `typo`.
*
* IF YOU CHANGE THIS FILE YOUR CHANGES WILL BE OVERWRITTEN.
*/
package adventureworks
package public
package only_pk_columns

import typo.dsl.Path
import typo.dsl.Required
import typo.dsl.SqlExpr
import typo.dsl.SqlExpr.CompositeIn
import typo.dsl.SqlExpr.CompositeIn.TuplePart
import typo.dsl.SqlExpr.FieldLikeNoHkt
import typo.dsl.SqlExpr.IdField
import typo.dsl.Structure.Relation

trait OnlyPkColumnsFields {
def keyColumn1: IdField[String, OnlyPkColumnsRow]
def keyColumn2: IdField[Int, OnlyPkColumnsRow]
def compositeIdIs(compositeId: OnlyPkColumnsId): SqlExpr[Boolean, Required] =
keyColumn1.isEqual(compositeId.keyColumn1).and(keyColumn2.isEqual(compositeId.keyColumn2))
def compositeIdIn(compositeIds: Array[OnlyPkColumnsId]): SqlExpr[Boolean, Required] =
new CompositeIn(compositeIds)(TuplePart(keyColumn1)(_.keyColumn1), TuplePart(keyColumn2)(_.keyColumn2))

}

object OnlyPkColumnsFields {
lazy val structure: Relation[OnlyPkColumnsFields, OnlyPkColumnsRow] =
new Impl(Nil)

private final class Impl(val _path: List[Path])
extends Relation[OnlyPkColumnsFields, OnlyPkColumnsRow] {

override lazy val fields: OnlyPkColumnsFields = new OnlyPkColumnsFields {
override def keyColumn1 = IdField[String, OnlyPkColumnsRow](_path, "key_column_1", None, None, x => x.keyColumn1, (row, value) => row.copy(keyColumn1 = value))
override def keyColumn2 = IdField[Int, OnlyPkColumnsRow](_path, "key_column_2", None, Some("int4"), x => x.keyColumn2, (row, value) => row.copy(keyColumn2 = value))
}

override lazy val columns: List[FieldLikeNoHkt[?, OnlyPkColumnsRow]] =
List[FieldLikeNoHkt[?, OnlyPkColumnsRow]](fields.keyColumn1, fields.keyColumn2)

override def copy(path: List[Path]): Impl =
new Impl(path)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* File has been automatically generated by `typo`.
*
* IF YOU CHANGE THIS FILE YOUR CHANGES WILL BE OVERWRITTEN.
*/
package adventureworks
package public
package only_pk_columns

import play.api.libs.json.JsObject
import play.api.libs.json.JsResult
import play.api.libs.json.JsValue
import play.api.libs.json.OWrites
import play.api.libs.json.Reads
import play.api.libs.json.Writes
import scala.collection.immutable.ListMap
import scala.util.Try

/** Type for the composite primary key of table `public.only_pk_columns` */
case class OnlyPkColumnsId(
keyColumn1: String,
keyColumn2: Int
)
object OnlyPkColumnsId {
implicit lazy val ordering: Ordering[OnlyPkColumnsId] = Ordering.by(x => (x.keyColumn1, x.keyColumn2))
implicit lazy val reads: Reads[OnlyPkColumnsId] = Reads[OnlyPkColumnsId](json => JsResult.fromTry(
Try(
OnlyPkColumnsId(
keyColumn1 = json.\("key_column_1").as(Reads.StringReads),
keyColumn2 = json.\("key_column_2").as(Reads.IntReads)
)
)
),
)
implicit lazy val writes: OWrites[OnlyPkColumnsId] = OWrites[OnlyPkColumnsId](o =>
new JsObject(ListMap[String, JsValue](
"key_column_1" -> Writes.StringWrites.writes(o.keyColumn1),
"key_column_2" -> Writes.IntWrites.writes(o.keyColumn2)
))
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* File has been automatically generated by `typo`.
*
* IF YOU CHANGE THIS FILE YOUR CHANGES WILL BE OVERWRITTEN.
*/
package adventureworks
package public
package only_pk_columns

import java.sql.Connection
import typo.dsl.DeleteBuilder
import typo.dsl.SelectBuilder
import typo.dsl.UpdateBuilder

trait OnlyPkColumnsRepo {
def delete: DeleteBuilder[OnlyPkColumnsFields, OnlyPkColumnsRow]
def deleteById(compositeId: OnlyPkColumnsId)(implicit c: Connection): Boolean
def deleteByIds(compositeIds: Array[OnlyPkColumnsId])(implicit c: Connection): Int
def insert(unsaved: OnlyPkColumnsRow)(implicit c: Connection): OnlyPkColumnsRow
def insertStreaming(unsaved: Iterator[OnlyPkColumnsRow], batchSize: Int = 10000)(implicit c: Connection): Long
def select: SelectBuilder[OnlyPkColumnsFields, OnlyPkColumnsRow]
def selectAll(implicit c: Connection): List[OnlyPkColumnsRow]
def selectById(compositeId: OnlyPkColumnsId)(implicit c: Connection): Option[OnlyPkColumnsRow]
def selectByIds(compositeIds: Array[OnlyPkColumnsId])(implicit c: Connection): List[OnlyPkColumnsRow]
def selectByIdsTracked(compositeIds: Array[OnlyPkColumnsId])(implicit c: Connection): Map[OnlyPkColumnsId, OnlyPkColumnsRow]
def update: UpdateBuilder[OnlyPkColumnsFields, OnlyPkColumnsRow]
def upsert(unsaved: OnlyPkColumnsRow)(implicit c: Connection): OnlyPkColumnsRow
def upsertBatch(unsaved: Iterable[OnlyPkColumnsRow])(implicit c: Connection): List[OnlyPkColumnsRow]
/* NOTE: this functionality is not safe if you use auto-commit mode! it runs 3 SQL statements */
def upsertStreaming(unsaved: Iterator[OnlyPkColumnsRow], batchSize: Int = 10000)(implicit c: Connection): Int
}
Loading

0 comments on commit 4f36771

Please sign in to comment.