diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d5b1e510..338ed2497 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,23 +24,23 @@ concurrency: jobs: build: - name: Build and Test + name: Test strategy: matrix: - os: [ubuntu-latest] + os: [ubuntu-22.04] scala: [2.12, 2.13, 3] java: [temurin@11] runs-on: ${{ matrix.os }} timeout-minutes: 60 steps: - - name: Install sbt - uses: sbt/setup-sbt@v1 - - name: Checkout current branch (full) uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Setup sbt + uses: sbt/setup-sbt@v1 + - name: Setup Java (temurin@11) id: setup-java-temurin-11 if: matrix.java == 'temurin@11' @@ -64,25 +64,25 @@ jobs: run: sbt githubWorkflowCheck - name: Check formatting - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-22.04' run: sbt '++ ${{ matrix.scala }}' scalafmtCheckAll 'project /' scalafmtSbtCheck - name: Test run: sbt '++ ${{ matrix.scala }}' freeGen2 test - name: Check binary compatibility - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-22.04' run: sbt '++ ${{ matrix.scala }}' mimaReportBinaryIssues - name: Generate API documentation - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-22.04' run: sbt '++ ${{ matrix.scala }}' doc - name: Check there are no uncommitted changes in git (to catch generated files that weren't committed) run: sbt '++ ${{ matrix.scala }}' checkGitNoUncommittedChanges - - name: Check Doc Site (2.13.14 only) - if: matrix.scala == '2.13.14' + - name: Check Doc Site (2.13 only) + if: matrix.scala == '2.13' run: sbt '++ ${{ matrix.scala }}' docs/makeSite - name: Make target directories @@ -106,18 +106,18 @@ jobs: if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') strategy: matrix: - os: [ubuntu-latest] + os: [ubuntu-22.04] java: [temurin@11] runs-on: ${{ matrix.os }} steps: - - name: Install sbt - uses: sbt/setup-sbt@v1 - - name: Checkout current branch (full) uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Setup sbt + uses: sbt/setup-sbt@v1 + - name: Setup Java (temurin@11) id: setup-java-temurin-11 if: matrix.java == 'temurin@11' @@ -190,18 +190,18 @@ jobs: if: github.event.repository.fork == false && github.event_name != 'pull_request' strategy: matrix: - os: [ubuntu-latest] + os: [ubuntu-22.04] java: [temurin@11] runs-on: ${{ matrix.os }} steps: - - name: Install sbt - uses: sbt/setup-sbt@v1 - - name: Checkout current branch (full) uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Setup sbt + uses: sbt/setup-sbt@v1 + - name: Setup Java (temurin@11) id: setup-java-temurin-11 if: matrix.java == 'temurin@11' diff --git a/.mergify.yml b/.mergify.yml index 4688c3292..51e4bad4e 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -10,9 +10,9 @@ pull_request_rules: conditions: - author=scala-steward - body~=labels:.*early-semver-patch - - status-success=Build and Test (ubuntu-latest, 2.12, temurin@11) - - status-success=Build and Test (ubuntu-latest, 2.13, temurin@11) - - status-success=Build and Test (ubuntu-latest, 3, temurin@11) + - status-success=Test (ubuntu-22.04, 2.12, temurin@11) + - status-success=Test (ubuntu-22.04, 2.13, temurin@11) + - status-success=Test (ubuntu-22.04, 3, temurin@11) actions: merge: {} - name: Label bench PRs @@ -161,9 +161,9 @@ pull_request_rules: remove: [] - name: merge-when-ci-pass conditions: - - status-success=Build and Test (ubuntu-latest, 2.12, temurin@11) - - status-success=Build and Test (ubuntu-latest, 2.13, temurin@11) - - status-success=Build and Test (ubuntu-latest, 3, temurin@11) + - status-success=Test (ubuntu-22.04, 2.12, temurin@11) + - status-success=Test (ubuntu-22.04, 2.13, temurin@11) + - status-success=Test (ubuntu-22.04, 3, temurin@11) - label=merge-on-build-success actions: merge: {} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8eef597e3..df33cf7b9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,20 +8,27 @@ Running the tests or building the documentation site requires connection to the which you can spin up using docker-compose: ``` -docker-compose up -d --force-update +docker compose up -d --force-recreate ``` -After that, in SBT you can run `test` to run tests, and `makeSite` to build the doc site +Note: If you're using Apple Silicone Macbooks (M1, M2, etc), you need to enable "Use Rosetta for x86_64/amd64 emulation on Apple Silicon" since there is no ARM64 image for postgis yet. -If you're editing code generation related code, you should reload the SBT project and then run the `freeGen2` SBT task -before compiling or running tests. +With the containers started, SBT you can run `test` to run tests, and `makeSite` to build the doc site. + +## Fixing warnings + +To improve code quality and bug, we enable many stricter scala compiler flags via the +[sbt-tpolecat](https://github.com/typelevel/sbt-tpolecat) plugin and in CI all warnings will be treated as errors. + +For a more pleasant development experience, we default to `tpolecatDevMode` so warnings do not cause compilation errors. +You can use the sbt command `tpolecatCiMode` to enable strict mode and help catch any warnings you missed. ## Caveats when working on the code ## Avoiding internal cyclic module dependencies For end users, doobie provides the aliases for high and low level APIs -such as `doobie.hi.HC`, `doobie.free.FPS`. +such as `doobie.hi.HC`, `doobie.free.FPS`. Due to how the module depends on one another, internally in doobie we cannot use these aliases because it'll lead to cyclic module dependencies and cause runtime errors. @@ -49,6 +56,3 @@ To update the doc site, check out the tag first. git checkout v1.2.3 sbt docs/publishMicrosite ``` - - - diff --git a/build.sbt b/build.sbt index 42a565a6e..b2368d721 100644 --- a/build.sbt +++ b/build.sbt @@ -4,25 +4,25 @@ import org.typelevel.sbt.tpolecat.{DevMode, CiMode} // Library versions all in one place, for convenience and sanity. lazy val catsVersion = "2.12.0" -lazy val catsEffectVersion = "3.5.6" +lazy val catsEffectVersion = "3.5.7" lazy val circeVersion = "0.14.10" lazy val fs2Version = "3.11.0" lazy val h2Version = "1.4.200" lazy val hikariVersion = "6.2.1" // N.B. Hikari v4 introduces a breaking change via slf4j v2 lazy val kindProjectorVersion = "0.11.2" -lazy val mysqlVersion = "9.1.0" +lazy val mysqlVersion = "9.2.0" lazy val log4catsVersion = "2.7.0" lazy val postGisVersion = "2024.1.0" lazy val postgresVersion = "42.7.4" -lazy val refinedVersion = "0.11.2" +lazy val refinedVersion = "0.11.3" lazy val scalaCheckVersion = "1.15.4" lazy val scalatestVersion = "3.2.18" -lazy val munitVersion = "1.0.2" +lazy val munitVersion = "1.0.4" lazy val shapelessVersion = "2.3.12" lazy val silencerVersion = "1.7.1" -lazy val specs2Version = "4.20.7" -lazy val scala212Version = "2.12.19" -lazy val scala213Version = "2.13.14" +lazy val specs2Version = "4.20.9" +lazy val scala212Version = "2.12.20" +lazy val scala213Version = "2.13.15" lazy val scala3Version = "3.3.4" // scala-steward:off lazy val slf4jVersion = "1.7.36" @@ -33,6 +33,7 @@ lazy val weaverVersion = "0.8.4" ThisBuild / tlBaseVersion := "1.0" ThisBuild / tlCiReleaseBranches := Seq("main") // publish snapshots on `main` ThisBuild / tlCiScalafmtCheck := true +//ThisBuild / scalaVersion := scala212Version ThisBuild / scalaVersion := scala213Version //ThisBuild / scalaVersion := scala3Version ThisBuild / crossScalaVersions := Seq(scala212Version, scala213Version, scala3Version) @@ -68,8 +69,8 @@ ThisBuild / githubWorkflowBuildPostamble ++= Seq( ), WorkflowStep.Sbt( commands = List("docs/makeSite"), - name = Some(s"Check Doc Site ($scala213Version only)"), - cond = Some(s"matrix.scala == '$scala213Version'") + name = Some(s"Check Doc Site (2.13 only)"), + cond = Some(s"matrix.scala == '2.13'") ) ) @@ -98,9 +99,12 @@ lazy val compilerFlags = Seq( Compile / doc / scalacOptions --= Seq( "-Xfatal-warnings" ), -// Test / scalacOptions --= Seq( -// "-Xfatal-warnings" -// ), + // Disable warning when @nowarn annotation isn't suppressing a warning + // to simplify cross-building + // because 2.12 @nowarn doesn't actually do anything.. https://github.com/scala/bug/issues/12313 + scalacOptions ++= Seq( + "-Wconf:cat=unused-nowarn:s" + ), scalacOptions ++= (if (tlIsScala3.value) // Handle irrefutable patterns in for comprehensions Seq("-source:future", "-language:adhocExtensions") @@ -249,8 +253,7 @@ lazy val core = project ).filterNot(_ => tlIsScala3.value) ++ Seq( "org.tpolecat" %% "typename" % "1.1.0", "com.h2database" % "h2" % h2Version % "test", - "org.postgresql" % "postgresql" % postgresVersion % "test", - "org.mockito" % "mockito-core" % "5.12.0" % Test + "org.postgresql" % "postgresql" % postgresVersion % "test" ), Compile / unmanagedSourceDirectories += { val sourceDir = (Compile / sourceDirectory).value @@ -493,7 +496,12 @@ lazy val bench = project .enablePlugins(NoPublishPlugin) .enablePlugins(AutomateHeaderPlugin) .enablePlugins(JmhPlugin) - .dependsOn(core, postgres) + .settings( + libraryDependencies ++= (if (scalaVersion.value == scala212Version) + Seq("org.scala-lang.modules" %% "scala-collection-compat" % "2.12.0") + else Seq.empty) + ) + .dependsOn(core, postgres, hikari) .settings(doobieSettings) lazy val docs = project @@ -539,6 +547,7 @@ lazy val docs = project "scalaVersion" -> scalaVersion.value ), mdocIn := baseDirectory.value / "src" / "main" / "mdoc", + mdocExtraArguments ++= Seq("--no-link-hygiene"), Compile / paradox / sourceDirectory := mdocOut.value, makeSite := makeSite.dependsOn(mdoc.toTask("")).value ) diff --git a/docker-compose.yml b/docker-compose.yml index 998912a54..06d174b3b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.1' +version: "3.1" services: postgres: @@ -16,9 +16,8 @@ services: limits: memory: 500M - mysql: - image: mysql:8.0-debian + image: mysql:8.0 environment: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: world diff --git a/init/postgres/test-db.sql b/init/postgres/test-db.sql index 28ff80161..c7c98ec84 100644 --- a/init/postgres/test-db.sql +++ b/init/postgres/test-db.sql @@ -6,6 +6,14 @@ create extension postgis; create extension hstore; create type myenum as enum ('foo', 'bar', 'invalid'); +create schema other_schema; + +set search_path to other_schema; + +create type other_enum as enum ('a', 'b'); + +set search_path to public; + -- -- The sample data used in the world database is Copyright Statistics -- Finland, http://www.stat.fi/worldinfigures. diff --git a/modules/bench/src/main/scala/doobie/bench/large.scala b/modules/bench/src/main/scala/doobie/bench/large.scala new file mode 100644 index 000000000..d35dddd14 --- /dev/null +++ b/modules/bench/src/main/scala/doobie/bench/large.scala @@ -0,0 +1,193 @@ +// Copyright (c) 2013-2020 Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package doobie.bench + +import cats.effect.IO +import com.zaxxer.hikari.{HikariConfig, HikariDataSource} +import doobie.* +import doobie.syntax.all.* +import org.openjdk.jmh.annotations.* +import org.openjdk.jmh.infra.Blackhole +import scala.util.Using + +@State(Scope.Benchmark) +@OperationsPerInvocation(10000) // We process 10k rows so adjust the benchmark output accordingly +class LargeRow { + import cats.effect.unsafe.implicits.global + + private val hikariConfig = { + val config = new HikariConfig() + config.setDriverClassName("org.postgresql.Driver") + config.setJdbcUrl("jdbc:postgresql:world") + config.setUsername("postgres") + config.setPassword("password") + config.setMaximumPoolSize(2) + config + } + + val pool = new HikariDataSource(hikariConfig) + + val (xa, cleanup) = { + (for { + connectEC <- ExecutionContexts.fixedThreadPool[IO](hikariConfig.getMaximumPoolSize) + } yield Transactor.fromDataSource[IO].apply[HikariDataSource](pool, connectEC)).allocated.unsafeRunSync() + } + + @Setup(Level.Trial) + def setup(): Unit = { + val connio = for { + _ <- sql"""DROP TABLE IF EXISTS data""".update.run + _ <- sql"""CREATE TABLE data ( + col1 DOUBLE PRECISION, + col2 VARCHAR(50), + col3 INTEGER, + col4 VARCHAR(50), + col5 DOUBLE PRECISION, + col6 DOUBLE PRECISION, + col7 VARCHAR(50), + col8 VARCHAR(50) + );""".update.run + _ <- sql"select setseed(0.5)".query[Unit].unique // deterministic seed + _ <- sql"""INSERT INTO data (col1, col2, col3, col4, col5, col6, col7, col8) + SELECT random(), random() :: text, (random() * 1000) :: int, random() :: text, random(), random(), random() :: text, random() :: text + FROM generate_series(1, 10000) + """.update.run + } yield () + + connio.transact(xa).unsafeRunSync() + } + + @TearDown(Level.Trial) + def teardown(): Unit = { + pool.close() + cleanup.unsafeRunSync() + } + + @Benchmark + def tuple(bh: Blackhole): Unit = { + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[(Double, String, Int, String, Double, Double, String, String)].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def tupleOpt(bh: Blackhole): Unit = { + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[Option[(Double, String, Int, String, Double, Double, String, String)]].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def semiautoDerivedComplex(bh: Blackhole): Unit = { + import SemiautoDerivedInstances.* + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[Complex].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def semiautoDerivedComplexOpt(bh: Blackhole): Unit = { + import SemiautoDerivedInstances.* + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[Option[Complex]].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def autoDerivedComplex(bh: Blackhole): Unit = { + import doobie.implicits.* + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[Complex].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def autoDerivedComplexOpt(bh: Blackhole): Unit = { + import doobie.implicits.* + bh.consume(sql"""SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data""" + .query[Option[Complex]].to[List].transact(xa).unsafeRunSync()) + } + + @Benchmark + def rawJdbcComplex(bh: Blackhole): Unit = { + var l: List[Complex] = null + Using.resource(pool.getConnection()) { c => + Using.resource(c.prepareStatement("SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data")) { ps => + Using.resource(ps.executeQuery()) { rs => + val m = scala.collection.mutable.ListBuffer.empty[Complex] + while (rs.next()) { + m += Complex( + DSIS( + DS( + rs.getDouble(1), + rs.getString(2) + ), + IS( + rs.getInt(3), + rs.getString(4) + ) + ), + DDSS( + DD( + rs.getDouble(5), + rs.getDouble(6) + ), + SS( + rs.getString(7), + rs.getString(8) + ) + ) + ) + } + l = m.toList + } + } + + } + bh.consume(l) + } + + @Benchmark + def rawJdbcTuple(bh: Blackhole): Unit = { + type Tup = (Double, String, Int, String, Double, Double, String, String) + var l: List[Tup] = null + Using.resource(pool.getConnection()) { c => + Using.resource(c.prepareStatement("SELECT col1, col2, col3, col4, col5, col6, col7, col8 FROM data")) { ps => + Using.resource(ps.executeQuery()) { rs => + val m = + scala.collection.mutable.ListBuffer.empty[Tup] + while (rs.next()) { + m += Tuple8( + rs.getDouble(1), + rs.getString(2), + rs.getInt(3), + rs.getString(4), + rs.getDouble(5), + rs.getDouble(6), + rs.getString(7), + rs.getString(8) + ) + } + l = m.toList + } + } + + } + bh.consume(l) + } +} + +case class IS(i: Int, s: String) +case class DS(d: Double, s: String) +case class DSIS(ds: DS, is: IS) +case class DD(d0: Double, d1: Double) +case class SS(s0: String, s1: String) +case class DDSS(dd: DD, ss: SS) +case class Complex(dsis: DSIS, ddss: DDSS) + +object SemiautoDerivedInstances { + implicit val isRead: Read[IS] = Read.derived + implicit val dsRead: Read[DS] = Read.derived + implicit val dsisRead: Read[DSIS] = Read.derived + implicit val ddRead: Read[DD] = Read.derived + implicit val ssRead: Read[SS] = Read.derived + implicit val ddssRead: Read[DDSS] = Read.derived + implicit val cRead: Read[Complex] = Read.derived +} diff --git a/modules/core/src/main/scala-2/doobie/util/GetPlatform.scala b/modules/core/src/main/scala-2/doobie/util/GetPlatform.scala index 0967c24f0..0d0568f8d 100644 --- a/modules/core/src/main/scala-2/doobie/util/GetPlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/GetPlatform.scala @@ -11,13 +11,15 @@ trait GetPlatform { import doobie.util.compat.=:= /** @group Instances */ - @deprecated("Use Get.derived instead to derive instances explicitly", "1.0.0-RC6") def unaryProductGet[A, L <: HList, H, T <: HList]( implicit G: Generic.Aux[A, L], C: IsHCons.Aux[L, H, T], H: Lazy[Get[H]], E: (H :: HNil) =:= L - ): MkGet[A] = MkGet.unaryProductGet + ): Get[A] = { + void(C) // C drives inference but is not used directly + H.value.tmap[A](h => G.from(h :: HNil)) + } } diff --git a/modules/core/src/main/scala-2/doobie/util/MkGetPlatform.scala b/modules/core/src/main/scala-2/doobie/util/MkGetPlatform.scala deleted file mode 100644 index 1a805d0e1..000000000 --- a/modules/core/src/main/scala-2/doobie/util/MkGetPlatform.scala +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util - -import shapeless._ -import shapeless.ops.hlist.IsHCons - -trait MkGetPlatform { - import doobie.util.compat.=:= - - /** @group Instances */ - implicit def unaryProductGet[A, L <: HList, H, T <: HList]( - implicit - G: Generic.Aux[A, L], - C: IsHCons.Aux[L, H, T], - H: Lazy[Get[H]], - E: (H :: HNil) =:= L - ): MkGet[A] = { - void(C) // C drives inference but is not used directly - val get = H.value.tmap[A](h => G.from(h :: HNil)) - MkGet.lift(get) - } - -} diff --git a/modules/core/src/main/scala-2/doobie/util/MkPutPlatform.scala b/modules/core/src/main/scala-2/doobie/util/MkPutPlatform.scala deleted file mode 100644 index 2a4d32c3b..000000000 --- a/modules/core/src/main/scala-2/doobie/util/MkPutPlatform.scala +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util - -import shapeless._ -import shapeless.ops.hlist.IsHCons - -trait MkPutPlatform { - import doobie.util.compat.=:= - - /** @group Instances */ - implicit def unaryProductPut[A, L <: HList, H, T <: HList]( - implicit - G: Generic.Aux[A, L], - C: IsHCons.Aux[L, H, T], - H: Lazy[Put[H]], - E: (H :: HNil) =:= L - ): MkPut[A] = { - void(E) // E is a necessary constraint but isn't used directly - val put = H.value.contramap[A](a => G.to(a).head) - MkPut.lift(put) - } - -} diff --git a/modules/core/src/main/scala-2/doobie/util/MkReadPlatform.scala b/modules/core/src/main/scala-2/doobie/util/MkReadPlatform.scala index eb2d14074..19764df2b 100644 --- a/modules/core/src/main/scala-2/doobie/util/MkReadPlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/MkReadPlatform.scala @@ -4,142 +4,91 @@ package doobie.util -import shapeless.{HList, HNil, ::, Generic, Lazy, <:!<, OrElse} -import shapeless.labelled.{field, FieldType} +import shapeless.{HList, HNil, ::, Generic, Lazy, OrElse} +import shapeless.labelled.FieldType -trait MkReadPlatform extends LowerPriorityRead { +trait MkReadPlatform extends LowerPriorityMkRead { // Derivation base case for product types (1-element) implicit def productBase[H]( - implicit H: Read[H] OrElse MkRead[H] - ): MkRead[H :: HNil] = { - val head = H.unify - - new MkRead[H :: HNil]( - head.gets, - (rs, n) => head.unsafeGet(rs, n) :: HNil + implicit H: Read[H] OrElse Derived[MkRead[H]] + ): Derived[MkRead[H :: HNil]] = { + val headInstance = H.fold(identity, _.instance) + + new Derived( + new MkRead( + headInstance.map(_ :: HNil) + ) ) } // Derivation base case for shapeless record (1-element) implicit def recordBase[K <: Symbol, H]( - implicit H: Read[H] OrElse MkRead[H] - ): MkRead[FieldType[K, H] :: HNil] = { - val head = H.unify - - new MkRead[FieldType[K, H] :: HNil]( - head.gets, - (rs, n) => field[K](head.unsafeGet(rs, n)) :: HNil + implicit H: Read[H] OrElse Derived[MkRead[H]] + ): Derived[MkRead[FieldType[K, H] :: HNil]] = { + val headInstance = H.fold(identity, _.instance) + + new Derived( + new MkRead( + new Read.Transform[FieldType[K, H] :: HNil, H]( + headInstance, + h => shapeless.labelled.field[K].apply(h) :: HNil + ) + ) ) } } -trait LowerPriorityRead extends EvenLowerPriorityRead { +trait LowerPriorityMkRead { // Derivation inductive case for product types implicit def product[H, T <: HList]( implicit - H: Read[H] OrElse MkRead[H], - T: MkRead[T] - ): MkRead[H :: T] = { - val head = H.unify - - new MkRead[H :: T]( - head.gets ++ T.gets, - (rs, n) => head.unsafeGet(rs, n) :: T.unsafeGet(rs, n + head.length) + H: Read[H] OrElse Derived[MkRead[H]], + T: Read[T] OrElse Derived[MkRead[T]] + ): Derived[MkRead[H :: T]] = { + val headInstance = H.fold(identity, _.instance) + val tailInstance = T.fold(identity, _.instance) + + new Derived( + new MkRead( + new Read.Composite[H :: T, H, T]( + headInstance, + tailInstance, + (h, t) => h :: t + ) + ) ) } // Derivation inductive case for shapeless records implicit def record[K <: Symbol, H, T <: HList]( implicit - H: Read[H] OrElse MkRead[H], - T: MkRead[T] - ): MkRead[FieldType[K, H] :: T] = { - val head = H.unify - - new MkRead[FieldType[K, H] :: T]( - head.gets ++ T.gets, - (rs, n) => field[K](head.unsafeGet(rs, n)) :: T.unsafeGet(rs, n + head.length) + H: Read[H] OrElse Derived[MkRead[H]], + T: Read[T] OrElse Derived[MkRead[T]] + ): Derived[MkRead[FieldType[K, H] :: T]] = { + val headInstance = H.fold(identity, _.instance) + val tailInstance = T.fold(identity, _.instance) + + new Derived( + new MkRead( + new Read.Composite[FieldType[K, H] :: T, H, T]( + headInstance, + tailInstance, + (h, t) => shapeless.labelled.field[K].apply(h) :: t + ) + ) ) } // Derivation for product types (i.e. case class) - implicit def generic[T, Repr](implicit gen: Generic.Aux[T, Repr], G: Lazy[MkRead[Repr]]): MkRead[T] = - new MkRead[T](G.value.gets, (rs, n) => gen.from(G.value.unsafeGet(rs, n))) - - // Derivation base case for Option of product types (1-element) - implicit def optProductBase[H]( - implicit - H: Read[Option[H]] OrElse MkRead[Option[H]], - N: H <:!< Option[α] forSome { type α } - ): MkRead[Option[H :: HNil]] = { - void(N) - val head = H.unify - - new MkRead[Option[H :: HNil]]( - head.gets, - (rs, n) => - head.unsafeGet(rs, n).map(_ :: HNil) - ) - } - - // Derivation base case for Option of product types (where the head element is Option) - implicit def optProductOptBase[H]( - implicit H: Read[Option[H]] OrElse MkRead[Option[H]] - ): MkRead[Option[Option[H] :: HNil]] = { - val head = H.unify - - new MkRead[Option[Option[H] :: HNil]]( - head.gets, - (rs, n) => head.unsafeGet(rs, n).map(h => Some(h) :: HNil) - ) - } - -} - -trait EvenLowerPriorityRead { - - // Read[Option[H]], Read[Option[T]] implies Read[Option[H *: T]] - implicit def optProduct[H, T <: HList]( + implicit def genericRead[T, Repr]( implicit - H: Read[Option[H]] OrElse MkRead[Option[H]], - T: MkRead[Option[T]], - N: H <:!< Option[α] forSome { type α } - ): MkRead[Option[H :: T]] = { - void(N) - val head = H.unify - - new MkRead[Option[H :: T]]( - head.gets ++ T.gets, - (rs, n) => - for { - h <- head.unsafeGet(rs, n) - t <- T.unsafeGet(rs, n + head.length) - } yield h :: t - ) + gen: Generic.Aux[T, Repr], + hlistRead: Lazy[Read[Repr] OrElse Derived[MkRead[Repr]]] + ): Derived[MkRead[T]] = { + val hlistInstance: Read[Repr] = hlistRead.value.fold(identity, _.instance) + new Derived(new MkRead(hlistInstance.map(gen.from))) } - // Read[Option[H]], Read[Option[T]] implies Read[Option[Option[H] *: T]] - implicit def optProductOpt[H, T <: HList]( - implicit - H: Read[Option[H]] OrElse MkRead[Option[H]], - T: MkRead[Option[T]] - ): MkRead[Option[Option[H] :: T]] = { - val head = H.unify - - new MkRead[Option[Option[H] :: T]]( - head.gets ++ T.gets, - (rs, n) => T.unsafeGet(rs, n + head.length).map(head.unsafeGet(rs, n) :: _) - ) - } - - // Derivation for optional of product types (i.e. case class) - implicit def ogeneric[A, Repr <: HList]( - implicit - G: Generic.Aux[A, Repr], - B: Lazy[MkRead[Option[Repr]]] - ): MkRead[Option[A]] = - new MkRead[Option[A]](B.value.gets, B.value.unsafeGet(_, _).map(G.from)) - } diff --git a/modules/core/src/main/scala-2/doobie/util/MkWritePlatform.scala b/modules/core/src/main/scala-2/doobie/util/MkWritePlatform.scala index b6572ad41..d5b1603c0 100644 --- a/modules/core/src/main/scala-2/doobie/util/MkWritePlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/MkWritePlatform.scala @@ -4,185 +4,94 @@ package doobie.util -import shapeless.{HList, HNil, ::, Generic, Lazy, <:!<, OrElse} -import shapeless.labelled.{FieldType} +import shapeless.{::, Generic, HList, HNil, Lazy, OrElse} +import shapeless.labelled.FieldType -trait MkWritePlatform extends LowerPriorityWrite { +trait MkWritePlatform extends LowerPriorityMkWrite { // Derivation base case for product types (1-element) implicit def productBase[H]( - implicit H: Write[H] OrElse MkWrite[H] - ): MkWrite[H :: HNil] = { - val head = H.unify - - new MkWrite[H :: HNil]( - head.puts, - { case h :: HNil => head.toList(h) }, - { case (ps, n, h :: HNil) => head.unsafeSet(ps, n, h); }, - { case (rs, n, h :: HNil) => head.unsafeUpdate(rs, n, h); } + implicit H: Write[H] OrElse Derived[MkWrite[H]] + ): Derived[MkWrite[H :: HNil]] = { + val head = H.fold(identity, _.instance) + + new Derived( + new MkWrite[H :: HNil]( + new Write.Composite(List(head), { case h :: HNil => List(h) }) + ) ) } // Derivation base case for shapelss record (1-element) implicit def recordBase[K <: Symbol, H]( - implicit H: Write[H] OrElse MkWrite[H] - ): MkWrite[FieldType[K, H] :: HNil] = { - val head = H.unify - - new MkWrite( - head.puts, - { case h :: HNil => head.toList(h) }, - { case (ps, n, h :: HNil) => head.unsafeSet(ps, n, h) }, - { case (rs, n, h :: HNil) => head.unsafeUpdate(rs, n, h) } + implicit H: Write[H] OrElse Derived[MkWrite[H]] + ): Derived[MkWrite[FieldType[K, H] :: HNil]] = { + val head = H.fold(identity, _.instance) + + new Derived( + new MkWrite( + new Write.Composite(List(head), { case h :: HNil => List(h) }) + ) ) } } -trait LowerPriorityWrite extends EvenLowerPriorityWrite { +trait LowerPriorityMkWrite { // Derivation inductive case for product types implicit def product[H, T <: HList]( implicit - H: Write[H] OrElse MkWrite[H], - T: MkWrite[T] - ): MkWrite[H :: T] = { - val head = H.unify - - new MkWrite( - head.puts ++ T.puts, - { case h :: t => head.toList(h) ++ T.toList(t) }, - { case (ps, n, h :: t) => head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) }, - { case (rs, n, h :: t) => head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) } + H: Write[H] OrElse Derived[MkWrite[H]], + T: Write[T] OrElse Derived[MkWrite[T]] + ): Derived[MkWrite[H :: T]] = { + val head = H.fold(identity, _.instance) + val tail = T.fold(identity, _.instance) + + new Derived( + new MkWrite[H :: T]( + new Write.Composite( + List(head, tail), + { case h :: t => List(h, t) } + ) + ) ) } - // Derivation for product types (i.e. case class) - implicit def generic[B, A](implicit gen: Generic.Aux[B, A], A: Lazy[MkWrite[A]]): MkWrite[B] = - new MkWrite[B]( - A.value.puts, - b => A.value.toList(gen.to(b)), - (ps, n, b) => A.value.unsafeSet(ps, n, gen.to(b)), - (rs, n, b) => A.value.unsafeUpdate(rs, n, gen.to(b)) - ) - // Derivation inductive case for shapeless records implicit def record[K <: Symbol, H, T <: HList]( implicit - H: Write[H] OrElse MkWrite[H], - T: MkWrite[T] - ): MkWrite[FieldType[K, H] :: T] = { - val head = H.unify - - new MkWrite( - head.puts ++ T.puts, - { case h :: t => head.toList(h) ++ T.toList(t) }, - { case (ps, n, h :: t) => head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) }, - { case (rs, n, h :: t) => head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) } + H: Write[H] OrElse Derived[MkWrite[H]], + T: Write[T] OrElse Derived[MkWrite[T]] + ): Derived[MkWrite[FieldType[K, H] :: T]] = { + val head = H.fold(identity, _.instance) + val tail = T.fold(identity, _.instance) + + new Derived( + new MkWrite( + new Write.Composite( + List(head, tail), + { + case h :: t => List(h, t) + } + ) + ) ) } - // Derivation base case for Option of product types (1-element) - implicit def optProductBase[H]( - implicit - H: Write[Option[H]] OrElse MkWrite[Option[H]], - N: H <:!< Option[α] forSome { type α } - ): MkWrite[Option[H :: HNil]] = { - void(N) - val head = H.unify - - def withHead[A](opt: Option[H :: HNil])(f: Option[H] => A): A = { - f(opt.map(_.head)) - } - - new MkWrite( - head.puts, - withHead(_)(head.toList(_)), - (ps, n, i) => withHead(i)(h => head.unsafeSet(ps, n, h)), - (rs, n, i) => withHead(i)(h => head.unsafeUpdate(rs, n, h)) - ) - - } - - // Derivation base case for Option of product types (where the head element is Option) - implicit def optProductOptBase[H]( - implicit H: Write[Option[H]] OrElse MkWrite[Option[H]] - ): MkWrite[Option[Option[H] :: HNil]] = { - val head = H.unify - - def withHead[A](opt: Option[Option[H] :: HNil])(f: Option[H] => A): A = { - opt match { - case Some(h :: _) => f(h) - case None => f(None) - } - } - - new MkWrite( - head.puts, - withHead(_) { h => head.toList(h) }, - (ps, n, i) => withHead(i) { h => head.unsafeSet(ps, n, h) }, - (rs, n, i) => withHead(i) { h => head.unsafeUpdate(rs, n, h) } - ) - - } - -} - -trait EvenLowerPriorityWrite { - - // Write[Option[H]], Write[Option[T]] implies Write[Option[H *: T]] - implicit def optPorduct[H, T <: HList]( - implicit - H: Write[Option[H]] OrElse MkWrite[Option[H]], - T: MkWrite[Option[T]], - N: H <:!< Option[α] forSome { type α } - ): MkWrite[Option[H :: T]] = { - void(N) - val head = H.unify - - def split[A](i: Option[H :: T])(f: (Option[H], Option[T]) => A): A = - i.fold(f(None, None)) { case h :: t => f(Some(h), Some(t)) } - - new MkWrite( - head.puts ++ T.puts, - split(_) { (h, t) => head.toList(h) ++ T.toList(t) }, - (ps, n, i) => split(i) { (h, t) => head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) }, - (rs, n, i) => split(i) { (h, t) => head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) } - ) - - } - - // Write[Option[H]], Write[Option[T]] implies Write[Option[Option[H] *: T]] - implicit def optProductOpt[H, T <: HList]( + // Derivation for product types (i.e. case class) + implicit def genericWrite[A, Repr <: HList]( implicit - H: Write[Option[H]] OrElse MkWrite[Option[H]], - T: MkWrite[Option[T]] - ): MkWrite[Option[Option[H] :: T]] = { - val head = H.unify - - def split[A](i: Option[Option[H] :: T])(f: (Option[H], Option[T]) => A): A = - i.fold(f(None, None)) { case oh :: t => f(oh, Some(t)) } - - new MkWrite( - head.puts ++ T.puts, - split(_) { (h, t) => head.toList(h) ++ T.toList(t) }, - (ps, n, i) => split(i) { (h, t) => head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) }, - (rs, n, i) => split(i) { (h, t) => head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) } + gen: Generic.Aux[A, Repr], + hlistWrite: Lazy[Write[Repr] OrElse Derived[MkWrite[Repr]]] + ): Derived[MkWrite[A]] = { + val g = hlistWrite.value.fold(identity, _.instance) + + new Derived( + new MkWrite[A]( + new Write.Composite(List(g), a => List(gen.to(a))) + ) ) - } - // Derivation for optional of product types (i.e. case class) - implicit def ogeneric[B, A <: HList]( - implicit - G: Generic.Aux[B, A], - A: Lazy[MkWrite[Option[A]]] - ): MkWrite[Option[B]] = - new MkWrite( - A.value.puts, - b => A.value.toList(b.map(G.to)), - (rs, n, a) => A.value.unsafeSet(rs, n, a.map(G.to)), - (rs, n, a) => A.value.unsafeUpdate(rs, n, a.map(G.to)) - ) - } diff --git a/modules/core/src/main/scala-2/doobie/util/PutPlatform.scala b/modules/core/src/main/scala-2/doobie/util/PutPlatform.scala index 26e3c5cf5..c237bc6fb 100644 --- a/modules/core/src/main/scala-2/doobie/util/PutPlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/PutPlatform.scala @@ -11,13 +11,15 @@ trait PutPlatform { import doobie.util.compat.=:= /** @group Instances */ - @deprecated("Use Put.derived instead to derive instances explicitly", "1.0.0-RC6") def unaryProductPut[A, L <: HList, H, T <: HList]( implicit G: Generic.Aux[A, L], C: IsHCons.Aux[L, H, T], H: Lazy[Put[H]], E: (H :: HNil) =:= L - ): MkPut[A] = MkPut.unaryProductPut + ): Put[A] = { + void(E) // E is a necessary constraint but isn't used directly + H.value.contramap[A](a => G.to(a).head) + } } diff --git a/modules/core/src/main/scala-2/doobie/util/ReadPlatform.scala b/modules/core/src/main/scala-2/doobie/util/ReadPlatform.scala index 64a5e7371..c8255543b 100644 --- a/modules/core/src/main/scala-2/doobie/util/ReadPlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/ReadPlatform.scala @@ -4,32 +4,51 @@ package doobie.util -import shapeless.{Generic, HList, IsTuple, Lazy} +import shapeless.labelled.FieldType +import shapeless.{Generic, HList, IsTuple, Lazy, OrElse} +import shapeless.{::, HNil} -trait ReadPlatform { +trait ReadPlatform extends LowerPriority1ReadPlatform { // Derivation for product types (i.e. case class) implicit def genericTuple[A, Repr <: HList](implicit gen: Generic.Aux[A, Repr], - G: Lazy[MkRead[Repr]], + G: Lazy[Read[Repr]], isTuple: IsTuple[A] - ): MkRead[A] = { + ): Read[A] = { val _ = isTuple - MkRead.generic[A, Repr] + implicit val r: Lazy[Read[Repr] OrElse Derived[MkRead[Repr]]] = G.map(OrElse.primary(_)) + MkRead.genericRead[A, Repr].instance } - // Derivation for optional of product types (i.e. case class) - implicit def ogenericTuple[A, Repr <: HList]( + @deprecated("Read.generic has been renamed to Read.derived to align with Scala 3 derivation", "1.0.0-RC6") + def generic[T, Repr <: HList]( implicit - G: Generic.Aux[A, Repr], - B: Lazy[MkRead[Option[Repr]]], - isTuple: IsTuple[A] - ): MkRead[Option[A]] = { - val _ = isTuple - MkRead.ogeneric[A, Repr] - } + gen: Generic.Aux[T, Repr], + G: Lazy[Read[Repr] OrElse Derived[MkRead[Repr]]] + ): Read[T] = + MkRead.genericRead[T, Repr].instance + + implicit def recordBase[K <: Symbol, H]( + implicit H: Read[H] + ): Read[FieldType[K, H] :: HNil] = MkRead.recordBase[K, H].instance + + implicit def productBase[H]( + implicit H: Read[H] + ): Read[H :: HNil] = MkRead.productBase[H].instance +} + +trait LowerPriority1ReadPlatform extends LowestPriorityRead { - @deprecated("Use Read.derived instead to derive instances explicitly", "1.0.0-RC6") - def generic[T, Repr](implicit gen: Generic.Aux[T, Repr], G: Lazy[MkRead[Repr]]): MkRead[T] = - MkRead.generic[T, Repr] + implicit def product[H, T <: HList]( + implicit + H: Read[H], + T: Read[T] + ): Read[H :: T] = MkRead.product[H, T].instance + + implicit def record[K <: Symbol, H, T <: HList]( + implicit + H: Read[H], + T: Read[T] + ): Read[FieldType[K, H] :: T] = MkRead.record[K, H, T].instance } diff --git a/modules/core/src/main/scala-2/doobie/util/WritePlatform.scala b/modules/core/src/main/scala-2/doobie/util/WritePlatform.scala index 0553067d2..f1adc32bd 100644 --- a/modules/core/src/main/scala-2/doobie/util/WritePlatform.scala +++ b/modules/core/src/main/scala-2/doobie/util/WritePlatform.scala @@ -4,31 +4,53 @@ package doobie.util -import shapeless.{Generic, HList, IsTuple, Lazy} +import shapeless.* +import shapeless.labelled.FieldType -trait WritePlatform { +trait WritePlatform extends LowerPriority1WritePlatform { - implicit def genericTuple[A, Repr]( + implicit def genericTuple[A, Repr <: HList]( implicit gen: Generic.Aux[A, Repr], - A: Lazy[MkWrite[Repr]], + G: Lazy[Write[Repr]], isTuple: IsTuple[A] - ): MkWrite[A] = { + ): Write[A] = { val _ = isTuple - MkWrite.generic[A, Repr] + implicit val hlistWrite: Lazy[Write[Repr] OrElse Derived[MkWrite[Repr]]] = G.map(OrElse.primary(_)) + MkWrite.genericWrite[A, Repr].instance } - implicit def ogenericTuple[A, Repr <: HList]( - implicit - G: Generic.Aux[A, Repr], - A: Lazy[MkWrite[Option[Repr]]], - isTuple: IsTuple[A] - ): MkWrite[Option[A]] = { - val _ = isTuple - MkWrite.ogeneric[A, Repr] + @deprecated("Write.generic has been renamed to Write.derived to align with Scala 3 derivation", "1.0.0-RC6") + def generic[T, Repr <: HList](implicit + gen: Generic.Aux[T, Repr], + A: Write[Repr] OrElse Derived[MkWrite[Repr]] + ): Write[T] = { + implicit val hlistWrite: Lazy[Write[Repr] OrElse Derived[MkWrite[Repr]]] = A + MkWrite.genericWrite[T, Repr].instance } - @deprecated("Use Write.derived instead to derive instances explicitly", "1.0.0-RC6") - def generic[T, Repr](implicit gen: Generic.Aux[T, Repr], A: Lazy[MkWrite[Repr]]): MkWrite[T] = - MkWrite.generic[T, Repr] + implicit def recordBase[K <: Symbol, H]( + implicit H: Write[H] + ): Write[FieldType[K, H] :: HNil] = MkWrite.recordBase[K, H].instance + + implicit def productBase[H]( + implicit H: Write[H] + ): Write[H :: HNil] = MkWrite.productBase[H].instance + +} + +trait LowerPriority1WritePlatform extends LowestPriorityWrite { + + implicit def product[H, T <: HList]( + implicit + H: Write[H], + T: Write[T] + ): Write[H :: T] = MkWrite.product[H, T].instance + + implicit def record[K <: Symbol, H, T <: HList]( + implicit + H: Write[H], + T: Write[T] + ): Write[FieldType[K, H] :: T] = MkWrite.record[K, H, T].instance + } diff --git a/modules/core/src/main/scala-3/doobie/util/MkGetPlatform.scala b/modules/core/src/main/scala-3/doobie/util/MkGetPlatform.scala deleted file mode 100644 index 833dbcabc..000000000 --- a/modules/core/src/main/scala-3/doobie/util/MkGetPlatform.scala +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util - -import scala.deriving.Mirror - -trait MkGetPlatform: - - // Get is available for single-element products. - given unaryProductGet[P <: Product, A]( - using - p: Mirror.ProductOf[P], - i: p.MirroredElemTypes =:= (A *: EmptyTuple), - g: Get[A] - ): MkGet[P] = { - val get = g.map(a => p.fromProduct(a *: EmptyTuple)) - MkGet.lift(get) - } diff --git a/modules/core/src/main/scala-3/doobie/util/MkPutPlatform.scala b/modules/core/src/main/scala-3/doobie/util/MkPutPlatform.scala deleted file mode 100644 index 9c9afcacb..000000000 --- a/modules/core/src/main/scala-3/doobie/util/MkPutPlatform.scala +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util - -import scala.deriving.Mirror - -trait MkPutPlatform: - - // Put is available for single-element products. - given unaryProductPut[P <: Product, A]( - using - m: Mirror.ProductOf[P], - i: m.MirroredElemTypes =:= (A *: EmptyTuple), - p: Put[A] - ): MkPut[P] = { - val put: Put[P] = p.contramap(p => i(Tuple.fromProductTyped(p)).head) - MkPut.lift(put) - } diff --git a/modules/core/src/main/scala-3/doobie/util/MkReadPlatform.scala b/modules/core/src/main/scala-3/doobie/util/MkReadPlatform.scala index 5a002a261..70cbf6745 100644 --- a/modules/core/src/main/scala-3/doobie/util/MkReadPlatform.scala +++ b/modules/core/src/main/scala-3/doobie/util/MkReadPlatform.scala @@ -6,103 +6,54 @@ package doobie.util import scala.deriving.Mirror import doobie.util.shapeless.OrElse +import scala.util.NotGiven trait MkReadPlatform: - // Generic Read for products. - given derived[P <: Product, A]( - using + // Derivation for product types (i.e. case class) + implicit def derived[P <: Product, A]( + implicit m: Mirror.ProductOf[P], i: A =:= m.MirroredElemTypes, - w: MkRead[A] - ): MkRead[P] = { - val read = w.map(a => m.fromProduct(i(a))) - MkRead.lift(read) - } - - // Generic Read for option of products. - given derivedOption[P <: Product, A]( - using - m: Mirror.ProductOf[P], - i: A =:= m.MirroredElemTypes, - w: MkRead[Option[A]] - ): MkRead[Option[P]] = { - val read = w.map(a => a.map(a => m.fromProduct(i(a)))) - MkRead.lift(read) - } - - // Derivation base case for product types (1-element) - given productBase[H]( - using H: Read[H] `OrElse` MkRead[H] - ): MkRead[H *: EmptyTuple] = { - val head = H.unify - new MkRead( - head.gets, - (rs, n) => head.unsafeGet(rs, n) *: EmptyTuple - ) - } - - // Read for head and tail. - given product[H, T <: Tuple]( - using - H: Read[H] `OrElse` MkRead[H], - T: MkRead[T] - ): MkRead[H *: T] = { - val head = H.unify - - new MkRead[H *: T]( - head.gets ++ T.gets, - (rs, n) => head.unsafeGet(rs, n) *: T.unsafeGet(rs, n + head.length) - ) - } - - given optProductBase[H]( - using H: Read[Option[H]] `OrElse` MkRead[Option[H]] - ): MkRead[Option[H *: EmptyTuple]] = { - val head = H.unify - MkRead[Option[H *: EmptyTuple]]( - head.gets, - (rs, n) => head.unsafeGet(rs, n).map(_ *: EmptyTuple) - ) - } - - given optProduct[H, T <: Tuple]( - using - H: Read[Option[H]] `OrElse` MkRead[Option[H]], - T: MkRead[Option[T]] - ): MkRead[Option[H *: T]] = { - val head = H.unify - - new MkRead[Option[H *: T]]( - head.gets ++ T.gets, - (rs, n) => - for { - h <- head.unsafeGet(rs, n) - t <- T.unsafeGet(rs, n + head.length) - } yield h *: t - ) - } - - given optProductOptBase[H]( - using H: Read[Option[H]] `OrElse` MkRead[Option[H]] - ): MkRead[Option[Option[H] *: EmptyTuple]] = { - val head = H.unify - - MkRead[Option[Option[H] *: EmptyTuple]]( - head.gets, - (rs, n) => head.unsafeGet(rs, n).map(h => Some(h) *: EmptyTuple) + r: Read[A] `OrElse` Derived[MkRead[A]], + isNotCaseObj: NotGiven[m.MirroredElemTypes =:= EmptyTuple] + ): Derived[MkRead[P]] = { + val _ = isNotCaseObj + val read = r.fold(identity, _.instance).map(a => m.fromProduct(i(a))) + new Derived(new MkRead(read)) + } + + // Derivation base case for tuple (1-element) + implicit def productBase[H]( + implicit H: Read[H] `OrElse` Derived[MkRead[H]] + ): Derived[MkRead[H *: EmptyTuple]] = { + val headInstance = H.fold(identity, _.instance) + new Derived( + new MkRead( + Read.Transform( + headInstance, + h => h *: EmptyTuple + ) + )) + } + + // Derivation inductive case for tuples + implicit def product[H, T <: Tuple]( + implicit + H: Read[H] `OrElse` Derived[MkRead[H]], + T: Read[T] `OrElse` Derived[MkRead[T]] + ): Derived[MkRead[H *: T]] = { + val headInstance = H.fold(identity, _.instance) + val tailInstance = T.fold(identity, _.instance) + + new Derived( + new MkRead( + Read.Composite( + headInstance, + tailInstance, + (h, t) => h *: t + ) + ) ) - } - - given optProductOpt[H, T <: Tuple]( - using - H: Read[Option[H]] `OrElse` MkRead[Option[H]], - T: MkRead[Option[T]] - ): MkRead[Option[Option[H] *: T]] = { - val head = H.unify - new MkRead[Option[Option[H] *: T]]( - head.gets ++ T.gets, - (rs, n) => T.unsafeGet(rs, n + head.length).map(head.unsafeGet(rs, n) *: _) - ) } diff --git a/modules/core/src/main/scala-3/doobie/util/MkWritePlatform.scala b/modules/core/src/main/scala-3/doobie/util/MkWritePlatform.scala index 78d200e60..91f393017 100644 --- a/modules/core/src/main/scala-3/doobie/util/MkWritePlatform.scala +++ b/modules/core/src/main/scala-3/doobie/util/MkWritePlatform.scala @@ -6,130 +6,50 @@ package doobie.util import scala.deriving.Mirror import doobie.util.shapeless.OrElse +import scala.util.NotGiven trait MkWritePlatform: // Derivation for product types (i.e. case class) - given derived[P <: Product, A]( - using + implicit def derived[P <: Product, A]( + implicit m: Mirror.ProductOf[P], i: m.MirroredElemTypes =:= A, - w: MkWrite[A] - ): MkWrite[P] = - val write: Write[P] = w.contramap(p => i(Tuple.fromProductTyped(p))) - MkWrite.lift(write) - - // Derivation for optional product types - given derivedOption[P <: Product, A]( - using - m: Mirror.ProductOf[P], - i: m.MirroredElemTypes =:= A, - w: MkWrite[Option[A]] - ): MkWrite[Option[P]] = - val write: Write[Option[P]] = w.contramap(op => op.map(p => i(Tuple.fromProductTyped(p)))) - MkWrite.lift(write) - - // Derivation base case for product types (1-element) - given productBase[H]( - using H: Write[H] `OrElse` MkWrite[H] - ): MkWrite[H *: EmptyTuple] = { - val head = H.unify - MkWrite( - head.puts, - { case h *: t => head.toList(h) }, - { case (ps, n, h *: t) => head.unsafeSet(ps, n, h) }, - { case (rs, n, h *: t) => head.unsafeUpdate(rs, n, h) } + w: Write[A] `OrElse` Derived[MkWrite[A]], + isNotCaseObj: NotGiven[m.MirroredElemTypes =:= EmptyTuple] + ): Derived[MkWrite[P]] = + val _ = isNotCaseObj + val write: Write[P] = w.fold(identity, _.instance).contramap(p => i(Tuple.fromProductTyped(p))) + new Derived( + new MkWrite(write) ) - } - // Derivation inductive case for product types - given product[H, T <: Tuple]( - using - H: Write[H] `OrElse` MkWrite[H], - T: MkWrite[T] - ): MkWrite[H *: T] = { - val head = H.unify - - MkWrite( - head.puts ++ T.puts, - { case h *: t => head.toList(h) ++ T.toList(t) }, - { case (ps, n, h *: t) => head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) }, - { case (rs, n, h *: t) => head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) } + // Derivation base case for tuple (1-element) + implicit def productBase[H]( + implicit H: Write[H] `OrElse` Derived[MkWrite[H]] + ): Derived[MkWrite[H *: EmptyTuple]] = { + val headInstance = H.fold(identity, _.instance) + new Derived( + new MkWrite(Write.Composite( + List(headInstance), + { case h *: EmptyTuple => List(h) } + )) ) } - // Derivation base case for Option of product types (1-element) - given optProductBase[H]( - using H: Write[Option[H]] `OrElse` MkWrite[Option[H]] - ): MkWrite[Option[H *: EmptyTuple]] = { - val head = H.unify - - MkWrite[Option[H *: EmptyTuple]]( - head.puts, - i => head.toList(i.map { case h *: EmptyTuple => h }), - (ps, n, i) => head.unsafeSet(ps, n, i.map { case h *: EmptyTuple => h }), - (rs, n, i) => head.unsafeUpdate(rs, n, i.map { case h *: EmptyTuple => h }) + // Derivation inductive case for tuples + implicit def product[H, T <: Tuple]( + implicit + H: Write[H] `OrElse` Derived[MkWrite[H]], + T: Write[T] `OrElse` Derived[MkWrite[T]] + ): Derived[MkWrite[H *: T]] = { + val headWrite = H.fold(identity, _.instance) + val tailWrite = T.fold(identity, _.instance) + + new Derived( + new MkWrite(Write.Composite( + List(headWrite, tailWrite), + { case h *: t => List(h, t) } + )) ) } - - // Write[Option[H]], Write[Option[T]] implies Write[Option[H *: T]] - given optProduct[H, T <: Tuple]( - using - H: Write[Option[H]] `OrElse` MkWrite[Option[H]], - T: MkWrite[Option[T]] - ): MkWrite[Option[H *: T]] = - val head = H.unify - - def split[A](i: Option[H *: T])(f: (Option[H], Option[T]) => A): A = - i.fold(f(None, None)) { case h *: t => f(Some(h), Some(t)) } - - MkWrite( - head.puts ++ T.puts, - split(_) { (h, t) => head.toList(h) ++ T.toList(t) }, - (ps, n, i) => - split(i) { (h, t) => - head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) - }, - (rs, n, i) => - split(i) { (h, t) => - head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) - } - ) - - // Derivation base case for Option of product types (where the head element is Option) - given optProductOptBase[H]( - using H: Write[Option[H]] `OrElse` MkWrite[Option[H]] - ): MkWrite[Option[Option[H] *: EmptyTuple]] = { - val head = H.unify - - MkWrite[Option[Option[H] *: EmptyTuple]]( - head.puts, - i => head.toList(i.flatMap { case ho *: EmptyTuple => ho }), - (ps, n, i) => head.unsafeSet(ps, n, i.flatMap { case ho *: EmptyTuple => ho }), - (rs, n, i) => head.unsafeUpdate(rs, n, i.flatMap { case ho *: EmptyTuple => ho }) - ) - } - - // Write[Option[H]], Write[Option[T]] implies Write[Option[Option[H] *: T]] - given optProductOpt[H, T <: Tuple]( - using - H: Write[Option[H]] `OrElse` MkWrite[Option[H]], - T: MkWrite[Option[T]] - ): MkWrite[Option[Option[H] *: T]] = - val head = H.unify - - def split[A](i: Option[Option[H] *: T])(f: (Option[H], Option[T]) => A): A = - i.fold(f(None, None)) { case oh *: t => f(oh, Some(t)) } - - MkWrite( - head.puts ++ T.puts, - split(_) { (h, t) => head.toList(h) ++ T.toList(t) }, - (ps, n, i) => - split(i) { (h, t) => - head.unsafeSet(ps, n, h); T.unsafeSet(ps, n + head.length, t) - }, - (rs, n, i) => - split(i) { (h, t) => - head.unsafeUpdate(rs, n, h); T.unsafeUpdate(rs, n + head.length, t) - } - ) diff --git a/modules/core/src/main/scala-3/doobie/util/PutPlatform.scala b/modules/core/src/main/scala-3/doobie/util/PutPlatform.scala index 9a7151ebe..f39f8a136 100644 --- a/modules/core/src/main/scala-3/doobie/util/PutPlatform.scala +++ b/modules/core/src/main/scala-3/doobie/util/PutPlatform.scala @@ -4,4 +4,4 @@ package doobie.util -trait PutPlatform {} +trait PutPlatform diff --git a/modules/core/src/main/scala-3/doobie/util/ReadPlatform.scala b/modules/core/src/main/scala-3/doobie/util/ReadPlatform.scala index f6e0fa795..f14bf0d38 100644 --- a/modules/core/src/main/scala-3/doobie/util/ReadPlatform.scala +++ b/modules/core/src/main/scala-3/doobie/util/ReadPlatform.scala @@ -4,23 +4,16 @@ package doobie.util -import scala.deriving.Mirror +trait ReadPlatform extends LowestPriorityRead: -trait ReadPlatform: - // Generic Read for products. - given derivedTuple[P <: Tuple, A]( - using - m: Mirror.ProductOf[P], - i: A =:= m.MirroredElemTypes, - w: MkRead[A] - ): MkRead[P] = - MkRead.derived[P, A] + given tupleBase[H]( + using H: Read[H] + ): Read[H *: EmptyTuple] = + H.map(h => h *: EmptyTuple) - // Generic Read for option of products. - given derivedOptionTuple[P <: Tuple, A]( + given tuple[H, T <: Tuple]( using - m: Mirror.ProductOf[P], - i: A =:= m.MirroredElemTypes, - w: MkRead[Option[A]] - ): MkRead[Option[P]] = - MkRead.derivedOption[P, A] + H: Read[H], + T: Read[T] + ): Read[H *: T] = + Read.Composite(H, T, (h, t) => h *: t) diff --git a/modules/core/src/main/scala-3/doobie/util/WritePlatform.scala b/modules/core/src/main/scala-3/doobie/util/WritePlatform.scala index 3e6989fe5..71864ba7b 100644 --- a/modules/core/src/main/scala-3/doobie/util/WritePlatform.scala +++ b/modules/core/src/main/scala-3/doobie/util/WritePlatform.scala @@ -4,24 +4,26 @@ package doobie.util -import scala.deriving.Mirror +trait WritePlatform extends LowestPriorityWrite: -trait WritePlatform: + given tupleBase[H]( + using H: Write[H] + ): Write[H *: EmptyTuple] = + Write.Composite[H *: EmptyTuple]( + List(H), + { + case h *: EmptyTuple => List(h) + } + ) - // Derivation for product types (i.e. case class) - given derivedTuple[P <: Tuple, A]( + given tuple[H, T <: Tuple]( using - m: Mirror.ProductOf[P], - i: m.MirroredElemTypes =:= A, - w: MkWrite[A] - ): MkWrite[P] = - MkWrite.derived[P, A] - - // Derivation for optional product types - given derivedOptionTuple[P <: Tuple, A]( - using - m: Mirror.ProductOf[P], - i: m.MirroredElemTypes =:= A, - w: MkWrite[Option[A]] - ): MkWrite[Option[P]] = - MkWrite.derivedOption[P, A] + H: Write[H], + T: Write[T] + ): Write[H *: T] = + Write.Composite( + List(H, T), + { + case h *: t => List(h, t) + } + ) diff --git a/modules/core/src/main/scala/doobie/generic/auto.scala b/modules/core/src/main/scala/doobie/generic/auto.scala index e8c73c6cf..c78ac237f 100644 --- a/modules/core/src/main/scala/doobie/generic/auto.scala +++ b/modules/core/src/main/scala/doobie/generic/auto.scala @@ -4,20 +4,10 @@ package doobie.generic -import doobie.util.meta.Meta -import doobie.util.{Get, Put, Read, Write} +import doobie.util.{Read, Write} trait AutoDerivation - extends Get.Auto - with Put.Auto - with Read.Auto + extends Read.Auto with Write.Auto -object auto extends AutoDerivation { - - // re-export these instances so `Meta` takes priority, must be in the object - implicit def metaProjectionGet[A](implicit m: Meta[A]): Get[A] = Get.metaProjection - implicit def metaProjectionPut[A](implicit m: Meta[A]): Put[A] = Put.metaProjectionWrite - implicit def fromGetRead[A](implicit G: Get[A]): Read[A] = Read.fromGet - implicit def fromPutWrite[A](implicit P: Put[A]): Write[A] = Write.fromPut -} +object auto extends AutoDerivation diff --git a/modules/core/src/main/scala/doobie/package.scala b/modules/core/src/main/scala/doobie/package.scala index db89940ab..1ea5f99f6 100644 --- a/modules/core/src/main/scala/doobie/package.scala +++ b/modules/core/src/main/scala/doobie/package.scala @@ -29,12 +29,6 @@ package object doobie with LegacyMeta with syntax.AllSyntax { - // re-export these instances so `Meta` takes priority, must be in the object - implicit def metaProjectionGet[A](implicit m: Meta[A]): Get[A] = Get.metaProjection - implicit def metaProjectionPut[A](implicit m: Meta[A]): Put[A] = Put.metaProjectionWrite - implicit def fromGetRead[A](implicit G: Get[A]): Read[A] = Read.fromGet - implicit def fromPutWrite[A](implicit P: Put[A]): Write[A] = Write.fromPut - /** Only use this import if: * 1. You're NOT using one of the database doobie has direct java.time isntances for (PostgreSQL / MySQL). (They * have more accurate column type checks) 2. Your driver natively supports java.time.* types diff --git a/modules/core/src/test/scala-3/doobie/util/GetSuitePlatform.scala b/modules/core/src/main/scala/doobie/util/Derived.scala similarity index 55% rename from modules/core/src/test/scala-3/doobie/util/GetSuitePlatform.scala rename to modules/core/src/main/scala/doobie/util/Derived.scala index c82893dca..07c40fbc2 100644 --- a/modules/core/src/test/scala-3/doobie/util/GetSuitePlatform.scala +++ b/modules/core/src/main/scala/doobie/util/Derived.scala @@ -2,11 +2,6 @@ // This software is licensed under the MIT License (MIT). // For more information see LICENSE or https://opensource.org/licenses/MIT -package doobie -package util +package doobie.util -trait GetSuitePlatform { self: munit.FunSuite => - - test("Get should be derived for unary products (AnyVal)".ignore) {} - -} +class Derived[+I](val instance: I) extends AnyVal diff --git a/modules/core/src/main/scala/doobie/util/analysis.scala b/modules/core/src/main/scala/doobie/util/analysis.scala index de07c8b07..97e8733d3 100644 --- a/modules/core/src/main/scala/doobie/util/analysis.scala +++ b/modules/core/src/main/scala/doobie/util/analysis.scala @@ -135,11 +135,12 @@ object analysis { columnAlignment: List[(Get[?], NullabilityKnown) `Ior` ColumnMeta] ) { - def parameterMisalignments: List[ParameterMisalignment] = + def parameterMisalignments: List[ParameterMisalignment] = { parameterAlignment.zipWithIndex.collect { case (Ior.Left(_), n) => ParameterMisalignment(n + 1, None) case (Ior.Right(p), n) => ParameterMisalignment(n + 1, Some(p)) } + } private def hasParameterTypeErrors[A](put: Put[A], paramMeta: ParameterMeta): Boolean = { !put.jdbcTargets.contains_(paramMeta.jdbcType) || diff --git a/modules/core/src/main/scala/doobie/util/fragment.scala b/modules/core/src/main/scala/doobie/util/fragment.scala index ad7f732cb..13bf95720 100644 --- a/modules/core/src/main/scala/doobie/util/fragment.scala +++ b/modules/core/src/main/scala/doobie/util/fragment.scala @@ -6,7 +6,6 @@ package doobie.util import cats.* import cats.data.Chain -import doobie.enumerated.Nullability.* import doobie.free.connection.ConnectionIO import doobie.free.preparedstatement.PreparedStatementIO import doobie.util.pos.Pos @@ -14,9 +13,6 @@ import doobie.hi.connection as IHC import doobie.util.query.{Query, Query0} import doobie.util.update.{Update, Update0} -import java.sql.{PreparedStatement, ResultSet} -import scala.Predef.{augmentString, implicitly} - /** Module defining the `Fragment` data type. */ object fragment { @@ -35,42 +31,20 @@ object fragment { private implicit lazy val write: Write[elems.type] = { import Elem.* - val puts: List[(Put[?], NullabilityKnown)] = + val writes: List[Write[?]] = elems.map { - case Arg(_, p) => (p, NoNulls) - case Opt(_, p) => (p, Nullable) + case Arg(_, p) => new Write.Single(p) + case Opt(_, p) => new Write.SingleOpt(p) }.toList - val toList: elems.type => List[Any] = elems => - elems.map { - case Arg(a, _) => a - case Opt(a, _) => a - }.toList - - val unsafeSet: (PreparedStatement, Int, elems.type) => Unit = { (ps, n, elems) => - var index = n - elems.iterator.foreach { e => - e match { - case Arg(a, p) => p.unsafeSetNonNullable(ps, index, a) - case Opt(a, p) => p.unsafeSetNullable(ps, index, a) - } - index += 1 - } - } - - val unsafeUpdate: (ResultSet, Int, elems.type) => Unit = { (ps, n, elems) => - var index = n - elems.iterator.foreach { e => - e match { - case Arg(a, p) => p.unsafeUpdateNonNullable(ps, index, a) - case Opt(a, p) => p.unsafeUpdateNullable(ps, index, a) - } - index += 1 - } - } - - Write(puts, toList, unsafeSet, unsafeUpdate) - + new Write.Composite( + writes, + elems => + elems.map { + case Arg(a, _) => a + case Opt(aOpt, _) => aOpt + }.toList + ) } /** Construct a program in ConnectionIO that constructs and prepares a PreparedStatement, with further handling diff --git a/modules/core/src/main/scala/doobie/util/get.scala b/modules/core/src/main/scala/doobie/util/get.scala index e1a9557b1..44dca67e7 100644 --- a/modules/core/src/main/scala/doobie/util/get.scala +++ b/modules/core/src/main/scala/doobie/util/get.scala @@ -84,12 +84,6 @@ object Get extends GetInstances with GetPlatform { def apply[A](implicit ev: Get[A]): ev.type = ev - def derived[A](implicit ev: MkGet[A]): Get[A] = ev - - trait Auto { - implicit def deriveGet[A](implicit ev: MkGet[A]): Get[A] = ev - } - /** Get instance for a basic JDBC type. */ object Basic { @@ -213,23 +207,3 @@ trait GetInstances { ev.tmap(_.toVector) } - -sealed abstract class MkGet[A]( - override val typeStack: NonEmptyList[Option[String]], - override val jdbcSources: NonEmptyList[JdbcType], - override val jdbcSourceSecondary: List[JdbcType], - override val vendorTypeNames: List[String], - override val get: Coyoneda[(ResultSet, Int) => *, A] -) extends Get[A](typeStack, jdbcSources, jdbcSourceSecondary, vendorTypeNames, get) - -object MkGet extends MkGetPlatform { - - def lift[A](g: Get[A]): MkGet[A] = - new MkGet[A]( - typeStack = g.typeStack, - jdbcSources = g.jdbcSources, - jdbcSourceSecondary = g.jdbcSourceSecondary, - vendorTypeNames = g.vendorTypeNames, - get = g.get - ) {} -} diff --git a/modules/core/src/main/scala/doobie/util/meta/meta.scala b/modules/core/src/main/scala/doobie/util/meta/meta.scala index 900c6baec..12d036dee 100644 --- a/modules/core/src/main/scala/doobie/util/meta/meta.scala +++ b/modules/core/src/main/scala/doobie/util/meta/meta.scala @@ -130,13 +130,13 @@ trait MetaConstructors { ) def array[A >: Null <: AnyRef]( - elementType: String, - schemaH: String, - schemaT: String* + elementTypeName: String, // Used in Put to set the array element type + arrayTypeName: String, + additionalArrayTypeNames: String* ): Meta[Array[A]] = new Meta[Array[A]]( - Get.Advanced.array[A](NonEmptyList(schemaH, schemaT.toList)), - Put.Advanced.array[A](NonEmptyList(schemaH, schemaT.toList), elementType) + Get.Advanced.array[A](NonEmptyList(arrayTypeName, additionalArrayTypeNames.toList)), + Put.Advanced.array[A](NonEmptyList(arrayTypeName, additionalArrayTypeNames.toList), elementTypeName) ) def other[A >: Null <: AnyRef: TypeName: ClassTag]( diff --git a/modules/core/src/main/scala/doobie/util/put.scala b/modules/core/src/main/scala/doobie/util/put.scala index 752f6df16..c79eff1c4 100644 --- a/modules/core/src/main/scala/doobie/util/put.scala +++ b/modules/core/src/main/scala/doobie/util/put.scala @@ -82,16 +82,10 @@ sealed abstract class Put[A]( } -object Put extends PutInstances with PutPlatform { +object Put extends PutInstances { def apply[A](implicit ev: Put[A]): ev.type = ev - def derived[A](implicit ev: MkPut[A]): Put[A] = ev - - trait Auto { - implicit def derivePut[A](implicit ev: MkPut[A]): Put[A] = ev - } - object Basic { def apply[A]( @@ -208,7 +202,7 @@ object Put extends PutInstances with PutPlatform { } -trait PutInstances { +trait PutInstances extends PutPlatform { /** @group Instances */ implicit val ContravariantPut: Contravariant[Put] = @@ -226,23 +220,3 @@ trait PutInstances { ev.tcontramap(_.toArray) } - -sealed abstract class MkPut[A]( - override val typeStack: NonEmptyList[Option[String]], - override val jdbcTargets: NonEmptyList[JdbcType], - override val vendorTypeNames: List[String], - override val put: ContravariantCoyoneda[(PreparedStatement, Int, *) => Unit, A], - override val update: ContravariantCoyoneda[(ResultSet, Int, *) => Unit, A] -) extends Put[A](typeStack, jdbcTargets, vendorTypeNames, put, update) - -object MkPut extends MkPutPlatform { - - def lift[A](g: Put[A]): MkPut[A] = - new MkPut[A]( - typeStack = g.typeStack, - jdbcTargets = g.jdbcTargets, - vendorTypeNames = g.vendorTypeNames, - put = g.put, - update = g.update - ) {} -} diff --git a/modules/core/src/main/scala/doobie/util/read.scala b/modules/core/src/main/scala/doobie/util/read.scala index a8f3e4bdb..14aa0e688 100644 --- a/modules/core/src/main/scala/doobie/util/read.scala +++ b/modules/core/src/main/scala/doobie/util/read.scala @@ -4,13 +4,14 @@ package doobie.util -import cats.* -import doobie.free.ResultSetIO -import doobie.enumerated.Nullability.* +import cats.Applicative +import doobie.ResultSetIO +import doobie.enumerated.Nullability +import doobie.enumerated.Nullability.{NoNulls, NullabilityKnown} +import doobie.free.resultset as IFRS import java.sql.ResultSet import scala.annotation.implicitNotFound -import doobie.free.resultset as IFRS @implicitNotFound(""" Cannot find or construct a Read instance for type: @@ -27,7 +28,7 @@ some debugging hints: version. - For types you expect to map to a single column ensure that a Get instance is in scope. -- For case classes, HLists, and shapeless records ensure that each element +- For case classes, shapeless HLists/records ensure that each element has a Read instance in scope. - Lather, rinse, repeat, recursively until you find the problematic bit. @@ -42,67 +43,135 @@ and similarly with Get: And find the missing instance and construct it as needed. Refer to Chapter 12 of the book of doobie for more information. """) -sealed abstract class Read[A]( - val gets: List[(Get[?], NullabilityKnown)], - val unsafeGet: (ResultSet, Int) => A -) { - - final lazy val length: Int = gets.length - - def map[B](f: A => B): Read[B] = - new Read(gets, (rs, n) => f(unsafeGet(rs, n))) {} - - def ap[B](ff: Read[A => B]): Read[B] = - new Read(ff.gets ++ gets, (rs, n) => ff.unsafeGet(rs, n)(unsafeGet(rs, n + ff.length))) {} +sealed trait Read[A] { + def unsafeGet(rs: ResultSet, startIdx: Int): A + def gets: List[(Get[?], NullabilityKnown)] + def toOpt: Read[Option[A]] + def length: Int - def get(n: Int): ResultSetIO[A] = + final def get(n: Int): ResultSetIO[A] = IFRS.raw(unsafeGet(_, n)) + final def map[B](f: A => B): Read[B] = new Read.Transform[B, A](this, f) + + final def ap[B](ff: Read[A => B]): Read[B] = { + new Read.Composite[B, A => B, A](ff, this, (f, a) => f(a)) + } } -object Read extends ReadPlatform { +object Read extends LowerPriority1Read { - def apply[A]( - gets: List[(Get[?], NullabilityKnown)], - unsafeGet: (ResultSet, Int) => A - ): Read[A] = new Read(gets, unsafeGet) {} + def apply[A](implicit ev: Read[A]): Read[A] = ev - def apply[A](implicit ev: Read[A]): ev.type = ev + def derived[A](implicit + @implicitNotFound( + "Cannot derive Read instance. Please check that each field in the case class has a Read instance or can derive one") + ev: Derived[MkRead[A]] + ): Read[A] = ev.instance.underlying - def derived[A](implicit ev: MkRead[A]): Read[A] = ev - - trait Auto { - implicit def deriveRead[A](implicit ev: MkRead[A]): Read[A] = ev - } + trait Auto extends MkReadInstances implicit val ReadApply: Applicative[Read] = new Applicative[Read] { def ap[A, B](ff: Read[A => B])(fa: Read[A]): Read[B] = fa.ap(ff) - def pure[A](x: A): Read[A] = new Read(Nil, (_, _) => x) {} + def pure[A](x: A): Read[A] = unitRead.map(_ => x) override def map[A, B](fa: Read[A])(f: A => B): Read[B] = fa.map(f) } - implicit val unit: Read[Unit] = - Read(Nil, (_, _) => ()) + implicit val unitRead: Read[Unit] = new Read[Unit] { + override def unsafeGet(rs: ResultSet, startIdx: Int): Unit = { + () // Does not read anything from ResultSet + } + override def gets: List[(Get[?], NullabilityKnown)] = List.empty + override def toOpt: Read[Option[Unit]] = this.map(_ => Some(())) + override def length: Int = 0 + } + + /** Simple instance wrapping a Get. i.e. single column non-null value */ + class Single[A](get: Get[A]) extends Read[A] { + def unsafeGet(rs: ResultSet, startIdx: Int): A = + get.unsafeGetNonNullable(rs, startIdx) + + override def toOpt: Read[Option[A]] = new SingleOpt(get) + + override def gets: List[(Get[?], NullabilityKnown)] = List(get -> NoNulls) + + override val length: Int = 1 + + } + + /** Simple instance wrapping a Get. i.e. single column nullable value */ + class SingleOpt[A](get: Get[A]) extends Read[Option[A]] { + def unsafeGet(rs: ResultSet, startIdx: Int): Option[A] = + get.unsafeGetNullable(rs, startIdx) + + override def toOpt: Read[Option[Option[A]]] = new Transform[Option[Option[A]], Option[A]](this, a => Some(a)) + override def gets: List[(Get[?], NullabilityKnown)] = List(get -> Nullability.Nullable) + + override val length: Int = 1 + } + + class Transform[A, From](underlyingRead: Read[From], f: From => A) extends Read[A] { + override def unsafeGet(rs: ResultSet, startIdx: Int): A = f(underlyingRead.unsafeGet(rs, startIdx)) + override def gets: List[(Get[?], NullabilityKnown)] = underlyingRead.gets + override def toOpt: Read[Option[A]] = + new Transform[Option[A], Option[From]](underlyingRead.toOpt, opt => opt.map(f)) + override lazy val length: Int = underlyingRead.length + } + + /** A Read instance consists of multiple underlying Read instances */ + class Composite[A, S0, S1](read0: Read[S0], read1: Read[S1], f: (S0, S1) => A) extends Read[A] { + override def unsafeGet(rs: ResultSet, startIdx: Int): A = { + val r0 = read0.unsafeGet(rs, startIdx) + val r1 = read1.unsafeGet(rs, startIdx + read0.length) + f(r0, r1) + } + + override lazy val gets: List[(Get[?], NullabilityKnown)] = + read0.gets ++ read1.gets - implicit val optionUnit: Read[Option[Unit]] = - Read(Nil, (_, _) => Some(())) + override def toOpt: Read[Option[A]] = { + val readOpt0 = read0.toOpt + val readOpt1 = read1.toOpt + new Composite[Option[A], Option[S0], Option[S1]]( + readOpt0, + readOpt1, + { + case (Some(s0), Some(s1)) => Some(f(s0, s1)) + case _ => None + }) - implicit def fromGet[A](implicit ev: Get[A]): Read[A] = - new Read(List((ev, NoNulls)), ev.unsafeGetNonNullable) {} + } + override lazy val length: Int = read0.length + read1.length + } - implicit def fromGetOption[A](implicit ev: Get[A]): Read[Option[A]] = - new Read(List((ev, Nullable)), ev.unsafeGetNullable) {} +} + +trait LowerPriority1Read extends LowerPriority2Read { + + implicit def fromReadOption[A](implicit read: Read[A]): Read[Option[A]] = read.toOpt } -final class MkRead[A]( - override val gets: List[(Get[?], NullabilityKnown)], - override val unsafeGet: (ResultSet, Int) => A -) extends Read[A](gets, unsafeGet) +trait LowerPriority2Read extends ReadPlatform { + + implicit def fromGet[A](implicit get: Get[A]): Read[A] = new Read.Single(get) -object MkRead extends MkReadPlatform { + implicit def fromGetOption[A](implicit get: Get[A]): Read[Option[A]] = new Read.SingleOpt(get) + +} - def lift[A](r: Read[A]): MkRead[A] = - new MkRead[A](r.gets, r.unsafeGet) +trait LowestPriorityRead { + implicit def fromDerived[A](implicit ev: Derived[Read[A]]): Read[A] = ev.instance } + +final class MkRead[A](val underlying: Read[A]) extends Read[A] { + override def unsafeGet(rs: ResultSet, startIdx: Int): A = underlying.unsafeGet(rs, startIdx) + override def gets: List[(Get[?], NullabilityKnown)] = underlying.gets + override def toOpt: Read[Option[A]] = underlying.toOpt + override def length: Int = underlying.length +} + +object MkRead extends MkReadInstances + +trait MkReadInstances extends MkReadPlatform diff --git a/modules/core/src/main/scala/doobie/util/write.scala b/modules/core/src/main/scala/doobie/util/write.scala index eb06bb09e..84dfacc56 100644 --- a/modules/core/src/main/scala/doobie/util/write.scala +++ b/modules/core/src/main/scala/doobie/util/write.scala @@ -5,6 +5,7 @@ package doobie.util import cats.ContravariantSemigroupal +import doobie.enumerated.Nullability import doobie.enumerated.Nullability.* import doobie.free.{PreparedStatementIO, ResultSetIO} @@ -28,7 +29,7 @@ some debugging hints: version. - For types you expect to map to a single column ensure that a Put instance is in scope. -- For case classes, HLists, and shapeless records ensure that each element +- For case classes, shapeless HLists/records ensure that each element has a Write instance in scope. - Lather, rinse, repeat, recursively until you find the problematic bit. @@ -43,40 +44,26 @@ and similarly with Put: And find the missing instance and construct it as needed. Refer to Chapter 12 of the book of doobie for more information. """) -sealed abstract class Write[A]( - val puts: List[(Put[?], NullabilityKnown)], - val toList: A => List[Any], - val unsafeSet: (PreparedStatement, Int, A) => Unit, - val unsafeUpdate: (ResultSet, Int, A) => Unit -) { - - lazy val length = puts.length - - def set(n: Int, a: A): PreparedStatementIO[Unit] = +sealed trait Write[A] { + def puts: List[(Put[?], NullabilityKnown)] + def toList(a: A): List[Any] + def unsafeSet(ps: PreparedStatement, startIdx: Int, a: A): Unit + def unsafeUpdate(rs: ResultSet, startIdx: Int, a: A): Unit + def toOpt: Write[Option[A]] + def length: Int + + final def set(n: Int, a: A): PreparedStatementIO[Unit] = IFPS.raw(unsafeSet(_, n, a)) - def update(n: Int, a: A): ResultSetIO[Unit] = + final def update(n: Int, a: A): ResultSetIO[Unit] = IFRS.raw(unsafeUpdate(_, n, a)) - def contramap[B](f: B => A): Write[B] = - new Write[B]( - puts, - b => toList(f(b)), - (ps, n, a) => unsafeSet(ps, n, f(a)), - (rs, n, a) => unsafeUpdate(rs, n, f(a)) - ) {} - - def product[B](fb: Write[B]): Write[(A, B)] = - new Write[(A, B)]( - puts ++ fb.puts, - { case (a, b) => toList(a) ++ fb.toList(b) }, - { case (ps, n, (a, b)) => unsafeSet(ps, n, a); fb.unsafeSet(ps, n + length, b) }, - { case (rs, n, (a, b)) => unsafeUpdate(rs, n, a); fb.unsafeUpdate(rs, n + length, b) } - ) {} - - /** Given a value of type `A` and an appropriately parameterized SQL string we can construct a `Fragment`. If `sql` is - * unspecified a comma-separated list of `length` placeholders will be used. - */ + final def contramap[B](f: B => A): Write[B] = new Write.Composite[B](List(this), b => List(f(b))) + + final def product[B](fb: Write[B]): Write[(A, B)] = { + new Write.Composite[(A, B)](List(this, fb), tuple => List(tuple._1, tuple._2)) + } + def toFragment(a: A, sql: String = List.fill(length)("?").mkString(",")): Fragment = { val elems: List[Elem] = (puts zip toList(a)).map { case ((p: Put[a], NoNulls), a) => Elem.Arg(a.asInstanceOf[a], p) @@ -84,77 +71,132 @@ sealed abstract class Write[A]( } Fragment(sql, elems, None) } - } -object Write extends WritePlatform { - - def apply[A]( - puts: List[(Put[?], NullabilityKnown)], - toList: A => List[Any], - unsafeSet: (PreparedStatement, Int, A) => Unit, - unsafeUpdate: (ResultSet, Int, A) => Unit - ): Write[A] = new Write(puts, toList, unsafeSet, unsafeUpdate) {} - +object Write extends LowerPriority1Write { def apply[A](implicit A: Write[A]): Write[A] = A - def derived[A](implicit ev: MkWrite[A]): Write[A] = ev + def derived[A](implicit + @implicitNotFound( + "Cannot derive Write instance. Please check that each field in the case class has a Write instance or can derive one") + ev: Derived[MkWrite[A]] + ): Write[A] = ev.instance - trait Auto { - implicit def deriveWrite[A](implicit ev: MkWrite[A]): Write[A] = ev - } + trait Auto extends MkWriteInstances implicit val WriteContravariantSemigroupal: ContravariantSemigroupal[Write] = new ContravariantSemigroupal[Write] { - def contramap[A, B](fa: Write[A])(f: B => A) = fa.contramap(f) - def product[A, B](fa: Write[A], fb: Write[B]) = fa.product(fb) + def contramap[A, B](fa: Write[A])(f: B => A): Write[B] = fa.contramap(f) + def product[A, B](fa: Write[A], fb: Write[B]): Write[(A, B)] = fa.product(fb) } - private def doNothing[P, A](p: P, i: Int, a: A): Unit = { - void(p, i, a) + implicit val unitWrite: Write[Unit] = + new Composite[Unit](Nil, _ => List.empty) + + /** Simple instance wrapping a Put. i.e. single column non-null value */ + class Single[A](put: Put[A]) extends Write[A] { + override val length: Int = 1 + + override def unsafeSet(ps: PreparedStatement, startIdx: Int, a: A): Unit = + put.unsafeSetNonNullable(ps, startIdx, a) + + override def unsafeUpdate(rs: ResultSet, startIdx: Int, a: A): Unit = + put.unsafeUpdateNonNullable(rs, startIdx, a) + + override def puts: List[(Put[?], NullabilityKnown)] = List(put -> Nullability.NoNulls) + + override def toList(a: A): List[Any] = List(a) + + override def toOpt: Write[Option[A]] = new SingleOpt(put) } - private def empty[A](a: A): List[Any] = { - void(a) - List.empty + /** Simple instance wrapping a Put. i.e. single column nullable value */ + class SingleOpt[A](put: Put[A]) extends Write[Option[A]] { + override val length: Int = 1 + + override def unsafeSet(ps: PreparedStatement, startIdx: Int, a: Option[A]): Unit = + put.unsafeSetNullable(ps, startIdx, a) + + override def unsafeUpdate(rs: ResultSet, startIdx: Int, a: Option[A]): Unit = + put.unsafeUpdateNullable(rs, startIdx, a) + + override def puts: List[(Put[?], NullabilityKnown)] = List(put -> Nullability.Nullable) + + override def toList(a: Option[A]): List[Any] = List(a) + + override def toOpt: Write[Option[Option[A]]] = new Composite[Option[Option[A]]](List(this), x => List(x.flatten)) } - implicit val unitComposite: Write[Unit] = - Write[Unit](Nil, empty[Unit](_), doNothing[PreparedStatement, Unit](_, _, _), doNothing[ResultSet, Unit](_, _, _)) - - implicit val optionUnit: Write[Option[Unit]] = - Write[Option[Unit]]( - Nil, - empty[Option[Unit]](_), - doNothing[PreparedStatement, Option[Unit]](_, _, _), - doNothing[ResultSet, Option[Unit]](_, _, _)) - - implicit def fromPut[A](implicit P: Put[A]): Write[A] = - new Write[A]( - List((P, NoNulls)), - a => List(a), - (ps, n, a) => P.unsafeSetNonNullable(ps, n, a), - (rs, n, a) => P.unsafeUpdateNonNullable(rs, n, a) - ) {} - - implicit def fromPutOption[A](implicit P: Put[A]): Write[Option[A]] = - new Write[Option[A]]( - List((P, Nullable)), - a => List(a), - (ps, n, a) => P.unsafeSetNullable(ps, n, a), - (rs, n, a) => P.unsafeUpdateNullable(rs, n, a) - ) {} + /** A Write instance consists of multiple underlying Write instances */ + class Composite[A]( + writeInstances: List[Write[?]], + deconstruct: A => List[Any] + ) extends Write[A] { + override lazy val length: Int = writeInstances.map(_.length).sum + + // Make the types match up with deconstruct + private val anyWrites: List[Write[Any]] = writeInstances.asInstanceOf[List[Write[Any]]] + + override def unsafeSet(ps: PreparedStatement, startIdx: Int, a: A): Unit = { + val parts = deconstruct(a) + var idx = startIdx + anyWrites.zip(parts).foreach { case (w, p) => + w.unsafeSet(ps, idx, p) + idx += w.length + } + } + + override def unsafeUpdate(rs: ResultSet, startIdx: Int, a: A): Unit = { + val parts = deconstruct(a) + var idx = startIdx + anyWrites.zip(parts).foreach { case (w, p) => + w.unsafeUpdate(rs, idx, p) + idx += w.length + } + } + override lazy val puts: List[(Put[?], NullabilityKnown)] = writeInstances.flatMap(_.puts) + + override def toList(a: A): List[Any] = + anyWrites.zip(deconstruct(a)).flatMap { case (w, p) => w.toList(p) } + + override def toOpt: Write[Option[A]] = new Composite[Option[A]]( + writeInstances.map(_.toOpt), + { + case Some(a) => deconstruct(a).map(Some(_)) + case None => List.fill(writeInstances.length)(None) // All Nones + } + ) + } } -final class MkWrite[A]( - override val puts: List[(Put[?], NullabilityKnown)], - override val toList: A => List[Any], - override val unsafeSet: (PreparedStatement, Int, A) => Unit, - override val unsafeUpdate: (ResultSet, Int, A) => Unit -) extends Write[A](puts, toList, unsafeSet, unsafeUpdate) -object MkWrite extends MkWritePlatform { +trait LowerPriority1Write extends LowerPriority2Write { - def lift[A](w: Write[A]): MkWrite[A] = - new MkWrite[A](w.puts, w.toList, w.unsafeSet, w.unsafeUpdate) + implicit def optionalFromWrite[A](implicit write: Write[A]): Write[Option[A]] = + write.toOpt } + +trait LowerPriority2Write extends WritePlatform { + implicit def fromPut[A](implicit put: Put[A]): Write[A] = + new Write.Single(put) + + implicit def fromPutOption[A](implicit put: Put[A]): Write[Option[A]] = + new Write.SingleOpt(put) +} + +trait LowestPriorityWrite { + implicit def fromDerived[A](implicit ev: Derived[Write[A]]): Write[A] = ev.instance +} + +final class MkWrite[A](val instance: Write[A]) extends Write[A] { + override def puts: List[(Put[?], NullabilityKnown)] = instance.puts + override def toList(a: A): List[Any] = instance.toList(a) + override def unsafeSet(ps: PreparedStatement, startIdx: Int, a: A): Unit = instance.unsafeSet(ps, startIdx, a) + override def unsafeUpdate(rs: ResultSet, startIdx: Int, a: A): Unit = instance.unsafeUpdate(rs, startIdx, a) + override def toOpt: Write[Option[A]] = instance.toOpt + override def length: Int = instance.length +} + +object MkWrite extends MkWriteInstances + +trait MkWriteInstances extends MkWritePlatform diff --git a/modules/core/src/test/scala-2/doobie/util/GetSuitePlatform.scala b/modules/core/src/test/scala-2/doobie/util/GetSuitePlatform.scala deleted file mode 100644 index a6daa5f06..000000000 --- a/modules/core/src/test/scala-2/doobie/util/GetSuitePlatform.scala +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util -import doobie.testutils.{VoidExtensions, assertContains} -import doobie.testutils.TestClasses.{CCIntString, PlainObj, CCAnyVal} - -trait GetSuitePlatform { self: munit.FunSuite => - - test("Get can be auto derived for unary products (AnyVal)") { - import doobie.generic.auto.* - - Get[CCAnyVal].void - } - - test("Get can be explicitly derived for unary products (AnyVal)") { - Get.derived[CCAnyVal].void - } - - test("Get should not be derived for non-unary products") { - import doobie.generic.auto.* - - assertContains(compileErrors("Get[CCIntString]"), "implicit value") - assertContains(compileErrors("Get[(Int, Int)]"), "implicit value") - assertContains(compileErrors("Get[PlainObj.type]"), "implicit value") - } - -} diff --git a/modules/core/src/test/scala-2/doobie/util/PutSuitePlatform.scala b/modules/core/src/test/scala-2/doobie/util/PutSuitePlatform.scala deleted file mode 100644 index a7eda5bed..000000000 --- a/modules/core/src/test/scala-2/doobie/util/PutSuitePlatform.scala +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util -import doobie.testutils.{VoidExtensions, assertContains} -import doobie.testutils.TestClasses.{CCIntString, PlainObj, CCAnyVal} - -trait PutSuitePlatform { self: munit.FunSuite => - test("Put can be auto derived for unary products (AnyVal)") { - import doobie.generic.auto.* - - Put[CCAnyVal].void - } - - test("Put can be explicitly derived for unary products (AnyVal)") { - Put.derived[CCAnyVal].void - } - - test("Put should not be derived for non-unary products") { - import doobie.generic.auto.* - - assertContains(compileErrors("Put[CCIntString]"), "implicit value") - assertContains(compileErrors("Put[(Int, Int)]"), "implicit value") - assertContains(compileErrors("Put[PlainObj.type]"), "implicit value") - } - -} diff --git a/modules/core/src/test/scala-2/doobie/util/QueryLogSuitePlatform.scala b/modules/core/src/test/scala-2/doobie/util/QueryLogSuitePlatform.scala index b82855515..539e82b4e 100644 --- a/modules/core/src/test/scala-2/doobie/util/QueryLogSuitePlatform.scala +++ b/modules/core/src/test/scala-2/doobie/util/QueryLogSuitePlatform.scala @@ -8,7 +8,6 @@ import doobie.util.log.{Parameters, ProcessingFailure, Success} import shapeless._ trait QueryLogSuitePlatform { self: QueryLogSuite => - import doobie.generic.auto._ test("[Query] n-arg success") { val Sql = "select 1 where ? = ?" diff --git a/modules/core/src/test/scala-3/doobie/util/PutSuitePlatform.scala b/modules/core/src/test/scala-3/doobie/util/PutSuitePlatform.scala deleted file mode 100644 index 96305ccf2..000000000 --- a/modules/core/src/test/scala-3/doobie/util/PutSuitePlatform.scala +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2013-2020 Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package doobie.util -import doobie.testutils.assertContains - -import scala.annotation.nowarn - -trait PutSuitePlatform { self: munit.FunSuite => - - test("Put should be derived for unary products (AnyVal)".ignore) {} - - test("Put should not be derived for non-unary products") { - import doobie.generic.auto.* - import doobie.testutils.TestClasses.{CCIntString, PlainObj} - - assertContains(compileErrors("Put[CCIntString]"), "No given instance") - assertContains(compileErrors("Put[(Int, Int)]"), "No given instance") - assertContains(compileErrors("Put[PlainObj.type]"), "No given instance") - }: @nowarn("msg=.*unused.*") - -} diff --git a/modules/core/src/test/scala-3/doobie/util/QueryLogSuitePlatform.scala b/modules/core/src/test/scala-3/doobie/util/QueryLogSuitePlatform.scala index 6943cfcfe..81aa8532c 100644 --- a/modules/core/src/test/scala-3/doobie/util/QueryLogSuitePlatform.scala +++ b/modules/core/src/test/scala-3/doobie/util/QueryLogSuitePlatform.scala @@ -7,7 +7,6 @@ package doobie.util import doobie.util.log.{Parameters, Success, ProcessingFailure} trait QueryLogSuitePlatform { self: QueryLogSuite => - import doobie.generic.auto.* test("[Query] n-arg success") { val Sql = "select 1 where ? = ?" diff --git a/modules/core/src/test/scala/doobie/util/GetSuite.scala b/modules/core/src/test/scala/doobie/util/GetSuite.scala index e5d6c862e..35b985b44 100644 --- a/modules/core/src/test/scala/doobie/util/GetSuite.scala +++ b/modules/core/src/test/scala/doobie/util/GetSuite.scala @@ -9,7 +9,7 @@ import doobie.enumerated.JdbcType import doobie.testutils.VoidExtensions import doobie.util.transactor.Transactor -class GetSuite extends munit.FunSuite with GetSuitePlatform { +class GetSuite extends munit.FunSuite { case class X(x: Int) case class Q(x: String) @@ -22,35 +22,13 @@ class GetSuite extends munit.FunSuite with GetSuitePlatform { Get[String].void } - test("Get should be auto derived for unary products") { - import doobie.generic.auto.* - - Get[X].void - Get[Q].void - } - - test("Get is not auto derived without an import") { - compileErrors("Get[X]").void - compileErrors("Get[Q]").void - } - - test("Get can be manually derived for unary products") { - Get.derived[X].void - Get.derived[Q].void - } - - test("Get should not be derived for non-unary products") { - compileErrors("Get[Z]").void - compileErrors("Get[(Int, Int)]").void - compileErrors("Get[S.type]").void - } - } final case class Foo(s: String) final case class Bar(n: Int) -class GetDBSuite extends munit.CatsEffectSuite { +class GetDBSuite extends munit.FunSuite { + import cats.effect.unsafe.implicits.global import doobie.syntax.all.* lazy val xa = Transactor.fromDriverManager[IO]( @@ -63,42 +41,27 @@ class GetDBSuite extends munit.CatsEffectSuite { // Both of these will fail at runtime if called with a null value, we check that this is // avoided below. - implicit def FooMeta: Get[Foo] = Get[String].map(s => Foo(s.toUpperCase)) - implicit def barMeta: Get[Bar] = Get[Int].temap(n => if (n == 0) Left("cannot be 0") else Right(Bar(n))) + implicit val FooMeta: Get[Foo] = Get[String].map(s => Foo(s.toUpperCase)) + implicit val barMeta: Get[Bar] = Get[Int].temap(n => if (n == 0) Left("cannot be 0") else Right(Bar(n))) test("Get should not allow map to observe null on the read side (AnyRef)") { - val x = sql"select null".query[Option[Foo]].unique.transact(xa) - x.assertEquals(None) + val x = sql"select null".query[Option[Foo]].unique.transact(xa).unsafeRunSync() + assertEquals(x, None) } test("Get should read non-null value (AnyRef)") { - val x = sql"select 'abc'".query[Foo].unique.transact(xa) - x.assertEquals(Foo("ABC")) + val x = sql"select 'abc'".query[Foo].unique.transact(xa).unsafeRunSync() + assertEquals(x, Foo("ABC")) } test("Get should error when reading a NULL into an unlifted Scala type (AnyRef)") { - def x = sql"select null".query[Foo].unique.transact(xa).attempt - x.assertEquals(Left(doobie.util.invariant.NonNullableColumnRead(1, JdbcType.Char))) - } - - test("Get should not allow map to observe null on the read side (AnyVal)") { - val x = sql"select null".query[Option[Bar]].unique.transact(xa) - x.assertEquals(None) - } - - test("Get should read non-null value (AnyVal)") { - val x = sql"select 1".query[Bar].unique.transact(xa) - x.assertEquals(Bar(1)) - } - - test("Get should error when reading a NULL into an unlifted Scala type (AnyVal)") { - def x = sql"select null".query[Bar].unique.transact(xa).attempt - x.assertEquals(Left(doobie.util.invariant.NonNullableColumnRead(1, JdbcType.Integer))) + def x = sql"select null".query[Foo].unique.transact(xa).attempt.unsafeRunSync() + assertEquals(x, Left(doobie.util.invariant.NonNullableColumnRead(1, JdbcType.Char))) } test("Get should error when reading an incorrect value") { - def x = sql"select 0".query[Bar].unique.transact(xa).attempt - x.assertEquals(Left(doobie.util.invariant.InvalidValue[Int, Bar](0, "cannot be 0"))) + def x = sql"select 0".query[Bar].unique.transact(xa).attempt.unsafeRunSync() + assertEquals(x, Left(doobie.util.invariant.InvalidValue[Int, Bar](0, "cannot be 0"))) } } diff --git a/modules/core/src/test/scala/doobie/util/PutSuite.scala b/modules/core/src/test/scala/doobie/util/PutSuite.scala index a8548e211..fd2f4b0c4 100644 --- a/modules/core/src/test/scala/doobie/util/PutSuite.scala +++ b/modules/core/src/test/scala/doobie/util/PutSuite.scala @@ -8,7 +8,7 @@ import cats.effect.IO import doobie.testutils.VoidExtensions import doobie.util.transactor.Transactor -class PutSuite extends munit.FunSuite with PutSuitePlatform { +class PutSuite extends munit.FunSuite { case class X(x: Int) case class Q(x: String) @@ -34,20 +34,4 @@ class PutSuite extends munit.FunSuite with PutSuitePlatform { Put[String].void } - test("Put should be auto derived for unary products") { - import doobie.generic.auto.* - - Put[X].void - Put[Q].void - } - - test("Put is not auto derived without an import") { - compileErrors("Put[X]").void - compileErrors("Put[Q]").void - } - - test("Put can be manually derived for unary products") { - Put.derived[X].void - Put.derived[Q].void - } } diff --git a/modules/core/src/test/scala/doobie/util/ReadSuite.scala b/modules/core/src/test/scala/doobie/util/ReadSuite.scala index 69d76edc5..327fedefc 100644 --- a/modules/core/src/test/scala/doobie/util/ReadSuite.scala +++ b/modules/core/src/test/scala/doobie/util/ReadSuite.scala @@ -8,9 +8,17 @@ import cats.effect.IO import doobie.util.TestTypes.* import doobie.util.transactor.Transactor import doobie.testutils.VoidExtensions -import munit.CatsEffectSuite +import doobie.syntax.all.* +import doobie.{ConnectionIO, Query} +import doobie.util.analysis.{Analysis, ColumnMisalignment, ColumnTypeError, ColumnTypeWarning, NullabilityMisalignment} +import doobie.util.fragment.Fragment +import munit.Location -class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { +import scala.annotation.nowarn + +class ReadSuite extends munit.FunSuite with ReadSuitePlatform { + + import cats.effect.unsafe.implicits.global val xa = Transactor.fromDriverManager[IO]( driver = "org.h2.Driver", @@ -20,34 +28,19 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { logHandler = None ) - test("Read should exist for some fancy types") { - import doobie.generic.auto.* - - Read[Int].void - Read[(Int, Int)].void - Read[(Int, Int, String)].void - Read[(Int, (Int, String))].void - } - - test("Read is not auto derived for case classes without importing auto derive import") { - assert(compileErrors("Read[LenStr1]").contains("Cannot find or construct")) - } - - test("Read should not be derivable for case objects") { - assert(compileErrors("Read[CaseObj.type]").contains("Cannot find or construct")) - assert(compileErrors("Read[Option[CaseObj.type]]").contains("Cannot find or construct")) - } - - test("Read is auto derived for tuples without an import") { + test("Read is available for tuples without an import when all elements have a Write instance") { Read[(Int, Int)].void Read[(Int, Int, String)].void Read[(Int, (Int, String))].void Read[Option[(Int, Int)]].void Read[Option[(Int, Option[(String, Int)])]].void + + // But shouldn't automatically derive anything that doesn't already have a Read instance + assert(compileErrors("Read[(Int, TrivialCaseClass)]").contains("Cannot find or construct")) } - test("Read is still auto derived for tuples when import is present (no ambiguous implicits)") { + test("Read is still auto derived for tuples when import is present (no ambiguous implicits) ") { import doobie.generic.auto.* Read[(Int, Int)].void Read[(Int, Int, String)].void @@ -55,45 +48,86 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { Read[Option[(Int, Int)]].void Read[Option[(Int, Option[(String, Int)])]].void + + Read[(ComplexCaseClass, Int)].void + Read[(Int, ComplexCaseClass)].void } - test("Read can be manually derived") { - Read.derived[LenStr1] + test("Read is not auto derived for case classes without importing auto derive import") { + assert(compileErrors("Read[TrivialCaseClass]").contains("Cannot find or construct")) } - test("Read should exist for Unit") { - import doobie.generic.auto.* + test("Semiauto derivation selects custom Read instances when available") { + implicit val i0: Read[HasCustomReadWrite0] = Read.derived[HasCustomReadWrite0] + assertEquals(i0.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y")) - Read[Unit] - assertEquals(Read[(Int, Unit)].length, 1) - } + implicit val i1: Read[HasCustomReadWrite1] = Read.derived[HasCustomReadWrite1] + assertEquals(i1.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R"))) - test("Read should exist for option of some fancy types") { - import doobie.generic.auto.* + implicit val iOpt0: Read[HasOptCustomReadWrite0] = Read.derived[HasOptCustomReadWrite0] + assertEquals(iOpt0.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y")) - Read[Option[Int]].void - Read[Option[(Int, Int)]].void - Read[Option[(Int, Int, String)]].void - Read[Option[(Int, (Int, String))]].void - Read[Option[(Int, Option[(Int, String)])]].void - Read[ComplexCaseClass].void + implicit val iOpt1: Read[HasOptCustomReadWrite1] = Read.derived[HasOptCustomReadWrite1] + assertEquals(iOpt1.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R")))) } - test("Read should exist for option of Unit") { - import doobie.generic.auto.* + test("Semiauto derivation selects custom Get instances to use for Read when available") { + implicit val i0: Read[HasCustomGetPut0] = Read.derived[HasCustomGetPut0] + assertEquals(i0.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y")) + + implicit val i1: Read[HasCustomGetPut1] = Read.derived[HasCustomGetPut1] + assertEquals(i1.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G"))) - Read[Option[Unit]].void - assertEquals(Read[Option[(Int, Unit)]].length, 1).void + implicit val iOpt0: Read[HasOptCustomGetPut0] = Read.derived[HasOptCustomGetPut0] + assertEquals(iOpt0.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y")) + + implicit val iOpt1: Read[HasOptCustomGetPut1] = Read.derived[HasOptCustomGetPut1] + assertEquals(iOpt1.length, 2) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G")))) } - test("Read should select multi-column instance by default") { - import doobie.generic.auto.* + test("Automatic derivation selects custom Read instances when available") { + import doobie.implicits.* - assertEquals(Read[LenStr1].length, 2).void + insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y")) + insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R"))) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y")) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R")))) } - test("Read should select 1-column instance when available") { - assertEquals(Read[LenStr2].length, 1).void + test("Automatic derivation selects custom Get instances to use for Read when available") { + import doobie.implicits.* + insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y")) + insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G"))) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y")) + insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G")))) + } + + test("Read should not be derivable for case objects") { + val expectedDeriveError = + if (util.Properties.versionString.startsWith("version 2.12")) + "could not find implicit" + else + "Cannot derive" + assert(compileErrors("Read.derived[CaseObj.type]").contains(expectedDeriveError)) + assert(compileErrors("Read.derived[Option[CaseObj.type]]").contains(expectedDeriveError)) + + import doobie.implicits.* + assert(compileErrors("Read[CaseObj.type]").contains("not find or construct")) + assert(compileErrors("Read[Option[CaseObj.type]]").contains("not find or construct")) + }: @nowarn("msg=.*(u|U)nused import.*") + + test("Read should exist for Unit/Option[Unit]") { + assertEquals(Read[Unit].length, 0) + assertEquals(Read[Option[Unit]].length, 0) + assertEquals(Read[(Int, Unit)].length, 1) } test(".product should product the correct ordering of gets") { @@ -104,7 +138,17 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { val p = readInt.product(readString) - assertEquals(p.gets, readInt.gets ++ readString.gets) + assertEquals(p.gets, (readInt.gets ++ readString.gets)) + } + + test(".map should correctly transform the value") { + import doobie.implicits.* + implicit val r: Read[WrappedSimpleCaseClass] = Read[SimpleCaseClass].map(s => + WrappedSimpleCaseClass( + s.copy(s = "custom") + )) + + insertTuple3AndCheckRead((1, "s1", "s2"), WrappedSimpleCaseClass(SimpleCaseClass(Some(1), "custom", Some("s2")))) } /* @@ -116,13 +160,12 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { val frag = sql"SELECT 1, NULL, 3, NULL" val q1 = frag.query[Option[(Int, Option[Int], Int, Option[Int])]].to[List] - // This result doesn't seem ideal, because we should know that Int isn't - // nullable, so the correct result is Some((1, None, 3, None)) - // But with how things are wired at the moment this isn't possible - q1.transact(xa).assertEquals(List(None)) + val o1 = q1.transact(xa).unsafeRunSync() + assertEquals(o1, List(Some((1, None, 3, None)))) val q2 = frag.query[Option[(Int, Int, Int, Int)]].to[List] - q2.transact(xa).assertEquals(List(None)) + val o2 = q2.transact(xa).unsafeRunSync() + assertEquals(o2, List(None)) } test("Read should read correct columns for instances with Option (Some)") { @@ -130,10 +173,12 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { val frag = sql"SELECT 1, 2, 3, 4" val q1 = frag.query[Option[(Int, Option[Int], Int, Option[Int])]].to[List] - q1.transact(xa).assertEquals(List(Some((1, Some(2), 3, Some(4))))) + val o1 = q1.transact(xa).unsafeRunSync() + assertEquals(o1, List(Some((1, Some(2), 3, Some(4))))) val q2 = frag.query[Option[(Int, Int, Int, Int)]].to[List] - q2.transact(xa).assertEquals(List(Some((1, 2, 3, 4)))) + val o2 = q2.transact(xa).unsafeRunSync() + assertEquals(o2, List(Some((1, 2, 3, 4)))) } test("Read should select correct columns when combined with `ap`") { @@ -145,7 +190,10 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { val c = (r, r, r, r, r).tupled val q = sql"SELECT 1, 2, 3, 4, 5".query(using c).to[List] - q.transact(xa).assertEquals(List((1, 2, 3, 4, 5))) + + val o = q.transact(xa).unsafeRunSync() + + assertEquals(o, List((1, 2, 3, 4, 5))) } test("Read should select correct columns when combined with `product`") { @@ -155,7 +203,130 @@ class ReadSuite extends CatsEffectSuite with ReadSuitePlatform { val r = Read[Int].product(Read[Int].product(Read[Int])) val q = sql"SELECT 1, 2, 3".query(using r).to[List] - q.transact(xa).assertEquals(List((1, (2, 3)))) + val o = q.transact(xa).unsafeRunSync() + + assertEquals(o, List((1, (2, 3)))) + } + + test("Read typechecking should work for Tuples") { + val frag = sql"SELECT 1, 's', 3.0 :: DOUBLE" + + assertSuccessTypecheckRead[(Int, String, Double)](frag) + assertSuccessTypecheckRead[(Int, (String, Double))](frag) + assertSuccessTypecheckRead[((Int, String), Double)](frag) + + assertSuccessTypecheckRead[(Int, Option[String], Double)](frag) + assertSuccessTypecheckRead[(Option[Int], Option[(String, Double)])](frag) + assertSuccessTypecheckRead[Option[((Int, String), Double)]](frag) + + assertWarnedTypecheckRead[(Boolean, String, Double)](frag) + + assertMisalignedTypecheckRead[(Int, String)](frag) + assertMisalignedTypecheckRead[(Int, String, Double, Int)](frag) + + } + + test("Read typechecking should work for case classes") { + implicit val rscc: Read[SimpleCaseClass] = Read.derived[SimpleCaseClass] + implicit val rccc: Read[ComplexCaseClass] = Read.derived[ComplexCaseClass] + implicit val rwscc: Read[WrappedSimpleCaseClass] = + rscc.map(WrappedSimpleCaseClass.apply) // Test map doesn't break typechecking + + assertSuccessTypecheckRead( + sql"create table tab(c1 int, c2 varchar not null, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis) + ) + assertSuccessTypecheckRead( + sql"create table tab(c1 int, c2 varchar not null, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[WrappedSimpleCaseClass].analysis) + ) + + assertSuccessTypecheckRead( + sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[Option[SimpleCaseClass]].analysis) + ) + assertSuccessTypecheckRead( + sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[Option[WrappedSimpleCaseClass]].analysis) + ) + + assertTypeErrorTypecheckRead( + sql"create table tab(c1 binary, c2 varchar not null, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis) + ) + + assertMisalignedNullabilityTypecheckRead( + sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ => + sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis) + ) + + assertSuccessTypecheckRead( + sql"create table tab(c1 int, c2 varchar not null, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)" + .update.run.flatMap(_ => + sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis) + ) + + assertTypeErrorTypecheckRead( + sql"create table tab(c1 binary, c2 varchar not null, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)" + .update.run.flatMap(_ => + sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis) + ) + + assertMisalignedNullabilityTypecheckRead( + sql"create table tab(c1 int, c2 varchar, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)" + .update.run.flatMap(_ => + sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis) + ) + + } + + private def insertTuple3AndCheckRead[Tup <: (?, ?, ?): Write, A: Read](in: Tup, expectedOut: A)(implicit + loc: Location + ): Unit = { + val res = Query[Tup, A]("SELECT ?, ?, ?").unique(in).transact(xa) + .unsafeRunSync() + assertEquals(res, expectedOut) + } + + private def insertTuple2AndCheckRead[Tup <: (?, ?): Write, A: Read](in: Tup, expectedOut: A)(implicit + loc: Location + ): Unit = { + val res = Query[Tup, A]("SELECT ?, ?").unique(in).transact(xa) + .unsafeRunSync() + assertEquals(res, expectedOut) + } + + private def assertSuccessTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + assertEquals(analysisResult.columnAlignmentErrors, Nil) + } + + private def assertSuccessTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = { + assertSuccessTypecheckRead(frag.query[A].analysis) + } + + private def assertWarnedTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = { + val analysisResult = frag.query[A].analysis.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[ColumnTypeWarning])) + } + + private def assertTypeErrorTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[ColumnTypeError])) + } + + private def assertMisalignedNullabilityTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[NullabilityMisalignment])) + } + + private def assertMisalignedTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = { + val analysisResult = frag.query[A].analysis.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[ColumnMisalignment])) } } diff --git a/modules/core/src/test/scala/doobie/util/TestTypes.scala b/modules/core/src/test/scala/doobie/util/TestTypes.scala index 8607103f4..d0217cd98 100644 --- a/modules/core/src/test/scala/doobie/util/TestTypes.scala +++ b/modules/core/src/test/scala/doobie/util/TestTypes.scala @@ -4,20 +4,36 @@ package doobie.util -import doobie.util.meta.Meta - object TestTypes { - case class LenStr1(n: Int, s: String) - - case class LenStr2(n: Int, s: String) - object LenStr2 { - implicit val LenStrMeta: Meta[LenStr2] = - Meta[String].timap(s => LenStr2(s.length, s))(_.s) - } - case object CaseObj + case class TrivialCaseClass(i: Int) case class SimpleCaseClass(i: Option[Int], s: String, os: Option[String]) case class ComplexCaseClass(sc: SimpleCaseClass, osc: Option[SimpleCaseClass], i: Option[Int], s: String) + case class WrappedSimpleCaseClass(sc: SimpleCaseClass) + + case class HasCustomReadWrite0(c: CustomReadWrite, s: String) + case class HasCustomReadWrite1(s: String, c: CustomReadWrite) + case class HasOptCustomReadWrite0(c: Option[CustomReadWrite], s: String) + case class HasOptCustomReadWrite1(s: String, c: Option[CustomReadWrite]) + + case class CustomReadWrite(s: String) + + object CustomReadWrite { + implicit val write: Write[CustomReadWrite] = Write.fromPut[String].contramap(a => a.s.concat("_W")) + implicit val read: Read[CustomReadWrite] = Read.fromGet[String].map(str => CustomReadWrite(str.concat("_R"))) + } + + case class HasCustomGetPut0(c: CustomGetPut, s: String) + case class HasCustomGetPut1(s: String, c: CustomGetPut) + case class HasOptCustomGetPut0(c: Option[CustomGetPut], s: String) + case class HasOptCustomGetPut1(s: String, c: Option[CustomGetPut]) + + case class CustomGetPut(s: String) + + object CustomGetPut { + implicit val put: Put[CustomGetPut] = Put[String].contramap(a => a.s.concat("_P")) + implicit val get: Get[CustomGetPut] = Get[String].tmap(a => CustomGetPut(a.concat("_G"))) + } } diff --git a/modules/core/src/test/scala/doobie/util/WriteSuite.scala b/modules/core/src/test/scala/doobie/util/WriteSuite.scala index 382e636b1..564dd6ff4 100644 --- a/modules/core/src/test/scala/doobie/util/WriteSuite.scala +++ b/modules/core/src/test/scala/doobie/util/WriteSuite.scala @@ -4,13 +4,18 @@ package doobie.util -import doobie.Transactor -import doobie.Update +import doobie.{ConnectionIO, Query, Transactor, Update} import doobie.util.TestTypes.* import cats.effect.IO +import cats.effect.unsafe.implicits.global import doobie.testutils.VoidExtensions +import doobie.syntax.all.* +import doobie.util.analysis.{Analysis, ParameterMisalignment, ParameterTypeError} +import munit.Location -class WriteSuite extends munit.CatsEffectSuite with WriteSuitePlatform { +import scala.annotation.nowarn + +class WriteSuite extends munit.FunSuite with WriteSuitePlatform { val xa: Transactor[IO] = Transactor.fromDriverManager[IO]( driver = "org.h2.Driver", @@ -20,83 +25,109 @@ class WriteSuite extends munit.CatsEffectSuite with WriteSuitePlatform { logHandler = None ) - test("Write should exist for some fancy types") { - import doobie.generic.auto.* - - Write[Int].void - Write[(Int, Int)].void - Write[(Int, Int, String)].void - Write[(Int, (Int, String))].void - Write[ComplexCaseClass].void - } - - test("Write is auto derived for tuples without an import") { + test("Write is available for tuples without an import when all elements have a Write instance") { Write[(Int, Int)].void Write[(Int, Int, String)].void Write[(Int, (Int, String))].void Write[Option[(Int, Int)]].void Write[Option[(Int, Option[(String, Int)])]].void + + // But shouldn't automatically derive anything that doesn't already have a Read instance + assert(compileErrors("Write[(Int, TrivialCaseClass)]").contains("Cannot find or construct")) } test("Write is still auto derived for tuples when import is present (no ambiguous implicits) ") { - import doobie.generic.auto.* + import doobie.implicits.* Write[(Int, Int)].void Write[(Int, Int, String)].void Write[(Int, (Int, String))].void Write[Option[(Int, Int)]].void Write[Option[(Int, Option[(String, Int)])]].void + + Write[(ComplexCaseClass, Int)].void + Write[(Int, ComplexCaseClass)].void } test("Write is not auto derived for case classes") { - assert(compileErrors("Write[LenStr1]").contains("Cannot find or construct")) + assert(compileErrors("Write[TrivialCaseClass]").contains("Cannot find or construct")) } - test("Write should not be derivable for case objects") { - assert(compileErrors("Write[CaseObj.type]").contains("Cannot find or construct")) - assert(compileErrors("Write[Option[CaseObj.type]]").contains("Cannot find or construct")) - } + test("Semiauto derivation selects custom Write instances when available") { + implicit val i0: Write[HasCustomReadWrite0] = Write.derived[HasCustomReadWrite0] + assertEquals(i0.length, 2) + writeAndCheckTuple2(HasCustomReadWrite0(CustomReadWrite("x"), "y"), ("x_W", "y")) - test("Write can be manually derived") { - Write.derived[LenStr1].void - } + implicit val i1: Write[HasCustomReadWrite1] = Write.derived[HasCustomReadWrite1] + assertEquals(i1.length, 2) + writeAndCheckTuple2(HasCustomReadWrite1("x", CustomReadWrite("y")), ("x", "y_W")) - test("Write should exist for Unit") { - import doobie.generic.auto.* + implicit val iOpt0: Write[HasOptCustomReadWrite0] = Write.derived[HasOptCustomReadWrite0] + assertEquals(iOpt0.length, 2) + writeAndCheckTuple2(HasOptCustomReadWrite0(Some(CustomReadWrite("x")), "y"), ("x_W", "y")) - Write[Unit].void - assertEquals(Write[(Int, Unit)].length, 1) + implicit val iOpt1: Write[HasOptCustomReadWrite1] = Write.derived[HasOptCustomReadWrite1] + assertEquals(iOpt1.length, 2) + writeAndCheckTuple2(HasOptCustomReadWrite1("x", Some(CustomReadWrite("y"))), ("x", "y_W")) } - test("Write should exist for option of some fancy types") { - import doobie.generic.auto.* + test("Semiauto derivation selects custom Put instances to use for Write when available") { + implicit val i0: Write[HasCustomGetPut0] = Write.derived[HasCustomGetPut0] + assertEquals(i0.length, 2) + writeAndCheckTuple2(HasCustomGetPut0(CustomGetPut("x"), "y"), ("x_P", "y")) - Write[Option[Int]].void - Write[Option[(Int, Int)]].void - Write[Option[(Int, Int, String)]].void - Write[Option[(Int, (Int, String))]].void - Write[Option[(Int, Option[(Int, String)])]].void - } + implicit val i1: Write[HasCustomGetPut1] = Write.derived[HasCustomGetPut1] + assertEquals(i1.length, 2) + writeAndCheckTuple2(HasCustomGetPut1("x", CustomGetPut("y")), ("x", "y_P")) - test("Write should exist for option of Unit") { - import doobie.generic.auto.* + implicit val iOpt0: Write[HasOptCustomGetPut0] = Write.derived[HasOptCustomGetPut0] + assertEquals(iOpt0.length, 2) + writeAndCheckTuple2(HasOptCustomGetPut0(Some(CustomGetPut("x")), "y"), ("x_P", "y")) - Write[Option[Unit]].void - assertEquals(Write[Option[(Int, Unit)]].length, 1) + implicit val iOpt1: Write[HasOptCustomGetPut1] = Write.derived[HasOptCustomGetPut1] + assertEquals(iOpt1.length, 2) + writeAndCheckTuple2(HasOptCustomGetPut1("x", Some(CustomGetPut("y"))), ("x", "y_P")) } - test("Write should select multi-column instance by default") { - import doobie.generic.auto.* + test("Automatic derivation selects custom Write instances when available") { + import doobie.implicits.* - assertEquals(Write[LenStr1].length, 2) + writeAndCheckTuple2(HasCustomReadWrite0(CustomReadWrite("x"), "y"), ("x_W", "y")) + writeAndCheckTuple2(HasCustomReadWrite1("x", CustomReadWrite("y")), ("x", "y_W")) + writeAndCheckTuple2(HasOptCustomReadWrite0(Some(CustomReadWrite("x")), "y"), ("x_W", "y")) + writeAndCheckTuple2(HasOptCustomReadWrite1("x", Some(CustomReadWrite("y"))), ("x", "y_W")) } - test("Write should select 1-column instance when available") { - assertEquals(Write[LenStr2].length, 1) + test("Automatic derivation selects custom Put instances to use for Write when available") { + import doobie.implicits.* + writeAndCheckTuple2(HasCustomGetPut0(CustomGetPut("x"), "y"), ("x_P", "y")) + writeAndCheckTuple2(HasCustomGetPut1("x", CustomGetPut("y")), ("x", "y_P")) + writeAndCheckTuple2(HasOptCustomGetPut0(Some(CustomGetPut("x")), "y"), ("x_P", "y")) + writeAndCheckTuple2(HasOptCustomGetPut1("x", Some(CustomGetPut("y"))), ("x", "y_P")) } - test("Write should correct set parameters for Option instances ") { + test("Write should not be derivable for case objects") { + val expectedDeriveError = + if (util.Properties.versionString.startsWith("version 2.12")) + "could not find implicit" + else + "Cannot derive" + assert(compileErrors("Write.derived[CaseObj.type]").contains(expectedDeriveError)) + assert(compileErrors("Write.derived[Option[CaseObj.type]]").contains(expectedDeriveError)) + + import doobie.implicits.* + assert(compileErrors("Write[Option[CaseObj.type]]").contains("not find or construct")) + assert(compileErrors("Write[CaseObj.type]").contains("not find or construct")) + }: @nowarn("msg=.*(u|U)nused import.*") + + test("Write should exist for Unit/Option[Unit]") { + assertEquals(Write[Unit].length, 0) + assertEquals(Write[Option[Unit]].length, 0) + assertEquals(Write[(Int, Unit)].length, 1) + } + + test("Write should correctly set parameters for Option instances ") { import doobie.implicits.* (for { _ <- sql"create temp table t1 (a int, b int)".update.run @@ -118,19 +149,126 @@ class WriteSuite extends munit.CatsEffectSuite with WriteSuitePlatform { )) }) .transact(xa) + .unsafeRunSync() } test("Write should yield correct error when Some(null) inserted") { - testNullPut((null, Some("b"))).interceptMessage[RuntimeException]( - "Expected non-nullable param at 1. Use Option to describe nullable values.") + interceptMessage[RuntimeException]("Expected non-nullable param at 2. Use Option to describe nullable values.") { + testNullPut(("a", Some(null))) + } } test("Write should yield correct error when null inserted into non-nullable field") { - testNullPut((null, Some("b"))).interceptMessage[RuntimeException]( - "Expected non-nullable param at 1. Use Option to describe nullable values.") + interceptMessage[RuntimeException]("Expected non-nullable param at 1. Use Option to describe nullable values.") { + testNullPut((null, Some("b"))) + } + } + + test(".contramap correctly transformers the input value") { + import doobie.implicits.* + implicit val w: Write[WrappedSimpleCaseClass] = Write[SimpleCaseClass].contramap(v => + v.sc.copy( + s = "custom" + )) + + writeAndCheckTuple3(WrappedSimpleCaseClass(SimpleCaseClass(Some(1), "s1", Some("s2"))), (1, "custom", "s2")) + } + + test("Write typechecking should work for tuples") { + val createTable = sql"create temp table tab(c1 int, c2 varchar not null, c3 double)".update.run + val createAllNullableTable = sql"create temp table tab(c1 int, c2 varchar, c3 double)".update.run + val insertSql = "INSERT INTO tab VALUES (?,?,?)" + + assertSuccessTypecheckWrite( + createTable.flatMap(_ => Update[(Option[Int], String, Double)](insertSql).analysis)) + assertSuccessTypecheckWrite( + createTable.flatMap(_ => Update[((Option[Int], String), Double)](insertSql).analysis)) + assertSuccessTypecheckWrite( + createTable.flatMap(_ => Update[(Option[Int], String, Option[Double])](insertSql).analysis)) + assertSuccessTypecheckWrite( + createAllNullableTable.flatMap(_ => Update[(Option[Int], Option[String], Option[Double])](insertSql).analysis)) + assertSuccessTypecheckWrite( + createAllNullableTable.flatMap(_ => Update[Option[(Option[Int], String, Double)]](insertSql).analysis)) + assertSuccessTypecheckWrite( + createAllNullableTable.flatMap(_ => Update[Option[(Int, Option[(String, Double)])]](insertSql).analysis)) + + assertMisalignedTypecheckWrite(createTable.flatMap(_ => Update[(Option[Int], String)](insertSql).analysis)) + assertMisalignedTypecheckWrite(createTable.flatMap(_ => + Update[(Option[Int], String, Double, Int)](insertSql).analysis)) + + assertTypeErrorTypecheckWrite( + sql"create temp table tab(c1 binary not null, c2 varchar not null, c3 int)".update.run.flatMap(_ => + Update[(Int, String, Option[Int])](insertSql).analysis) + ) } - private def testNullPut(input: (String, Option[String])): IO[Int] = { + test("Write typechecking should work for case classes") { + implicit val wscc: Write[SimpleCaseClass] = Write.derived[SimpleCaseClass] + implicit val wccc: Write[ComplexCaseClass] = Write.derived[ComplexCaseClass] + implicit val wwscc: Write[WrappedSimpleCaseClass] = + wscc.contramap(_.sc) // Testing contramap doesn't break typechecking + + val createTable = sql"create temp table tab(c1 int, c2 varchar not null, c3 varchar)".update.run + + val insertSimpleSql = "INSERT INTO tab VALUES (?,?,?)" + + assertSuccessTypecheckWrite(createTable.flatMap(_ => Update[SimpleCaseClass](insertSimpleSql).analysis)) + assertSuccessTypecheckWrite(createTable.flatMap(_ => Update[WrappedSimpleCaseClass](insertSimpleSql).analysis)) + + // This shouldn't pass but JDBC driver (at least for h2) doesn't tell us when a parameter should be not-nullable + assertSuccessTypecheckWrite(createTable.flatMap(_ => Update[Option[SimpleCaseClass]](insertSimpleSql).analysis)) + assertSuccessTypecheckWrite(createTable.flatMap(_ => + Update[Option[WrappedSimpleCaseClass]](insertSimpleSql).analysis)) + + val insertComplexSql = "INSERT INTO tab VALUES (?,?,?,?,?,?,?,?)" + + assertSuccessTypecheckWrite( + sql"create temp table tab(c1 int, c2 varchar, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)" + .update.run + .flatMap(_ => Update[ComplexCaseClass](insertComplexSql).analysis) + ) + + assertTypeErrorTypecheckWrite( + sql"create temp table tab(c1 int, c2 varchar, c3 varchar, c4 BINARY, c5 varchar, c6 varchar, c7 int, c8 varchar not null)" + .update.run + .flatMap(_ => Update[ComplexCaseClass](insertComplexSql).analysis) + ) + } + + private def assertSuccessTypecheckWrite(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + assertEquals(analysisResult.parameterAlignmentErrors, Nil) + } + + private def assertMisalignedTypecheckWrite(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.parameterAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[ParameterMisalignment])) + } + + private def assertTypeErrorTypecheckWrite(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = { + val analysisResult = connio.transact(xa).unsafeRunSync() + val errorClasses = analysisResult.parameterAlignmentErrors.map(_.getClass) + assertEquals(errorClasses, List(classOf[ParameterTypeError])) + } + + private def writeAndCheckTuple2[A: Write, Tup <: (?, ?): Read](in: A, expectedOut: Tup)(implicit + loc: Location + ): Unit = { + val res = Query[A, Tup]("SELECT ?, ?").unique(in).transact(xa) + .unsafeRunSync() + assertEquals(res, expectedOut) + } + + private def writeAndCheckTuple3[A: Write, Tup <: (?, ?, ?): Read](in: A, expectedOut: Tup)(implicit + loc: Location + ): Unit = { + val res = Query[A, Tup]("SELECT ?, ?, ?").unique(in).transact(xa) + .unsafeRunSync() + assertEquals(res, expectedOut) + } + + private def testNullPut(input: (String, Option[String])): Int = { import doobie.implicits.* (for { @@ -138,6 +276,9 @@ class WriteSuite extends munit.CatsEffectSuite with WriteSuitePlatform { n <- Update[(String, Option[String])]("insert into t0 (a, b) values (?, ?)").run(input) } yield n) .transact(xa) + .unsafeRunSync() } } + +object WriteSuite {} diff --git a/modules/docs/src/main/mdoc/docs/11-Arrays.md b/modules/docs/src/main/mdoc/docs/11-Arrays.md index 0db0d49d7..e77b578cd 100644 --- a/modules/docs/src/main/mdoc/docs/11-Arrays.md +++ b/modules/docs/src/main/mdoc/docs/11-Arrays.md @@ -60,7 +60,7 @@ val create = (drop *> create).unsafeRunSync() ``` -**doobie** maps SQL array columns to `Array`, `List`, and `Vector` by default. No special handling is required, other than importing the vendor-specific array support above. +**doobie** maps SQL array columns to `Array`, `List`, and `Vector` by default for standard types like `String` or `Int`. No special handling is required, other than importing the vendor-specific array support above. ```scala mdoc:silent case class Person(id: Long, name: String, pets: List[String]) @@ -93,3 +93,49 @@ sql"select array['foo','bar','baz']".query[Option[List[String]]].quick.unsafeRun sql"select array['foo',NULL,'baz']".query[List[Option[String]]].quick.unsafeRunSync() sql"select array['foo',NULL,'baz']".query[Option[List[Option[String]]]].quick.unsafeRunSync() ``` + +### Array of enums + +For reading from and writing to a column that is an array of enum, you can use `doobie.postgres.implicits.arrayOfEnum` +to create a `Meta` instance for your enum type: + +```scala mdoc +import doobie.postgres.implicits.arrayOfEnum + +sealed trait MyEnum + +object MyEnum { + case object Foo extends MyEnum + + case object Bar extends MyEnum + + private val typeName = "myenum" + + def fromStrUnsafe(s: String): MyEnum = s match { + case "foo" => Foo + case "bar" => Bar + case other => throw new RuntimeException(s"Unexpected value '$other' for MyEnum") + } + + def toStr(e: MyEnum): String = e match { + case Foo => "foo" + case Bar => "bar" + } + + implicit val MyEnumArrayMeta: Meta[Array[MyEnum]] = + arrayOfEnum[MyEnum]( + enumTypeName = typeName, + fromStr = fromStrUnsafe, + toStr = toStr + ) + +} +``` + +and you can now map the array of enum column into an `Array[MyEnum]`, `List[MyEnum]`, `Vector[MyEnum]`: + +```scala mdoc +sql"select array['foo', 'bar'] :: myenum[]".query[List[MyEnum]].quick.unsafeRunSync() +``` + +For an example of using an enum type from another schema, please see [OtherEnum.scala](https://github.com/typelevel/doobie/blob/main/modules/postgres/src/test/scala/doobie/postgres/enums/OtherEnum.scala) diff --git a/modules/docs/src/main/mdoc/docs/15-Extensions-PostgreSQL.md b/modules/docs/src/main/mdoc/docs/15-Extensions-PostgreSQL.md index 503c7c888..98623da91 100644 --- a/modules/docs/src/main/mdoc/docs/15-Extensions-PostgreSQL.md +++ b/modules/docs/src/main/mdoc/docs/15-Extensions-PostgreSQL.md @@ -116,7 +116,7 @@ object MyEnum extends Enumeration { val foo, bar = Value } -implicit val MyEnumMeta = pgEnum(MyEnum, "myenum") +implicit val MyEnumMeta: Meta[MyEnum.Value] = pgEnum(MyEnum, "myenum") ``` ```scala mdoc @@ -214,14 +214,14 @@ In addition to the general types above, **doobie** provides mappings for the fol [Geographic types](http://postgis.net/workshops/postgis-intro/geography.html) mappings are defined in a different object (`pgisgeographyimplicits`), to allow geometric types using geodetic coordinates. -``` +```scala import doobie.postgres.pgisgeographyimplicits._ // or define the implicit conversion manually -implicit val geographyPoint: Meta[Point] = - doobie.postgres.pgisgeographyimplicits.PointType +implicit val geographyPoint: Meta[Point] = doobie.postgres.pgisgeographyimplicits.PointType ``` + - Point - Polygon - MultiPoint @@ -242,17 +242,17 @@ The following range types are supported, and map to **doobie** generic `Range[T] - the `tstzrange` schema type maps to `Range[java.time.OffsetDateTime]` Non empty range maps to: -```scala mdoc:silent +```scala case class NonEmptyRange[T](lowerBound: Option[T], upperBound: Option[T], edge: Edge) extends Range[T] ``` Empty range maps to: -```scala mdoc:silent +```scala case object EmptyRange extends Range[Nothing] ``` To control the inclusive and exclusive bounds according to the [PostgreSQL](https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY) specification you need to use a special `Edge` enumeration when creating a `Range`: -```scala mdoc:silent +```scala object Edge { case object ExclExcl extends Edge case object ExclIncl extends Edge @@ -271,13 +271,14 @@ import doobie.postgres.rangeimplicits._ To create for example custom implementation of `Range[Byte]` you can use the public method which declared in the following package `doobie.postgres.rangeimplicits`: -```scala mdoc:silent +```scala def rangeMeta[T](sqlRangeType: String)(encode: T => String, decode: String => T): Meta[Range[T]] ``` For a `Range[Byte]`, the meta and bounds encoder and decoder would appear as follows: ```scala mdoc:silent import doobie.postgres.rangeimplicits._ +import doobie.postgres.types.Range implicit val byteRangeMeta: Meta[Range[Byte]] = rangeMeta[Byte]("int4range")(_.toString, _.toByte) diff --git a/modules/docs/src/main/mdoc/docs/17-FAQ.md b/modules/docs/src/main/mdoc/docs/17-FAQ.md index 279584421..c432cdd09 100644 --- a/modules/docs/src/main/mdoc/docs/17-FAQ.md +++ b/modules/docs/src/main/mdoc/docs/17-FAQ.md @@ -150,7 +150,7 @@ As of **doobie** 0.4 there is a reasonable solution to the logging/instrumentati There are a lot of ways to handle `SQLXML` so there is no pre-defined strategy, but here is one that maps `scala.xml.Elem` to `SQLXML` via streaming. ```scala mdoc:silent -import doobie.enum.JdbcType.Other +import doobie.enumerated.JdbcType.Other import java.sql.SQLXML import scala.xml.{ XML, Elem } @@ -181,7 +181,7 @@ Domains with check constraints will type check as DISTINCT. For Doobie later tha ```scala mdoc:silent import cats.data.NonEmptyList import doobie._ -import doobie.enum.JdbcType +import doobie.enumerated.JdbcType object distinct { diff --git a/modules/docs/src/main/mdoc/docs/18-Related-Projects.md b/modules/docs/src/main/mdoc/docs/18-Related-Projects.md index 1bcfdd33d..b8fd7dc9f 100644 --- a/modules/docs/src/main/mdoc/docs/18-Related-Projects.md +++ b/modules/docs/src/main/mdoc/docs/18-Related-Projects.md @@ -5,4 +5,5 @@ A non-exhaustive list of projects that supplement Doobie: - [DoobieRoll](https://github.com/jatcwang/doobieroll) - collection of utilities to make working with Doobie / SQL even easier. - [TableColumns](https://jatcwang.github.io/doobieroll/docs/tablecolumns) - Ensure fields in your SQL are consistently named and ordered. - [Assembler](https://jatcwang.github.io/doobieroll/docs/assembler) - Assemble SQL query results into hierarchical domain models. -- [doobie-typesafe](https://github.com/arturaz/doobie-typesafe) - type-safe table definitions for Doobie. \ No newline at end of file +- [doobie-typesafe](https://github.com/arturaz/doobie-typesafe) - type-safe table definitions for Doobie. +- [otel4s-doobie](https://github.com/arturaz/otel4s-doobie) - [Otel4s](https://github.com/typelevel/otel4s) integration. diff --git a/modules/munit/src/test/scala/doobie/munit/CheckerTests.scala b/modules/munit/src/test/scala/doobie/munit/CheckerTests.scala index 296211fac..c31cf40b0 100644 --- a/modules/munit/src/test/scala/doobie/munit/CheckerTests.scala +++ b/modules/munit/src/test/scala/doobie/munit/CheckerTests.scala @@ -47,8 +47,6 @@ trait CheckerChecks[M[_]] extends FunSuite with Checker[M] { } test("Read should select correct columns for checking when combined with `ap`") { - import doobie.generic.auto.* - val readInt = Read[(Int, Int)] val readIntToInt: Read[Tuple2[Int, Int] => String] = Read[(String, String)].map(i => k => s"$i,$k") diff --git a/modules/postgres/src/main/scala/doobie/postgres/Instances.scala b/modules/postgres/src/main/scala/doobie/postgres/Instances.scala index a30b8202c..a88d1af46 100644 --- a/modules/postgres/src/main/scala/doobie/postgres/Instances.scala +++ b/modules/postgres/src/main/scala/doobie/postgres/Instances.scala @@ -170,6 +170,28 @@ trait Instances { .timap(_.map(_.map(a => if (a == null) null else BigDecimal.apply(a))))(_.map(_.map(a => if (a == null) null else a.bigDecimal))) + /** Create a Meta instance to allow reading and writing into an array of enum, with stricter typechecking support to + * verify that the column we're inserting into must match the enum array type. + * + * @param enumTypeName + * Name of the enum type + * @param fromStr + * Function to convert each element to the Scala type when reading from the database + * @param toStr + * Function to convert each element to string when writing to the database + * @return + */ + def arrayOfEnum[A: ClassTag]( + enumTypeName: String, + fromStr: String => A, + toStr: A => String + ): Meta[Array[A]] = { + Meta.Advanced.array[String]( + enumTypeName, + arrayTypeName = s"_$enumTypeName" + ).timap(arr => arr.map(fromStr))(arr => arr.map(toStr)) + } + // So, it turns out that arrays of structs don't work because something is missing from the // implementation. So this means we will only be able to support primitive types for arrays. // diff --git a/modules/postgres/src/test/scala/doobie/postgres/PgArraySuite.scala b/modules/postgres/src/test/scala/doobie/postgres/PgArraySuite.scala new file mode 100644 index 000000000..9a64b93cb --- /dev/null +++ b/modules/postgres/src/test/scala/doobie/postgres/PgArraySuite.scala @@ -0,0 +1,114 @@ +// Copyright (c) 2013-2020 Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package doobie.postgres + +import cats.effect.IO +import doobie.Transactor +import doobie.postgres.enums.{MyEnum, OtherEnum} +import doobie.postgres.implicits.* +import doobie.syntax.all.* +import doobie.util.analysis.{ColumnTypeError, ParameterTypeError} + +class PgArraySuite extends munit.CatsEffectSuite { + + val transactor: Transactor[IO] = Transactor.fromDriverManager[IO]( + driver = "org.postgresql.Driver", + url = "jdbc:postgresql:world", + user = "postgres", + password = "password", + logHandler = None + ) + + private val listOfMyEnums: List[MyEnum] = List(MyEnum.Foo, MyEnum.Bar) + + private val listOfOtherEnums: List[OtherEnum] = List(OtherEnum.A, OtherEnum.B) + + test("array of custom string type: read correctly and typechecks") { + val q = sql"select array['foo', 'bar'] :: myenum[]".query[List[MyEnum]] + (for { + _ <- q.analysis + .map(ana => assertEquals(ana.columnAlignmentErrors, List.empty)) + + _ <- q.unique.map(assertEquals(_, listOfMyEnums)) + + _ <- sql"select array['foo', 'bar']".query[List[MyEnum]].analysis.map(_.columnAlignmentErrors) + .map { + case List(e: ColumnTypeError) => assertEquals(e.schema.vendorTypeName, "_text") + case other => fail(s"Unexpected typecheck result: $other") + } + } yield ()) + .transact(transactor) + } + + test("array of custom string type: writes correctly and typechecks") { + val q = sql"insert into temp_myenum (arr) values ($listOfMyEnums)".update + (for { + _ <- sql"drop table if exists temp_myenum".update.run + _ <- sql"create table temp_myenum(arr myenum[] not null)".update.run + _ <- q.analysis.map(_.columnAlignmentErrors).map(ana => assertEquals(ana, List.empty)) + _ <- q.run + _ <- sql"select arr from temp_myenum".query[List[MyEnum]].unique + .map(assertEquals(_, listOfMyEnums)) + + _ <- sql"insert into temp_myenum (arr) values (${List("foo")})".update.analysis + .map(_.parameterAlignmentErrors) + .map { + case List(e: ParameterTypeError) => assertEquals(e.vendorTypeName, "_myenum") + case other => fail(s"Unexpected typecheck result: $other") + } + } yield ()) + .transact(transactor) + } + + test("array of custom type in another schema: read correctly and typechecks") { + val q = sql"select array['a', 'b'] :: other_schema.other_enum[]".query[List[OtherEnum]] + (for { + _ <- q.analysis + .map(ana => assertEquals(ana.columnAlignmentErrors, List.empty)) + + _ <- q.unique.map(assertEquals(_, listOfOtherEnums)) + + _ <- sql"select array['a', 'b']".query[List[OtherEnum]].analysis.map(_.columnAlignmentErrors) + .map { + case List(e: ColumnTypeError) => assertEquals(e.schema.vendorTypeName, "_text") + case other => fail(s"Unexpected typecheck result: $other") + } + + _ <- sql"select array['a', 'b'] :: other_schema.other_enum[]".query[List[String]].analysis.map( + _.columnAlignmentErrors) + .map { + case List(e: ColumnTypeError) => assertEquals(e.schema.vendorTypeName, """"other_schema"."_other_enum"""") + case other => fail(s"Unexpected typecheck result: $other") + } + } yield ()) + .transact(transactor) + } + + test("array of custom type in another schema: writes correctly and typechecks") { + val q = sql"insert into temp_otherenum (arr) values ($listOfOtherEnums)".update + (for { + _ <- sql"drop table if exists temp_otherenum".update.run + _ <- sql"create table temp_otherenum(arr other_schema.other_enum[] not null)".update.run + _ <- q.analysis.map(_.parameterAlignmentErrors).map(ana => assertEquals(ana, List.empty)) + _ <- q.run + _ <- sql"select arr from temp_otherenum".query[List[OtherEnum]].to[List] + .map(assertEquals(_, List(listOfOtherEnums))) + + _ <- sql"insert into temp_otherenum (arr) values (${List("a")})".update.analysis + .map(_.parameterAlignmentErrors) + .map { + case List(e: ParameterTypeError) => { + // pgjdbc is a bit crazy. If you have inserted into the table already then it'll report the parameter type as + // _other_enum, or otherwise "other_schema"."_other_enum".. + assertEquals(e.vendorTypeName, "_other_enum") + // assertEquals(e.vendorTypeName, s""""other_schema"."_other_enum"""") + } + case other => fail(s"Unexpected typecheck result: $other") + } + } yield ()) + .transact(transactor) + } + +} diff --git a/modules/postgres/src/test/scala/doobie/postgres/enums/MyEnum.scala b/modules/postgres/src/test/scala/doobie/postgres/enums/MyEnum.scala index 7dcd2f4c2..cd376e6dc 100644 --- a/modules/postgres/src/test/scala/doobie/postgres/enums/MyEnum.scala +++ b/modules/postgres/src/test/scala/doobie/postgres/enums/MyEnum.scala @@ -6,6 +6,7 @@ package doobie.postgres.enums import doobie.Meta import doobie.postgres.implicits.* +import doobie.postgres.implicits.arrayOfEnum // create type myenum as enum ('foo', 'bar') <-- part of setup sealed trait MyEnum @@ -13,15 +14,30 @@ object MyEnum { case object Foo extends MyEnum case object Bar extends MyEnum + def fromStringUnsafe(s: String): MyEnum = s match { + case "foo" => Foo + case "bar" => Bar + } + + def asString(e: MyEnum): String = e match { + case Foo => "foo" + case Bar => "bar" + } + + private val typeName = "myenum" + implicit val MyEnumMeta: Meta[MyEnum] = pgEnumString( - "myenum", - { - case "foo" => Foo - case "bar" => Bar - }, - { - case Foo => "foo" - case Bar => "bar" - }) + typeName, + fromStringUnsafe, + asString + ) + + implicit val MyEnumArrayMeta: Meta[Array[MyEnum]] = + arrayOfEnum[MyEnum]( + typeName, + fromStringUnsafe, + asString + ) + } diff --git a/modules/postgres/src/test/scala/doobie/postgres/enums/OtherEnum.scala b/modules/postgres/src/test/scala/doobie/postgres/enums/OtherEnum.scala new file mode 100644 index 000000000..30451d05e --- /dev/null +++ b/modules/postgres/src/test/scala/doobie/postgres/enums/OtherEnum.scala @@ -0,0 +1,32 @@ +// Copyright (c) 2013-2020 Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package doobie.postgres.enums + +import doobie.Meta + +// This is an enum type defined in another schema (See other_enum in test-db.sql) +sealed abstract class OtherEnum(val strValue: String) + +object OtherEnum { + case object A extends OtherEnum("a") + + case object B extends OtherEnum("b") + + private def fromStrUnsafe(s: String): OtherEnum = s match { + case "a" => A + case "b" => B + } + + private val elementTypeNameUnqualified = "other_enum" + private val elementTypeName = s""""other_schema"."$elementTypeNameUnqualified"""" + private val arrayTypeName = s""""other_schema"."_$elementTypeNameUnqualified"""" + + implicit val arrayMeta: Meta[Array[OtherEnum]] = + Meta.Advanced.array[String]( + elementTypeName, + arrayTypeName, + s"_$elementTypeNameUnqualified" + ).timap(arr => arr.map(fromStrUnsafe))(arr => arr.map(_.strValue)) +} diff --git a/modules/weaver/src/test/scala/doobie/weaver/CheckerTests.scala b/modules/weaver/src/test/scala/doobie/weaver/CheckerTests.scala index 97a703c4d..43989044a 100644 --- a/modules/weaver/src/test/scala/doobie/weaver/CheckerTests.scala +++ b/modules/weaver/src/test/scala/doobie/weaver/CheckerTests.scala @@ -54,8 +54,6 @@ object CheckerTests extends IOSuite with IOChecker { } test("Read should select correct columns for checking when combined with `ap`") { implicit transactor => - import doobie.generic.auto.* - val readInt = Read[(Int, Int)] val readIntToInt: Read[Tuple2[Int, Int] => String] = Read[(String, String)].map(i => k => s"$i,$k") diff --git a/project/build.properties b/project/build.properties index db1723b08..73df629ac 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.5 +sbt.version=1.10.7 diff --git a/project/plugins.sbt b/project/plugins.sbt index 0665f1a71..94f22c016 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -2,13 +2,13 @@ addSbtPlugin("com.lightbend.paradox" % "sbt-paradox" % "0.10.7") addSbtPlugin("com.github.sbt" % "sbt-site" % "1.7.0") addSbtPlugin("com.github.sbt" % "sbt-site-paradox" % "1.7.0") addSbtPlugin("com.github.sbt" % "sbt-ghpages" % "0.8.0") -addSbtPlugin("org.typelevel" % "sbt-typelevel-ci-release" % "0.7.4") -addSbtPlugin("org.typelevel" % "sbt-typelevel-mergify" % "0.7.4") +addSbtPlugin("org.typelevel" % "sbt-typelevel-ci-release" % "0.7.7") +addSbtPlugin("org.typelevel" % "sbt-typelevel-mergify" % "0.7.7") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.2.2") addSbtPlugin("com.timushev.sbt" % "sbt-updates" % "0.6.2") addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7") addSbtPlugin("de.heikoseeberger" % "sbt-header" % "5.10.0") addSbtPlugin("org.typelevel" % "sbt-tpolecat" % "0.5.2") -addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.6.1") +addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.6.2") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") addDependencyTreePlugin