diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index dd48dc296..284899204 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -6,7 +6,7 @@ okio = "3.3.0" [libraries] antlr = "org.antlr:antlr4:4.9.3" -apacheThrift = "org.apache.thrift:libthrift:0.17.0" +apacheThrift = "org.apache.thrift:libthrift:0.19.0" clikt = "com.github.ajalt.clikt:clikt:3.1.0" dokka = "org.jetbrains.dokka:dokka-gradle-plugin:1.7.20" guava = "com.google.guava:guava:31.1-jre" diff --git a/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/CoroutineConformanceTests.kt b/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/CoroutineConformanceTests.kt index d1044762d..6567c5673 100644 --- a/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/CoroutineConformanceTests.kt +++ b/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/CoroutineConformanceTests.kt @@ -42,6 +42,7 @@ import com.microsoft.thrifty.testing.ServerProtocol import com.microsoft.thrifty.testing.ServerTransport import com.microsoft.thrifty.testing.TestServer import com.microsoft.thrifty.transport.FramedTransport +import com.microsoft.thrifty.transport.HttpTransport import com.microsoft.thrifty.transport.SocketTransport import com.microsoft.thrifty.transport.Transport import io.kotest.assertions.fail @@ -73,6 +74,9 @@ class NonblockingCompactCoroutineConformanceTest : CoroutineConformanceTests() @ServerConfig(transport = ServerTransport.NON_BLOCKING, protocol = ServerProtocol.JSON) class NonblockingJsonCoroutineConformanceTest : CoroutineConformanceTests() +@ServerConfig(transport = ServerTransport.HTTP, protocol = ServerProtocol.JSON) +class HttpJsonCoroutineConformanceTest : CoroutineConformanceTests() + /** * A test of auto-generated service code for the standard ThriftTest * service. @@ -103,12 +107,7 @@ abstract class CoroutineConformanceTests { @BeforeAll @JvmStatic fun beforeAll() { - val port = testServer.port() - val transport = SocketTransport.Builder("localhost", port) - .readTimeout(2000) - .build() - - transport.connect() + val transport = getTransportImpl() this.transport = decorateTransport(transport) this.protocol = createProtocol(this.transport) @@ -123,6 +122,18 @@ abstract class CoroutineConformanceTests { }) } + private fun getTransportImpl(): Transport { + return when(testServer.transport) { + ServerTransport.BLOCKING, ServerTransport.NON_BLOCKING -> + return SocketTransport.Builder("localhost", testServer.port()) + .readTimeout(2000) + .build() + .apply { connect() } + + ServerTransport.HTTP -> HttpTransport("http://localhost:${testServer.port()}/test/service") + } + } + /** * When overridden in a derived class, wraps the given transport * in a decorator, e.g. a framed transport. diff --git a/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/KotlinConformanceTest.kt b/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/KotlinConformanceTest.kt index de7808ff8..a7eb48aef 100644 --- a/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/KotlinConformanceTest.kt +++ b/thrifty-integration-tests/src/test/kotlin/com/microsoft/thrifty/integration/conformance/KotlinConformanceTest.kt @@ -42,6 +42,7 @@ import com.microsoft.thrifty.testing.ServerProtocol import com.microsoft.thrifty.testing.ServerTransport import com.microsoft.thrifty.testing.TestServer import com.microsoft.thrifty.transport.FramedTransport +import com.microsoft.thrifty.transport.HttpTransport import com.microsoft.thrifty.transport.SocketTransport import com.microsoft.thrifty.transport.Transport import io.kotest.matchers.should @@ -72,6 +73,9 @@ class NonblockingCompactConformanceTest : KotlinConformanceTest() @ServerConfig(transport = ServerTransport.NON_BLOCKING, protocol = ServerProtocol.JSON) class NonblockingJsonConformanceTest : KotlinConformanceTest() +@ServerConfig(transport = ServerTransport.HTTP, protocol = ServerProtocol.JSON) +class HttpJsonConformanceTest : KotlinConformanceTest() + /** * A test of auto-generated service code for the standard ThriftTest * service. @@ -102,12 +106,7 @@ abstract class KotlinConformanceTest { @BeforeAll @JvmStatic fun beforeAll() { - val port = testServer.port() - val transport = SocketTransport.Builder("localhost", port) - .readTimeout(2000) - .build() - - transport.connect() + val transport = getTransportImpl() this.transport = decorateTransport(transport) this.protocol = createProtocol(this.transport) @@ -122,6 +121,18 @@ abstract class KotlinConformanceTest { }) } + private fun getTransportImpl(): Transport { + return when(testServer.transport) { + ServerTransport.BLOCKING, ServerTransport.NON_BLOCKING -> + return SocketTransport.Builder("localhost", testServer.port()) + .readTimeout(2000) + .build() + .apply { connect() } + + ServerTransport.HTTP -> HttpTransport("http://localhost:${testServer.port()}/test/service") + } + } + /** * When overridden in a derived class, wraps the given transport * in a decorator, e.g. a framed transport. diff --git a/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/HttpTransport.kt b/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/HttpTransport.kt new file mode 100644 index 000000000..31b1dbb77 --- /dev/null +++ b/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/HttpTransport.kt @@ -0,0 +1,186 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ + +// Adapted from Thrift sources; original license header follows: +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package com.microsoft.thrifty.transport + + +import com.microsoft.thrifty.internal.ProtocolException +import java.io.ByteArrayOutputStream +import java.io.InputStream +import java.net.HttpURLConnection +import java.net.URL + +/** + * HTTP implementation of the TTransport interface. Used for working with a + * Thrift web services implementation (using for example TServlet). + * + * THIS IMPLEMENTATION IS NOT THREAD-SAFE !!! + * + * Based on the official thrift java THttpTransport with the apache client support removed. + * Both due to wanting to avoid the additional dependency as well as it being a bit weird to have two + * implementations to switch between in the same class. + * + * Uses HttpURLConnection internally + * + * Also note that under high load, the HttpURLConnection implementation + * may exhaust the open file descriptor limit. + * + * @see [THRIFT-970](https://issues.apache.org/jira/browse/THRIFT-970) + */ +open class HttpTransport(url: String) : Transport { + private val url: URL = URL(url) + private var currentState: Transport = Writing() + private var connectTimeout: Int? = null + private var readTimeout: Int? = null + private val customHeaders = mutableMapOf() + private val sendBuffer = ByteArrayOutputStream() + + private inner class Writing : Transport { + override fun read(buffer: ByteArray, offset: Int, count: Int): Int { + throw ProtocolException("Currently in writing state") + } + + override fun write(buffer: ByteArray, offset: Int, count: Int) { + sendBuffer.write(buffer, offset, count) + } + + override fun flush() { + val bytesToSend = sendBuffer.toByteArray() + sendBuffer.reset() + send(bytesToSend) + } + + override fun close() { + // do nothing + } + } + + private inner class Reading(val inputStream: InputStream) : Transport { + override fun read(buffer: ByteArray, offset: Int, count: Int): Int { + val ret = inputStream.read(buffer, offset, count) + if (ret == -1) { + throw ProtocolException("No more data available.") + } + return ret + } + + override fun write(buffer: ByteArray, offset: Int, count: Int) { + throw ProtocolException("currently in reading state") + } + + override fun flush() { + throw ProtocolException("currently in reading state") + } + + override fun close() { + inputStream.close() + } + } + + fun send(data: ByteArray) { + // Create connection object + val connection = url.openConnection() as HttpURLConnection + + prepareConnection(connection) + // Make the request + connection.connect() + connection.outputStream.write(data) + val responseCode = connection.responseCode + if (responseCode != HttpURLConnection.HTTP_OK) { + throw ProtocolException("HTTP Response code: $responseCode") + } + + // Read the response + this.currentState = Reading(connection.inputStream) + } + + protected open fun prepareConnection(connection: HttpURLConnection) { + // Timeouts, only if explicitly set + connectTimeout?.let { connection.connectTimeout = it } + readTimeout?.let { connection.readTimeout = it } + + connection.requestMethod = "POST" + connection.setRequestProperty("Content-Type", "application/x-thrift") + connection.setRequestProperty("Accept", "application/x-thrift") + connection.setRequestProperty("User-Agent", "Java/THttpClient") + for ((key, value) in customHeaders) { + connection.setRequestProperty(key, value) + } + connection.doOutput = true + } + + fun setConnectTimeout(timeout: Int) { + connectTimeout = timeout + } + + fun setReadTimeout(timeout: Int) { + readTimeout = timeout + } + + fun setCustomHeaders(headers: Map) { + customHeaders.clear() + customHeaders.putAll(headers) + } + + fun setCustomHeader(key: String, value: String) { + customHeaders[key] = value + } + + override fun close() { + currentState.close() + } + + override fun read(buffer: ByteArray, offset: Int, count: Int): Int = currentState.read(buffer, offset, count) + + override fun write(buffer: ByteArray, offset: Int, count: Int) { + // this mirrors the original behaviour, though it is not very elegant. + // we don't know when the user is done reading, so when they start writing again, + // we just go with it. + if (currentState is Reading) { + currentState.close() + currentState = Writing() + } + currentState.write(buffer, offset, count) + } + + override fun flush() { + currentState.flush() + } +} diff --git a/thrifty-test-server/build.gradle b/thrifty-test-server/build.gradle index c1cb84d2d..dae022fb0 100644 --- a/thrifty-test-server/build.gradle +++ b/thrifty-test-server/build.gradle @@ -28,4 +28,5 @@ dependencies { implementation "commons-codec:commons-codec:1.15" implementation "org.apache.httpcomponents:httpclient:4.5.13" implementation "org.slf4j:slf4j-api:2.0.5" + implementation "org.apache.tomcat.embed:tomcat-embed-core:10.1.4" } diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/HttpServer.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/HttpServer.java new file mode 100644 index 000000000..798b4fd65 --- /dev/null +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/HttpServer.java @@ -0,0 +1,69 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.testing; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.core.StandardContext; +import org.apache.catalina.startup.Tomcat; + +import static com.microsoft.thrifty.testing.TestServer.getProtocolFactory; + +public class HttpServer implements TestServerInterface { + private Tomcat tomcat; + + @Override + public void run(ServerProtocol protocol, ServerTransport transport) { + if (transport != ServerTransport.HTTP) { + throw new IllegalArgumentException("only http transport supported"); + } + this.tomcat = new Tomcat(); + tomcat.setBaseDir(System.getProperty("user.dir") + "\\build"); + tomcat.setPort(0); + tomcat.getHost().setAutoDeploy(false); + + String contextPath = "/test"; + StandardContext context = new StandardContext(); + context.setPath(contextPath); + context.addLifecycleListener(new Tomcat.FixContextListener()); + tomcat.getHost().addChild(context); + tomcat.addServlet(contextPath, "testServlet", new TestServlet(getProtocolFactory(protocol))); + context.addServletMappingDecoded("/service", "testServlet"); + try { + tomcat.start(); + } catch (LifecycleException e) { + throw new RuntimeException(e); + } + } + + @Override + public int port() { + return tomcat.getConnector().getLocalPort(); + } + + @Override + public void close() { + try { + tomcat.stop(); + } catch (LifecycleException e) { + throw new RuntimeException(e); + } + } +} diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/ServerTransport.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/ServerTransport.java index 42770901f..e14b1d387 100644 --- a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/ServerTransport.java +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/ServerTransport.java @@ -29,5 +29,6 @@ public enum ServerTransport { /** * A framed, non-blocking server socket,i.e. TNonblockingServerTransport. */ - NON_BLOCKING + NON_BLOCKING, + HTTP } diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/SocketBasedServer.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/SocketBasedServer.java new file mode 100644 index 000000000..5eb15692b --- /dev/null +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/SocketBasedServer.java @@ -0,0 +1,176 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.testing; + +import com.microsoft.thrifty.test.gen.ThriftTest; +import org.apache.thrift.TProcessor; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; +import org.apache.thrift.protocol.TJSONProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.TNonblockingServer; +import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TThreadPoolServer; +import org.apache.thrift.transport.TNonblockingServerSocket; +import org.apache.thrift.transport.TNonblockingServerTransport; +import org.apache.thrift.transport.TServerSocket; +import org.apache.thrift.transport.TServerTransport; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class SocketBasedServer implements TestServerInterface { + private static final Logger LOG = Logger.getLogger(TestServer.class.getName()); + private TServerTransport serverTransport; + private TServer server; + private Thread serverThread; + + @Override + public void run(ServerProtocol protocol, ServerTransport transport) { + ThriftTestHandler handler = new ThriftTestHandler(System.out); + ThriftTest.Processor processor = new ThriftTest.Processor<>(handler); + + TProtocolFactory factory = TestServer.getProtocolFactory(protocol); + + serverTransport = getServerTransport(transport); + server = startServer(transport, processor, factory); + + final CountDownLatch latch = new CountDownLatch(1); + serverThread = new Thread(() -> { + latch.countDown(); + LOG.entering("TestServer", "serve"); + try { + server.serve(); + } catch (Throwable t) { + LOG.log(Level.SEVERE, "Error while serving", t); + } finally { + LOG.exiting("TestServer", "serve"); + } + }); + + serverThread.start(); + + try { + if (!latch.await(1, TimeUnit.SECONDS)) { + LOG.severe("Server thread failed to start"); + } + } catch (InterruptedException e) { + LOG.severe("Interrupted while waiting for server thread to start"); + e.printStackTrace(); + } + } + + @Override + public int port() { + if (serverTransport instanceof TServerSocket) { + return ((TServerSocket) serverTransport).getServerSocket().getLocalPort(); + } else if (serverTransport instanceof TNonblockingServerSocket) { + TNonblockingServerSocket sock = (TNonblockingServerSocket) serverTransport; + return sock.getPort(); + } else { + throw new AssertionError("Unexpected server transport type: " + serverTransport.getClass()); + } + } + @Override + public void close() { + cleanupServer(); + } + + private void cleanupServer() { + if (serverTransport != null) { + serverTransport.close(); + serverTransport = null; + } + + if (server != null) { + server.stop(); + server = null; + } + + if (serverThread != null) { + serverThread.interrupt(); + serverThread = null; + } + } + private TServerTransport getServerTransport(ServerTransport transport) { + switch (transport) { + case BLOCKING: return getBlockingServerTransport(); + case NON_BLOCKING: return getNonBlockingServerTransport(); + default: + throw new AssertionError("Invalid transport type: " + transport); + } + } + + private TServerTransport getBlockingServerTransport() { + try { + InetAddress localhost = InetAddress.getByName("localhost"); + InetSocketAddress socketAddress = new InetSocketAddress(localhost, 0); + TServerSocket.ServerSocketTransportArgs args = new TServerSocket.ServerSocketTransportArgs() + .bindAddr(socketAddress); + + return new TServerSocket(args); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + private TServerTransport getNonBlockingServerTransport() { + try { + InetAddress localhost = InetAddress.getByName("localhost"); + InetSocketAddress socketAddress = new InetSocketAddress(localhost, 0); + + return new TNonblockingServerSocket(socketAddress); + } catch (Exception e) { + throw new AssertionError(e); + } + } + private TServer startServer(ServerTransport transport, TProcessor processor, TProtocolFactory protocolFactory) { + switch (transport) { + case BLOCKING: return startBlockingServer(processor, protocolFactory); + case NON_BLOCKING: return startNonblockingServer(processor, protocolFactory); + default: + throw new AssertionError("Invalid transport type: " + transport); + } + } + + private TServer startBlockingServer(TProcessor processor, TProtocolFactory protocolFactory) { + TThreadPoolServer.Args args = new TThreadPoolServer.Args(serverTransport) + .processor(processor) + .protocolFactory(protocolFactory); + + return new TThreadPoolServer(args); + } + + private TServer startNonblockingServer(TProcessor processor, TProtocolFactory protocolFactory) { + TNonblockingServerTransport nonblockingTransport = (TNonblockingServerTransport) serverTransport; + TNonblockingServer.Args args = new TNonblockingServer.Args(nonblockingTransport) + .processor(processor) + .protocolFactory(protocolFactory); + + return new TNonblockingServer(args); + } + + +} diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServer.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServer.java index ebd1793b6..013991450 100644 --- a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServer.java +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServer.java @@ -20,89 +20,26 @@ */ package com.microsoft.thrifty.testing; -import com.microsoft.thrifty.test.gen.ThriftTest; -import java.util.logging.Level; -import java.util.logging.Logger; -import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TJSONProtocol; import org.apache.thrift.protocol.TProtocolFactory; -import org.apache.thrift.server.TNonblockingServer; -import org.apache.thrift.server.TServer; -import org.apache.thrift.server.TThreadPoolServer; -import org.apache.thrift.transport.TNonblockingServerSocket; -import org.apache.thrift.transport.TNonblockingServerTransport; -import org.apache.thrift.transport.TServerSocket; -import org.apache.thrift.transport.TServerTransport; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.Extension; import org.junit.jupiter.api.extension.ExtensionContext; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - public class TestServer implements Extension, BeforeAllCallback, AfterAllCallback { - private static final Logger LOG = Logger.getLogger(TestServer.class.getName()); + private TestServerInterface serverImplementation; private ServerProtocol protocol; private ServerTransport transport; - private TServerTransport serverTransport; - private TServer server; - private Thread serverThread; private Class testClass; - public void run() { - ThriftTestHandler handler = new ThriftTestHandler(System.out); - ThriftTest.Processor processor = new ThriftTest.Processor<>(handler); - - TProtocolFactory factory = getProtocolFactory(); - - serverTransport = getServerTransport(); - server = startServer(processor, factory); - - final CountDownLatch latch = new CountDownLatch(1); - serverThread = new Thread(() -> { - latch.countDown(); - LOG.entering("TestServer", "serve"); - try { - server.serve(); - } catch (Throwable t) { - LOG.log(Level.SEVERE, "Error while serving", t); - } finally { - LOG.exiting("TestServer", "serve"); - } - }); - - serverThread.start(); - - try { - if (!latch.await(1, TimeUnit.SECONDS)) { - LOG.severe("Server thread failed to start"); - } - } catch (InterruptedException e) { - LOG.severe("Interrupted while waiting for server thread to start"); - e.printStackTrace(); - } - } - - public int port() { - if (serverTransport instanceof TServerSocket) { - return ((TServerSocket) serverTransport).getServerSocket().getLocalPort(); - } else if (serverTransport instanceof TNonblockingServerSocket) { - TNonblockingServerSocket sock = (TNonblockingServerSocket) serverTransport; - return sock.getPort(); - } else { - throw new AssertionError("Unexpected server transport type: " + serverTransport.getClass()); - } - } public ServerProtocol getProtocol() { return protocol; @@ -112,10 +49,6 @@ public ServerTransport getTransport() { return transport; } - public Class getTestClass() { - return testClass; - } - @Override public void beforeAll(ExtensionContext context) throws Exception { testClass = context.getRequiredTestClass(); @@ -124,69 +57,36 @@ public void beforeAll(ExtensionContext context) throws Exception { protocol = config != null ? config.protocol() : ServerProtocol.BINARY; transport = config != null ? config.transport() : ServerTransport.BLOCKING; - run(); + serverImplementation = getServerImplementation(transport); + serverImplementation.run(protocol, transport); } - @Override - public void afterAll(ExtensionContext context) throws Exception { - cleanupServer(); - } - - public void close() { - cleanupServer(); - } - - private void cleanupServer() { - if (serverTransport != null) { - serverTransport.close(); - serverTransport = null; - } - - if (server != null) { - server.stop(); - server = null; - } - - if (serverThread != null) { - serverThread.interrupt(); - serverThread = null; - } - } - - private TServerTransport getServerTransport() { + private TestServerInterface getServerImplementation(ServerTransport transport) { switch (transport) { - case BLOCKING: return getBlockingServerTransport(); - case NON_BLOCKING: return getNonBlockingServerTransport(); + case BLOCKING: + case NON_BLOCKING: + return new SocketBasedServer(); + case HTTP: + return new HttpServer(); default: throw new AssertionError("Invalid transport type: " + transport); } } - private TServerTransport getBlockingServerTransport() { - try { - InetAddress localhost = InetAddress.getByName("localhost"); - InetSocketAddress socketAddress = new InetSocketAddress(localhost, 0); - TServerSocket.ServerSocketTransportArgs args = new TServerSocket.ServerSocketTransportArgs() - .bindAddr(socketAddress); - - return new TServerSocket(args); - } catch (Exception e) { - throw new AssertionError(e); - } + @Override + public void afterAll(ExtensionContext context) throws Exception { + serverImplementation.close(); } - private TServerTransport getNonBlockingServerTransport() { - try { - InetAddress localhost = InetAddress.getByName("localhost"); - InetSocketAddress socketAddress = new InetSocketAddress(localhost, 0); + public int port() { + return serverImplementation.port(); + } - return new TNonblockingServerSocket(socketAddress); - } catch (Exception e) { - throw new AssertionError(e); - } + public void close() { + serverImplementation.close(); } - private TProtocolFactory getProtocolFactory() { + public static TProtocolFactory getProtocolFactory(ServerProtocol protocol) { switch (protocol) { case BINARY: return new TBinaryProtocol.Factory(); case COMPACT: return new TCompactProtocol.Factory(); @@ -196,29 +96,5 @@ private TProtocolFactory getProtocolFactory() { } } - private TServer startServer(TProcessor processor, TProtocolFactory protocolFactory) { - switch (transport) { - case BLOCKING: return startBlockingServer(processor, protocolFactory); - case NON_BLOCKING: return startNonblockingServer(processor, protocolFactory); - default: - throw new AssertionError("Invalid transport type: " + transport); - } - } - private TServer startBlockingServer(TProcessor processor, TProtocolFactory protocolFactory) { - TThreadPoolServer.Args args = new TThreadPoolServer.Args(serverTransport) - .processor(processor) - .protocolFactory(protocolFactory); - - return new TThreadPoolServer(args); - } - - private TServer startNonblockingServer(TProcessor processor, TProtocolFactory protocolFactory) { - TNonblockingServerTransport nonblockingTransport = (TNonblockingServerTransport) serverTransport; - TNonblockingServer.Args args = new TNonblockingServer.Args(nonblockingTransport) - .processor(processor) - .protocolFactory(protocolFactory); - - return new TNonblockingServer(args); - } } diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServerInterface.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServerInterface.java new file mode 100644 index 000000000..b56c0f649 --- /dev/null +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServerInterface.java @@ -0,0 +1,31 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.testing; + +import com.microsoft.thrifty.test.gen.ThriftTest; + +public interface TestServerInterface { + void run(ServerProtocol protocol, ServerTransport transport); + + int port(); + + void close(); +} diff --git a/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServlet.java b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServlet.java new file mode 100644 index 000000000..3d8903c61 --- /dev/null +++ b/thrifty-test-server/src/main/java/com/microsoft/thrifty/testing/TestServlet.java @@ -0,0 +1,71 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.microsoft.thrifty.testing; + +import com.microsoft.thrifty.test.gen.ThriftTest; +import org.apache.thrift.TProcessor; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.TExtensibleServlet; + +@SuppressWarnings("serial") +public class TestServlet extends TExtensibleServlet { + private final TProtocolFactory protocolFactory; + + public TestServlet(TProtocolFactory protocolFactory) { + this.protocolFactory = protocolFactory; + } + + @Override + protected TProtocolFactory getInProtocolFactory() { + return protocolFactory; + } + + @Override + protected TProtocolFactory getOutProtocolFactory() { + return protocolFactory; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + protected TProcessor getProcessor() { + ThriftTestHandler handler = new ThriftTestHandler(System.out); + return new ThriftTest.Processor<>(handler); + } +}