Skip to content

Commit

Permalink
WSRequest: Normalize URL
Browse files Browse the repository at this point in the history
  • Loading branch information
htmldoug committed Jan 22, 2019
1 parent 4c4538b commit 23cbbac
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (C) 2009-2019 Lightbend Inc. <https://www.lightbend.com>
*/

package play.api.libs.ws.ahc

import java.util.concurrent.TimeUnit

import akka.stream.Materializer
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

/**
* ==Quick Run from sbt==
*
* > bench/jmh:run .*StandaloneAhcWSRequestBench
*
* ==Using Oracle Flight Recorder==
*
* To record a Flight Recorder file from a JMH run, run it using the jmh.extras.JFR profiler:
* > bench/jmh:run -prof jmh.extras.JFR .*StandaloneAhcWSRequestBench
*
* Compare your results before/after on your machine. Don't trust the ones in scaladoc.
*
* Sample benchmark results:
* {{{
* > bench/jmh:run .*StandaloneAhcWSRequestBench
* [info] Benchmark Mode Cnt Score Error Units
* [info] StandaloneAhcWSRequestBench.urlNoParams avgt 5 326.443 ± 3.712 ns/op
* [info] StandaloneAhcWSRequestBench.urlWithParams avgt 5 1562.871 ± 16.736 ns/op
* }}}
*
* @see https://github.com/ktoso/sbt-jmh
*/
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Array(Mode.AverageTime))
@Fork(jvmArgsAppend = Array("-Xmx350m", "-XX:+HeapDumpOnOutOfMemoryError"), value = 1)
@State(Scope.Benchmark)
class StandaloneAhcWSRequestBench {

private implicit val materializer: Materializer = null // we're not actually going to execute anything.
private val wsClient = StandaloneAhcWSClient()

@Benchmark
def urlNoParams(bh: Blackhole): Unit = {
bh.consume(wsClient.url("https://www.example.com/foo/bar/a/b"))
}

@Benchmark
def urlWithParams(bh: Blackhole): Unit = {
bh.consume(wsClient.url("https://www.example.com?foo=bar& = "))
}

@TearDown
def teardown(): Unit = wsClient.close()
}
28 changes: 23 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ val disableDocs = Seq[Setting[_]](

val disablePublishing = Seq[Setting[_]](
publishArtifact := false,
skip in publish := true,
crossScalaVersions := Seq(scala212)
skip in publish := true
)

lazy val shadeAssemblySettings = commonSettings ++ Seq(
Expand Down Expand Up @@ -465,7 +464,6 @@ lazy val `integration-tests` = project.in(file("integration-tests"))
.settings(disableDocs)
.settings(disablePublishing)
.settings(
crossScalaVersions := Seq(scala213, scala212, scala211),
fork in Test := true,
concurrentRestrictions += Tags.limitAll(1), // only one integration test at a time
testOptions in Test := Seq(Tests.Argument(TestFrameworks.JUnit, "-a", "-v")),
Expand All @@ -480,6 +478,24 @@ lazy val `integration-tests` = project.in(file("integration-tests"))
)
.disablePlugins(sbtassembly.AssemblyPlugin)

//---------------------------------------------------------------
// Benchmarks (run manually)
//---------------------------------------------------------------

lazy val bench = project
.in(file("bench"))
.enablePlugins(JmhPlugin)
.dependsOn(
`play-ws-standalone`,
`play-ws-standalone-json`,
`play-ws-standalone-xml`,
`play-ahc-ws-standalone`
)
.settings(commonSettings)
.settings(formattingSettings)
.settings(disableDocs)
.settings(disablePublishing)

//---------------------------------------------------------------
// Root Project
//---------------------------------------------------------------
Expand All @@ -496,13 +512,15 @@ lazy val root = project
.settings(formattingSettings)
.settings(disableDocs)
.settings(disablePublishing)
.settings(crossScalaVersions := Seq(scala212))
.aggregate(
`shaded`,
`play-ws-standalone`,
`play-ws-standalone-json`,
`play-ws-standalone-xml`,
`play-ahc-ws-standalone`,
`integration-tests`
`integration-tests`,
bench
)
.disablePlugins(sbtassembly.AssemblyPlugin)

Expand Down Expand Up @@ -545,4 +563,4 @@ checkCodeFormat := {
}
}

addCommandAlias("validateCode", ";scalariformFormat;test:scalariformFormat;headerCheck;test:headerCheck;checkCodeFormat")
addCommandAlias("validateCode", ";scalariformFormat;test:scalariformFormat;headerCheck;test:headerCheck;checkCodeFormat")
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@

package play.api.libs.ws.ahc

import javax.inject.Inject
import java.net.URLDecoder
import java.util.Collections

import akka.stream.Materializer
import akka.stream.scaladsl.Source
import akka.util.ByteString
import com.typesafe.sslconfig.ssl.SystemConfiguration
import com.typesafe.sslconfig.ssl.debug.DebugConfiguration
import javax.inject.Inject
import play.api.libs.ws.ahc.cache._
import play.api.libs.ws.{ EmptyBody, StandaloneWSClient, StandaloneWSRequest }
import play.shaded.ahc.org.asynchttpclient.uri.Uri
import play.shaded.ahc.org.asynchttpclient.util.{ MiscUtils, UriEncoder }
import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse, _ }

import scala.collection.immutable.TreeMap
import scala.compat.java8.FunctionConverters
import scala.concurrent.{ Await, Future, Promise }
import scala.util.control.NonFatal
import scala.util.{ Failure, Success, Try }

/**
* A WS client backed by an AsyncHttpClient.
Expand All @@ -40,8 +45,7 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
}

def url(url: String): StandaloneWSRequest = {
validate(url)
StandaloneAhcWSRequest(
val req = StandaloneAhcWSRequest(
client = this,
url = url,
method = "GET",
Expand All @@ -57,6 +61,8 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
proxyServer = None,
disableUrlEncoding = None
)

StandaloneAhcWSClient.normalize(req)
}

private[ahc] def execute(request: Request): Future[StandaloneAhcWSResponse] = {
Expand All @@ -76,18 +82,6 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
result.future
}

private def validate(url: String): Unit = {
// Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149
try {
Uri.create(url)
} catch {
case iae: IllegalArgumentException =>
throw new IllegalArgumentException(s"Invalid URL $url", iae)
case npe: NullPointerException =>
throw new IllegalArgumentException(s"Invalid URL $url", npe)
}
}

private[ahc] def executeStream(request: Request): Future[StreamedResponse] = {
val promise = Promise[StreamedResponse]()

Expand Down Expand Up @@ -117,12 +111,12 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici

Await.result(result, StandaloneAhcWSClient.blockingTimeout)
}

}

object StandaloneAhcWSClient {

import scala.concurrent.duration._

val blockingTimeout = 50.milliseconds
val elementLimit = 13 // 13 8192k blocks is roughly 100k
private val logger = org.slf4j.LoggerFactory.getLogger(this.getClass)
Expand Down Expand Up @@ -164,5 +158,130 @@ object StandaloneAhcWSClient {
new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl)
wsClient
}

/**
* Ensures:
*
* 1. [[StandaloneWSRequest.url]] path is encoded, e.g.
* ws.url("http://example.com/foo bar") ->
* ws.url("http://example.com/foo%20bar")
*
* 2. Any query params present in the URL are moved to [[StandaloneWSRequest.queryString]], e.g.
* ws.url("http://example.com/?foo=bar") ->
* ws.url("http://example.com/").withQueryString("foo" -> "bar")
*/
@throws[IllegalArgumentException]("if the url is unrepairable")
private[ahc] def normalize(req: StandaloneAhcWSRequest): StandaloneWSRequest = {
import MiscUtils.isEmpty
if (req.url.indexOf('?') != -1) {
// Query params in the path. Move them to the queryParams: Map.
repair(req)
} else {
Try(req.uri) match {
case Success(uri) =>

/*
* [[Uri.create()]] throws if the host or scheme is missing.
* We can do those checks against the the [[java.net.URI]]
* to avoid incurring the cost of re-parsing the URL string.
*
* @see https://github.com/AsyncHttpClient/async-http-client/issues/1149
*/
if (isEmpty(uri.getScheme)) {
throw new IllegalArgumentException(req.url + " could not be parsed into a proper Uri, missing scheme")
}
if (isEmpty(uri.getHost)) {
throw new IllegalArgumentException(req.url + " could not be parsed into a proper Uri, missing host")
}

req
case Failure(_) =>
// URI parsing error. Sometimes recoverable by UriEncoder.FIXING
repair(req)
}
}
}

/**
* Encodes the URI to [[Uri]] and runs it through the same [[UriEncoder.FIXING]]
* that async-http-client uses before executing it.
*/
@throws[IllegalArgumentException]("if the url is unrepairable")
private def repair(req: StandaloneAhcWSRequest): StandaloneWSRequest = {
try {
val encodedAhcUri: Uri = toUri(req)
val javaUri = encodedAhcUri.toJavaNetURI
setUri(req, encodedAhcUri.withNewQuery(null).toUrl, Option(javaUri.getRawQuery))
} catch {
case NonFatal(t) =>
throw new IllegalArgumentException(s"Invalid URL ${req.url}", t)
}
}

/**
* Builds an AHC [[Uri]] with all parts URL encoded by [[UriEncoder.FIXING]].
* Combines query params from both [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]].
*/
private def toUri(req: StandaloneWSRequest): Uri = {
val combinedUri: Uri = {
val uri = Uri.create(req.url)

val paramsMap = req.queryString
if (paramsMap.nonEmpty) {
val query: String = combineQuery(uri.getQuery, paramsMap)
uri.withNewQuery(query)
} else {
uri
}
}

// FIXING.encode() encodes ONLY unencoded parts, leaving encoded parts untouched.
UriEncoder.FIXING.encode(combinedUri, Collections.emptyList())
}

private def combineQuery(query: String, params: Map[String, Seq[String]]): String = {
val sb = new StringBuilder
// Reminder: ahc.Uri.query does include '?' (unlike java.net.URI)
if (query != null) {
sb.append(query)
}

for {
(key, values) <- params
value <- values
} {
if (sb.nonEmpty) {
sb.append('&')
}
sb.append(key)
if (value.nonEmpty) {
sb.append('=').append(value)
}
}

sb.toString
}

/**
* Replace the [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]]
* with the values of [[uri]], discarding originals.
*/
private def setUri(
req: StandaloneAhcWSRequest,
urlNoQueryParams: String,
encodedQueryString: Option[String]): StandaloneWSRequest = {
val queryParams: List[(String, String)] = for {
queryString <- encodedQueryString.toList
// https://stackoverflow.com/a/13592567 for all of this.
pair <- queryString.split('&')
idx = pair.indexOf('=')
key = URLDecoder.decode(if (idx > 0) pair.substring(0, idx) else pair, "UTF-8")
value = if (idx > 0) URLDecoder.decode(pair.substring(idx + 1), "UTF-8") else ""
} yield key -> value

req
.withUrl(urlNoQueryParams)
.withQueryStringParameters(queryParams: _*)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import java.nio.charset.{ Charset, StandardCharsets }

import akka.stream.Materializer
import akka.stream.scaladsl.Sink
import play.api.libs.ws.{ StandaloneWSRequest, _ }
import play.api.libs.ws._
import play.shaded.ahc.io.netty.buffer.Unpooled
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
Expand Down Expand Up @@ -185,7 +185,10 @@ case class StandaloneAhcWSRequest(
withMethod(method).execute()
}

override def withUrl(url: String): Self = copy(url = url)
override def withUrl(url: String): Self = {
val unsafe = copy(url = url)
StandaloneAhcWSClient.normalize(unsafe)
}

override def withMethod(method: String): Self = copy(method = method)

Expand Down
Loading

0 comments on commit 23cbbac

Please sign in to comment.