diff --git a/java/client/build.gradle b/java/client/build.gradle index 4f73647f0c..be9120cf3f 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -18,15 +18,25 @@ dependencies { implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-x86_64' implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-x86_64' implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-aarch_64' + + //lombok + compileOnly 'org.projectlombok:lombok:1.18.30' + annotationProcessor 'org.projectlombok:lombok:1.18.30' + testCompileOnly 'org.projectlombok:lombok:1.18.30' + testAnnotationProcessor 'org.projectlombok:lombok:1.18.30' + + // junit + testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' } tasks.register('protobuf', Exec) { doFirst { - project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString()) + project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString()) } commandLine 'protoc', '-Iprotobuf=babushka-core/src/protobuf/', - '--java_out=java/client/src/main/java/babushka/protobuf', + '--java_out=java/client/src/main/java/babushka/models/protobuf', 'babushka-core/src/protobuf/connection_request.proto', 'babushka-core/src/protobuf/redis_request.proto', 'babushka-core/src/protobuf/response.proto' @@ -35,7 +45,7 @@ tasks.register('protobuf', Exec) { tasks.register('cleanProtobuf') { doFirst { - project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString()) + project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString()) } } diff --git a/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java b/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java deleted file mode 100644 index a16871b99a..0000000000 --- a/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java +++ /dev/null @@ -1,11 +0,0 @@ -package babushka; - -public class BabushkaCoreNativeDefinitions { - public static native String startSocketListenerExternal() throws Exception; - - public static native Object valueFromPointer(long pointer); - - static { - System.loadLibrary("javababushka"); - } -} diff --git a/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java new file mode 100644 index 0000000000..5ebaa03969 --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java @@ -0,0 +1,65 @@ +package babushka.connectors.handlers; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.tuple.Pair; +import response.ResponseOuterClass.Response; + +/** Holder for resources required to dispatch responses and used by {@link ReadHandler}. */ +public class CallbackDispatcher { + /** Unique request ID (callback ID). Thread-safe. */ + private final AtomicInteger requestId = new AtomicInteger(0); + + /** + * Storage of Futures to handle responses. Map key is callback id, which starts from 1.
+ * Each future is a promise for every submitted by user request. + */ + private final Map> responses = new ConcurrentHashMap<>(); + + /** + * Storage for connection request similar to {@link #responses}. Unfortunately, connection + * requests can't be stored in the same storage, because callback ID = 0 is hardcoded for + * connection requests. + */ + private final CompletableFuture connectionPromise = new CompletableFuture<>(); + + /** + * Register a new request to be sent. Once response received, the given future completes with it. + * + * @return A pair of unique callback ID which should set into request and a client promise for + * response. + */ + public Pair> registerRequest() { + int callbackId = requestId.incrementAndGet(); + var future = new CompletableFuture(); + responses.put(callbackId, future); + return Pair.of(callbackId, future); + } + + public CompletableFuture registerConnection() { + return connectionPromise; + } + + /** + * Complete the corresponding client promise and free resources. + * + * @param response A response received + */ + public void completeRequest(Response response) { + int callbackId = response.getCallbackIdx(); + if (callbackId == 0) { + connectionPromise.completeAsync(() -> response); + } else { + responses.get(callbackId).completeAsync(() -> response); + responses.remove(callbackId); + } + } + + public void shutdownGracefully() { + connectionPromise.cancel(false); + responses.values().forEach(future -> future.cancel(false)); + responses.clear(); + } +} diff --git a/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java new file mode 100644 index 0000000000..40feea417b --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java @@ -0,0 +1,78 @@ +package babushka.connectors.handlers; + +import babushka.connectors.resources.Platform; +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.unix.DomainSocketAddress; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import redis_request.RedisRequestOuterClass.RedisRequest; +import response.ResponseOuterClass.Response; + +/** + * Class responsible for handling calls to/from a netty.io {@link Channel}.
+ * Uses a {@link CallbackDispatcher} to record callbacks of every request sent. + */ +public class ChannelHandler { + + private static final String THREAD_POOL_NAME = "babushka-channel"; + + private final Channel channel; + private final CallbackDispatcher callbackDispatcher; + + /** Open a new channel for a new client. */ + public ChannelHandler(CallbackDispatcher callbackDispatcher, String socketPath) { + channel = + new Bootstrap() + // TODO let user specify the thread pool or pool size as an option + .group(Platform.createNettyThreadPool(THREAD_POOL_NAME, Optional.empty())) + .channel(Platform.getClientUdsNettyChannelType()) + .handler(new ProtobufSocketChannelInitializer(callbackDispatcher)) + .connect(new DomainSocketAddress(socketPath)) + // TODO call here .sync() if needed or remove this comment + .channel(); + this.callbackDispatcher = callbackDispatcher; + } + + /** + * Complete a protobuf message and write it to the channel (to UDS). + * + * @param request Incomplete request, function completes it by setting callback ID + * @param flush True to flush immediately + * @return A response promise + */ + public CompletableFuture write(RedisRequest.Builder request, boolean flush) { + var commandId = callbackDispatcher.registerRequest(); + request.setCallbackIdx(commandId.getKey()); + + if (flush) { + channel.writeAndFlush(request.build()); + } else { + channel.write(request.build()); + } + return commandId.getValue(); + } + + /** + * Write a protobuf message to the channel (to UDS). + * + * @param request A connection request + * @return A connection promise + */ + public CompletableFuture connect(ConnectionRequest request) { + channel.writeAndFlush(request); + return callbackDispatcher.registerConnection(); + } + + private final AtomicBoolean closed = new AtomicBoolean(false); + + /** Closes the UDS connection and frees corresponding resources. */ + public void close() { + if (closed.compareAndSet(false, true)) { + channel.close(); + callbackDispatcher.shutdownGracefully(); + } + } +} diff --git a/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java b/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java new file mode 100644 index 0000000000..06c4c03f02 --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java @@ -0,0 +1,31 @@ +package babushka.connectors.handlers; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.unix.UnixChannel; +import io.netty.handler.codec.protobuf.ProtobufDecoder; +import io.netty.handler.codec.protobuf.ProtobufEncoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import response.ResponseOuterClass.Response; + +/** Builder for the channel used by {@link ChannelHandler}. */ +@RequiredArgsConstructor +public class ProtobufSocketChannelInitializer extends ChannelInitializer { + + private final CallbackDispatcher callbackDispatcher; + + @Override + public void initChannel(@NonNull UnixChannel ch) { + ch.pipeline() + // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html + .addLast("frameDecoder", new ProtobufVarint32FrameDecoder()) + .addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender()) + .addLast("protobufDecoder", new ProtobufDecoder(Response.getDefaultInstance())) + .addLast("protobufEncoder", new ProtobufEncoder()) + .addLast(new ReadHandler(callbackDispatcher)) + .addLast(new ChannelOutboundHandlerAdapter()); + } +} diff --git a/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java b/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java new file mode 100644 index 0000000000..63aedf001e --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java @@ -0,0 +1,28 @@ +package babushka.connectors.handlers; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import response.ResponseOuterClass.Response; + +/** Handler for inbound traffic though UDS. Used by Netty. */ +@RequiredArgsConstructor +public class ReadHandler extends ChannelInboundHandlerAdapter { + + private final CallbackDispatcher callbackDispatcher; + + /** Submit responses from babushka to an instance {@link CallbackDispatcher} to handle them. */ + @Override + public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) { + callbackDispatcher.completeRequest((Response) msg); + } + + /** Handles uncaught exceptions from {@link #channelRead(ChannelHandlerContext, Object)}. */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + System.out.printf("=== exceptionCaught %s %s %n", ctx, cause); + cause.printStackTrace(System.err); + super.exceptionCaught(ctx, cause); + } +} diff --git a/java/client/src/main/java/babushka/connectors/resources/Platform.java b/java/client/src/main/java/babushka/connectors/resources/Platform.java new file mode 100644 index 0000000000..0e6c3edec9 --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/resources/Platform.java @@ -0,0 +1,139 @@ +package babushka.connectors.resources; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.kqueue.KQueue; +import io.netty.channel.kqueue.KQueueDomainSocketChannel; +import io.netty.channel.kqueue.KQueueEventLoopGroup; +import io.netty.channel.unix.DomainSocketChannel; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.UtilityClass; + +/** + * An auxiliary class purposed to detect platform (OS + JVM) {@link Capabilities} and allocate + * corresponding resources. + */ +@UtilityClass +public class Platform { + + @Getter + @AllArgsConstructor(access = AccessLevel.PRIVATE) + @ToString + private static class Capabilities { + private final boolean isKQueueAvailable; + private final boolean isEPollAvailable; + // TODO support IO-Uring and NIO + private final boolean isIOUringAvailable; + // At the moment, Windows is not supported + // Probably we should use NIO (NioEventLoopGroup) for Windows. + private final boolean isNIOAvailable; + } + + /** Detected platform (OS + JVM) capabilities. Not supposed to be changed in runtime. */ + @Getter + private static final Capabilities capabilities = + new Capabilities(isKQueueAvailable(), isEPollAvailable(), false, false); + + /** + * Thread pools supplied to Netty to perform all async IO.
+ * Map key is supposed to be pool name + thread count as a string concat product. + */ + private static final Map groups = new ConcurrentHashMap<>(); + + /** Detect kqueue availability. */ + private static boolean isKQueueAvailable() { + try { + Class.forName("io.netty.channel.kqueue.KQueue"); + return KQueue.isAvailable(); + } catch (ClassNotFoundException e) { + return false; + } + } + + /** Detect epoll availability. */ + private static boolean isEPollAvailable() { + try { + Class.forName("io.netty.channel.epoll.Epoll"); + return Epoll.isAvailable(); + } catch (ClassNotFoundException e) { + return false; + } + } + + /** + * Allocate Netty thread pool required to manage connection. A thread pool could be shared across + * multiple connections. + * + * @return A new thread pool. + */ + public static EventLoopGroup createNettyThreadPool(String prefix, Optional threadLimit) { + int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors()); + if (capabilities.isKQueueAvailable()) { + var name = prefix + "-kqueue-elg"; + return getOrCreate( + name + threadCount, + () -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); + } else if (capabilities.isEPollAvailable()) { + var name = prefix + "-epoll-elg"; + return getOrCreate( + name + threadCount, + () -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); + } + // TODO support IO-Uring and NIO + + throw new RuntimeException("Current platform supports no known thread pool types"); + } + + /** + * Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache. + */ + private static EventLoopGroup getOrCreate(String name, Supplier supplier) { + if (groups.containsKey(name)) { + return groups.get(name); + } + EventLoopGroup group = supplier.get(); + groups.put(name, group); + return group; + } + + /** + * Get a channel class required by Netty to open a client UDS channel. + * + * @return Return a class supported by the current platform. + */ + public static Class getClientUdsNettyChannelType() { + if (capabilities.isKQueueAvailable()) { + return KQueueDomainSocketChannel.class; + } + if (capabilities.isEPollAvailable()) { + return EpollDomainSocketChannel.class; + } + throw new RuntimeException("Current platform supports no known socket types"); + } + + /** + * A JVM shutdown hook to be registered. It is responsible for closing connection and freeing + * resources. It is recommended to use a class instead of lambda to ensure that it is called.
+ * See {@link Runtime#addShutdownHook}. + */ + private static class ShutdownHook implements Runnable { + @Override + public void run() { + groups.values().forEach(EventLoopGroup::shutdownGracefully); + } + } + + static { + Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook")); + } +} diff --git a/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java b/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java new file mode 100644 index 0000000000..6d4ec45121 --- /dev/null +++ b/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java @@ -0,0 +1,25 @@ +package babushka.ffi.resolvers; + +public class BabushkaCoreNativeDefinitions { + public static native String startSocketListenerExternal() throws Exception; + + public static native Object valueFromPointer(long pointer); + + static { + System.loadLibrary("javababushka"); + } + + /** + * Make an FFI call to obtain the socket path. + * + * @return A UDS path. + */ + public static String getSocket() { + try { + return startSocketListenerExternal(); + } catch (Exception | UnsatisfiedLinkError e) { + System.err.printf("Failed to create a UDS connection: %s%n%n", e); + throw new RuntimeException(e); + } + } +} diff --git a/java/src/lib.rs b/java/src/lib.rs index 3cd7bc5ed7..13577f0805 100644 --- a/java/src/lib.rs +++ b/java/src/lib.rs @@ -42,7 +42,9 @@ fn redis_value_to_java(mut env: JNIEnv, val: Value) -> JObject { } #[no_mangle] -pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_valueFromPointer<'local>( +pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_valueFromPointer< + 'local, +>( env: JNIEnv<'local>, _class: JClass<'local>, pointer: jlong, @@ -52,7 +54,7 @@ pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_valueFromPoin } #[no_mangle] -pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_startSocketListenerExternal< +pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_startSocketListenerExternal< 'local, >( env: JNIEnv<'local>,