Skip to content

Commit

Permalink
sqlcommenter
Browse files Browse the repository at this point in the history
  • Loading branch information
guymers authored Nov 19, 2023
1 parent 1a69443 commit ec4ce2d
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = 3.7.14
version = 3.7.17
runner.dialect = scala213source3

align.preset = none
Expand Down
37 changes: 27 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import FreeGen2.*

val catsVersion = "2.10.0"
val catsEffectVersion = "3.5.1"
val catsEffectVersion = "3.5.2"
val circeVersion = "0.14.6"
val fs2Version = "3.9.2"
val fs2Version = "3.9.3"
val h2Version = "2.2.224"
val hikariVersion = "5.0.1"
val hikariVersion = "5.1.0"
val magnoliaVersion = "1.1.3"
val munitVersion = "1.0.0-M10"
val mysqlVersion = "8.1.0"
val mysqlVersion = "8.2.0"
val openTelemetryVersion = "1.32.0"
val postgisVersion = "2021.1.0"
val postgresVersion = "42.6.0"
val scalatestVersion = "3.2.17"
val shapelessVersion = "2.3.10"
val slf4jVersion = "2.0.9"
val weaverVersion = "0.8.3"
val zioInteropCats = "23.0.0.8"
val zioVersion = "2.0.17"
val zioInteropCats = "23.1.0.0"
val zioVersion = "2.0.19"

val Scala213 = "2.13.12"
val Scala3 = "3.3.1"
Expand Down Expand Up @@ -130,6 +131,8 @@ lazy val noPublishSettings = Seq(
mimaPreviousArtifacts := Set.empty,
)

lazy val runningInIntelliJ = System.getProperty("idea.managed", "false").toBoolean

def filterScalacConsoleOpts(options: Seq[String]) = {
options.filterNot { opt =>
opt == "-Xfatal-warnings" || opt.startsWith("-Xlint") || opt.startsWith("-W")
Expand All @@ -142,16 +145,28 @@ def module(name: String) = Project(name, file(s"modules/$name"))
.settings(
mimaPreviousArtifacts := previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet
)
.settings(
if (runningInIntelliJ) Seq(
Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "it" / "scala",
) else Seq.empty
)

def moduleIT(name: String) = Project(s"$name-it", file(s"modules/$name-it"))
.settings(moduleName := s"foobie-$name-it")
.settings(commonSettings)
.settings(
publish / skip := true,
Compile / javaSource := baseDirectory.value / ".." / name / "src" / "main-it" / "java",
Compile / scalaSource := baseDirectory.value / ".." / name / "src" / "main-it" / "scala",
Test / javaSource := baseDirectory.value / ".." / name / "src" / "it" / "java",
Test / scalaSource := baseDirectory.value / ".." / name / "src" / "it" / "scala",
Test / fork := true,
Test / javaOptions += "-Xmx1000m",
)
.settings(
// intellij complains about shared content roots, so it gets the source appended in `module`
if (runningInIntelliJ) Seq.empty else Seq(
Compile / javaSource := baseDirectory.value / ".." / name / "src" / "main-it" / "java",
Compile / scalaSource := baseDirectory.value / ".." / name / "src" / "main-it" / "scala",
Test / javaSource := baseDirectory.value / ".." / name / "src" / "it" / "java",
Test / scalaSource := baseDirectory.value / ".." / name / "src" / "it" / "scala",
)
)
.disablePlugins(MimaPlugin)

Expand Down Expand Up @@ -390,6 +405,7 @@ lazy val zio = module("zio")
"com.mysql" % "mysql-connector-j" % mysqlVersion % Optional,
"org.postgresql" % "postgresql" % postgresVersion % Optional,
"net.postgis" % "postgis-jdbc" % postgisVersion % Optional,
"io.opentelemetry" % "opentelemetry-api" % openTelemetryVersion % Optional,

"dev.zio" %% "zio-test" % zioVersion % Test,
"dev.zio" %% "zio-test-sbt" % zioVersion % Test,
Expand All @@ -403,6 +419,7 @@ lazy val `zio-it` = moduleIT("zio")
libraryDependencies ++= Seq(
"dev.zio" %% "zio-test" % zioVersion % Test,
"dev.zio" %% "zio-test-sbt" % zioVersion % Test,
"io.opentelemetry" % "opentelemetry-api" % openTelemetryVersion % Test,
),
)
.dependsOn(zio, postgres)
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ services:
shm_size: 128m
command: [
"postgres",
#"-c", "log_statement=all"
#"-c", "log_statement=all",
"-c", "max_connections=400",
"-c", "shared_buffers=250MB", # 25% of RAM
"-c", "effective_cache_size=700MB", # 70% of RAM
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ object PostgreSQLIntegrationSpec extends ZIOSpecDefault {
p <- pool(connectionConfig, config)
transactor = Transactor.fromPoolTransactional(p)
results <- run(transactor)
metrics = zio.internal.metrics.MetricRegistryExposed.snapshot
metrics <- ZIO.metrics.map(_.metrics)
} yield {
val metricPairs = metrics.map { p =>
val tags = p.metricKey.tags.toList.sortBy(_.key)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package zoobie.sqlcommenter

import doobie.syntax.string.*
import io.opentelemetry.api.common.AttributeKey
import io.opentelemetry.api.common.Attributes
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanContext
import io.opentelemetry.api.trace.StatusCode
import io.opentelemetry.api.trace.TraceFlags
import io.opentelemetry.api.trace.TraceState
import zio.Chunk
import zio.ZIO
import zio.durationInt
import zio.test.TestAspect
import zio.test.ZIOSpecDefault
import zio.test.assertCompletes
import zoobie.ConnectionPoolConfig
import zoobie.Transactor
import zoobie.postgres.PostgreSQLConnectionConfig
import zoobie.postgres.pool

import java.util.concurrent.TimeUnit

object SQLCommenterIntegrationSpec extends ZIOSpecDefault {

override val spec = test("SQLCommenterIntegrationSpec") {
val spanContext = new SpanContext {
override val getTraceId = "3b120af54ca6f7efacddf3e538dd4988"
override val getSpanId = "7cdf802020b41208"
override val getTraceFlags = TraceFlags.getSampled
override val getTraceState = TraceState.builder().put("key", "value").build()
override val isRemote = false
}
val span = new Span {
override def setAttribute[T](key: AttributeKey[T], value: T) = ???
override def addEvent(name: String, attributes: Attributes) = ???
override def addEvent(name: String, attributes: Attributes, timestamp: Long, unit: TimeUnit) = ???
override def setStatus(statusCode: StatusCode, description: String) = ???
override def recordException(exception: Throwable, additionalAttributes: Attributes) = ???
override def updateName(name: String) = ???
override def end(): Unit = ???
override def end(timestamp: Long, unit: TimeUnit): Unit = ???
override def isRecording = ???
override def getSpanContext = spanContext
}

for {
p <- pool(connectionConfig, config)
interpreter = TraceInterpreter.create(Transactor.kleisliInterpreter, ZIO.succeed(Some(span)))
transactor = Transactor(p.get, interpreter.ConnectionInterpreter, Transactor.strategies.transactional)
_ <- transactor.run(fr"SELECT 1".query[Int].unique)
} yield {
assertCompletes
}
}

override val aspects = super.aspects ++ Chunk(
TestAspect.timed,
TestAspect.timeout(90.seconds),
TestAspect.withLiveClock,
)

private lazy val connectionConfig = PostgreSQLConnectionConfig(
host = "localhost",
database = "world",
username = "postgres",
password = "password",
applicationName = "doobie",
)

private lazy val config = ConnectionPoolConfig(
name = "zoobie-postgres-it",
size = 5,
queueSize = 1_000,
maxConnectionLifetime = 30.seconds,
validationTimeout = 2.seconds,
)
}
4 changes: 3 additions & 1 deletion modules/zio/src/main/scala/zoobie/Transactor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ object Transactor {

private val sync: Sync[Task] = zio.interop.catz.asyncInstance[Any]

val interpreter: Interpreter[Task] = KleisliInterpreter(sync).ConnectionInterpreter
val kleisliInterpreter: KleisliInterpreter[Task] = KleisliInterpreter(sync)

val interpreter: Interpreter[Task] = kleisliInterpreter.ConnectionInterpreter

object strategies {
val noop: Strategy = Strategy.void
Expand Down
87 changes: 87 additions & 0 deletions modules/zio/src/main/scala/zoobie/sqlcommenter/SQLCommenter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zoobie.sqlcommenter

import java.net.URLEncoder
import java.nio.charset.StandardCharsets
import scala.collection.immutable.SortedMap
import scala.jdk.CollectionConverters.*

// https://google.github.io/sqlcommenter/spec/
final case class SQLCommenter(
controller: Option[String],
action: Option[String],
framework: Option[String],
trace: Option[SQLCommenter.Trace],
) {
import SQLCommenter.serializeKeyValues

def format: String = {
val traceState = trace.flatMap(_.state).map { state =>
state
.filter { case (k, _) => k.nonEmpty }
.map { case (k, v) => s"$k=$v" }
.mkString(",")
}
val m = SortedMap(
"controller" -> controller,
"action" -> action,
"framework" -> framework,
"traceparent" -> trace.map(_.parent),
"tracestate" -> traceState,
).collect { case (k, Some(v)) => (k, v) }
serializeKeyValues(m)
}

}
object SQLCommenter {

final case class Trace(
traceId: String,
spanId: String,
options: Byte,
state: Option[Map[String, String]],
) {
def parent = String.format("00-%s-%s-%02X", traceId, spanId, options)
}
object Trace {

def fromOpenTelemetryContext(spanContext: io.opentelemetry.api.trace.SpanContext) = {
Option(spanContext).filter(_.isValid).map { ctx =>
val traceId = ctx.getTraceId
val spanId = ctx.getSpanId
val options = ctx.getTraceFlags

val state = Option(ctx.getTraceState).filter(!_.isEmpty).map { state =>
state.asMap().asScala.toMap
}

Trace(traceId = traceId, spanId = spanId, options.asByte, state)
}
}
}

private[sqlcommenter] val serializeKey =
urlEncode andThen escapeMetaCharacters
private[sqlcommenter] val serializeValue =
urlEncode andThen escapeMetaCharacters andThen sqlEscape

private[sqlcommenter] def serializeKeyValue(k: String, v: String) = s"${serializeKey(k)}=${serializeValue(v)}"

private[sqlcommenter] def serializeKeyValues(m: Map[String, String]) = {
if (m.isEmpty) ""
else m.toList.sorted.map(serializeKeyValue.tupled).mkString(",")
}

@SuppressWarnings(Array("org.wartremover.warts.Null"))
private def urlEncode(s: String) = {
URLEncoder.encode(s, StandardCharsets.UTF_8)
.replaceAll("%27", "'")
.replaceAll("\\+", "%20")
}
private def escapeMetaCharacters(s: String) = s.replaceAll("'", "\\\\'")
private def sqlEscape(s: String) = s"'$s'"

def affix(state: SQLCommenter, sql: String): String = {
val commentStr = state.format
if (commentStr.isEmpty) sql else sql.concat(s"\n/*${commentStr}*/")
}
}
Loading

0 comments on commit ec4ce2d

Please sign in to comment.