diff --git a/java/client/build.gradle b/java/client/build.gradle
index 4f73647f0c..be9120cf3f 100644
--- a/java/client/build.gradle
+++ b/java/client/build.gradle
@@ -18,15 +18,25 @@ dependencies {
implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-x86_64'
implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-x86_64'
implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-aarch_64'
+
+ //lombok
+ compileOnly 'org.projectlombok:lombok:1.18.30'
+ annotationProcessor 'org.projectlombok:lombok:1.18.30'
+ testCompileOnly 'org.projectlombok:lombok:1.18.30'
+ testAnnotationProcessor 'org.projectlombok:lombok:1.18.30'
+
+ // junit
+ testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
+ testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4'
}
tasks.register('protobuf', Exec) {
doFirst {
- project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString())
+ project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString())
}
commandLine 'protoc',
'-Iprotobuf=babushka-core/src/protobuf/',
- '--java_out=java/client/src/main/java/babushka/protobuf',
+ '--java_out=java/client/src/main/java/babushka/models/protobuf',
'babushka-core/src/protobuf/connection_request.proto',
'babushka-core/src/protobuf/redis_request.proto',
'babushka-core/src/protobuf/response.proto'
@@ -35,7 +45,7 @@ tasks.register('protobuf', Exec) {
tasks.register('cleanProtobuf') {
doFirst {
- project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString())
+ project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString())
}
}
diff --git a/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java b/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java
deleted file mode 100644
index a16871b99a..0000000000
--- a/java/client/src/main/java/babushka/BabushkaCoreNativeDefinitions.java
+++ /dev/null
@@ -1,11 +0,0 @@
-package babushka;
-
-public class BabushkaCoreNativeDefinitions {
- public static native String startSocketListenerExternal() throws Exception;
-
- public static native Object valueFromPointer(long pointer);
-
- static {
- System.loadLibrary("javababushka");
- }
-}
diff --git a/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java
new file mode 100644
index 0000000000..5ebaa03969
--- /dev/null
+++ b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java
@@ -0,0 +1,65 @@
+package babushka.connectors.handlers;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.commons.lang3.tuple.Pair;
+import response.ResponseOuterClass.Response;
+
+/** Holder for resources required to dispatch responses and used by {@link ReadHandler}. */
+public class CallbackDispatcher {
+ /** Unique request ID (callback ID). Thread-safe. */
+ private final AtomicInteger requestId = new AtomicInteger(0);
+
+ /**
+ * Storage of Futures to handle responses. Map key is callback id, which starts from 1.
+ * Each future is a promise for every submitted by user request.
+ */
+ private final Map> responses = new ConcurrentHashMap<>();
+
+ /**
+ * Storage for connection request similar to {@link #responses}. Unfortunately, connection
+ * requests can't be stored in the same storage, because callback ID = 0 is hardcoded for
+ * connection requests.
+ */
+ private final CompletableFuture connectionPromise = new CompletableFuture<>();
+
+ /**
+ * Register a new request to be sent. Once response received, the given future completes with it.
+ *
+ * @return A pair of unique callback ID which should set into request and a client promise for
+ * response.
+ */
+ public Pair> registerRequest() {
+ int callbackId = requestId.incrementAndGet();
+ var future = new CompletableFuture();
+ responses.put(callbackId, future);
+ return Pair.of(callbackId, future);
+ }
+
+ public CompletableFuture registerConnection() {
+ return connectionPromise;
+ }
+
+ /**
+ * Complete the corresponding client promise and free resources.
+ *
+ * @param response A response received
+ */
+ public void completeRequest(Response response) {
+ int callbackId = response.getCallbackIdx();
+ if (callbackId == 0) {
+ connectionPromise.completeAsync(() -> response);
+ } else {
+ responses.get(callbackId).completeAsync(() -> response);
+ responses.remove(callbackId);
+ }
+ }
+
+ public void shutdownGracefully() {
+ connectionPromise.cancel(false);
+ responses.values().forEach(future -> future.cancel(false));
+ responses.clear();
+ }
+}
diff --git a/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java
new file mode 100644
index 0000000000..40feea417b
--- /dev/null
+++ b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java
@@ -0,0 +1,78 @@
+package babushka.connectors.handlers;
+
+import babushka.connectors.resources.Platform;
+import connection_request.ConnectionRequestOuterClass.ConnectionRequest;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.unix.DomainSocketAddress;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+import redis_request.RedisRequestOuterClass.RedisRequest;
+import response.ResponseOuterClass.Response;
+
+/**
+ * Class responsible for handling calls to/from a netty.io {@link Channel}.
+ * Uses a {@link CallbackDispatcher} to record callbacks of every request sent.
+ */
+public class ChannelHandler {
+
+ private static final String THREAD_POOL_NAME = "babushka-channel";
+
+ private final Channel channel;
+ private final CallbackDispatcher callbackDispatcher;
+
+ /** Open a new channel for a new client. */
+ public ChannelHandler(CallbackDispatcher callbackDispatcher, String socketPath) {
+ channel =
+ new Bootstrap()
+ // TODO let user specify the thread pool or pool size as an option
+ .group(Platform.createNettyThreadPool(THREAD_POOL_NAME, Optional.empty()))
+ .channel(Platform.getClientUdsNettyChannelType())
+ .handler(new ProtobufSocketChannelInitializer(callbackDispatcher))
+ .connect(new DomainSocketAddress(socketPath))
+ // TODO call here .sync() if needed or remove this comment
+ .channel();
+ this.callbackDispatcher = callbackDispatcher;
+ }
+
+ /**
+ * Complete a protobuf message and write it to the channel (to UDS).
+ *
+ * @param request Incomplete request, function completes it by setting callback ID
+ * @param flush True to flush immediately
+ * @return A response promise
+ */
+ public CompletableFuture write(RedisRequest.Builder request, boolean flush) {
+ var commandId = callbackDispatcher.registerRequest();
+ request.setCallbackIdx(commandId.getKey());
+
+ if (flush) {
+ channel.writeAndFlush(request.build());
+ } else {
+ channel.write(request.build());
+ }
+ return commandId.getValue();
+ }
+
+ /**
+ * Write a protobuf message to the channel (to UDS).
+ *
+ * @param request A connection request
+ * @return A connection promise
+ */
+ public CompletableFuture connect(ConnectionRequest request) {
+ channel.writeAndFlush(request);
+ return callbackDispatcher.registerConnection();
+ }
+
+ private final AtomicBoolean closed = new AtomicBoolean(false);
+
+ /** Closes the UDS connection and frees corresponding resources. */
+ public void close() {
+ if (closed.compareAndSet(false, true)) {
+ channel.close();
+ callbackDispatcher.shutdownGracefully();
+ }
+ }
+}
diff --git a/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java b/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java
new file mode 100644
index 0000000000..06c4c03f02
--- /dev/null
+++ b/java/client/src/main/java/babushka/connectors/handlers/ProtobufSocketChannelInitializer.java
@@ -0,0 +1,31 @@
+package babushka.connectors.handlers;
+
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.unix.UnixChannel;
+import io.netty.handler.codec.protobuf.ProtobufDecoder;
+import io.netty.handler.codec.protobuf.ProtobufEncoder;
+import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder;
+import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+import response.ResponseOuterClass.Response;
+
+/** Builder for the channel used by {@link ChannelHandler}. */
+@RequiredArgsConstructor
+public class ProtobufSocketChannelInitializer extends ChannelInitializer {
+
+ private final CallbackDispatcher callbackDispatcher;
+
+ @Override
+ public void initChannel(@NonNull UnixChannel ch) {
+ ch.pipeline()
+ // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html
+ .addLast("frameDecoder", new ProtobufVarint32FrameDecoder())
+ .addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender())
+ .addLast("protobufDecoder", new ProtobufDecoder(Response.getDefaultInstance()))
+ .addLast("protobufEncoder", new ProtobufEncoder())
+ .addLast(new ReadHandler(callbackDispatcher))
+ .addLast(new ChannelOutboundHandlerAdapter());
+ }
+}
diff --git a/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java b/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java
new file mode 100644
index 0000000000..63aedf001e
--- /dev/null
+++ b/java/client/src/main/java/babushka/connectors/handlers/ReadHandler.java
@@ -0,0 +1,28 @@
+package babushka.connectors.handlers;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+import response.ResponseOuterClass.Response;
+
+/** Handler for inbound traffic though UDS. Used by Netty. */
+@RequiredArgsConstructor
+public class ReadHandler extends ChannelInboundHandlerAdapter {
+
+ private final CallbackDispatcher callbackDispatcher;
+
+ /** Submit responses from babushka to an instance {@link CallbackDispatcher} to handle them. */
+ @Override
+ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) {
+ callbackDispatcher.completeRequest((Response) msg);
+ }
+
+ /** Handles uncaught exceptions from {@link #channelRead(ChannelHandlerContext, Object)}. */
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ System.out.printf("=== exceptionCaught %s %s %n", ctx, cause);
+ cause.printStackTrace(System.err);
+ super.exceptionCaught(ctx, cause);
+ }
+}
diff --git a/java/client/src/main/java/babushka/connectors/resources/Platform.java b/java/client/src/main/java/babushka/connectors/resources/Platform.java
new file mode 100644
index 0000000000..0e6c3edec9
--- /dev/null
+++ b/java/client/src/main/java/babushka/connectors/resources/Platform.java
@@ -0,0 +1,139 @@
+package babushka.connectors.resources;
+
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.epoll.Epoll;
+import io.netty.channel.epoll.EpollDomainSocketChannel;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.kqueue.KQueue;
+import io.netty.channel.kqueue.KQueueDomainSocketChannel;
+import io.netty.channel.kqueue.KQueueEventLoopGroup;
+import io.netty.channel.unix.DomainSocketChannel;
+import io.netty.util.concurrent.DefaultThreadFactory;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Supplier;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.ToString;
+import lombok.experimental.UtilityClass;
+
+/**
+ * An auxiliary class purposed to detect platform (OS + JVM) {@link Capabilities} and allocate
+ * corresponding resources.
+ */
+@UtilityClass
+public class Platform {
+
+ @Getter
+ @AllArgsConstructor(access = AccessLevel.PRIVATE)
+ @ToString
+ private static class Capabilities {
+ private final boolean isKQueueAvailable;
+ private final boolean isEPollAvailable;
+ // TODO support IO-Uring and NIO
+ private final boolean isIOUringAvailable;
+ // At the moment, Windows is not supported
+ // Probably we should use NIO (NioEventLoopGroup) for Windows.
+ private final boolean isNIOAvailable;
+ }
+
+ /** Detected platform (OS + JVM) capabilities. Not supposed to be changed in runtime. */
+ @Getter
+ private static final Capabilities capabilities =
+ new Capabilities(isKQueueAvailable(), isEPollAvailable(), false, false);
+
+ /**
+ * Thread pools supplied to Netty to perform all async IO.
+ * Map key is supposed to be pool name + thread count as a string concat product.
+ */
+ private static final Map groups = new ConcurrentHashMap<>();
+
+ /** Detect kqueue availability. */
+ private static boolean isKQueueAvailable() {
+ try {
+ Class.forName("io.netty.channel.kqueue.KQueue");
+ return KQueue.isAvailable();
+ } catch (ClassNotFoundException e) {
+ return false;
+ }
+ }
+
+ /** Detect epoll availability. */
+ private static boolean isEPollAvailable() {
+ try {
+ Class.forName("io.netty.channel.epoll.Epoll");
+ return Epoll.isAvailable();
+ } catch (ClassNotFoundException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Allocate Netty thread pool required to manage connection. A thread pool could be shared across
+ * multiple connections.
+ *
+ * @return A new thread pool.
+ */
+ public static EventLoopGroup createNettyThreadPool(String prefix, Optional threadLimit) {
+ int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors());
+ if (capabilities.isKQueueAvailable()) {
+ var name = prefix + "-kqueue-elg";
+ return getOrCreate(
+ name + threadCount,
+ () -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
+ } else if (capabilities.isEPollAvailable()) {
+ var name = prefix + "-epoll-elg";
+ return getOrCreate(
+ name + threadCount,
+ () -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
+ }
+ // TODO support IO-Uring and NIO
+
+ throw new RuntimeException("Current platform supports no known thread pool types");
+ }
+
+ /**
+ * Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache.
+ */
+ private static EventLoopGroup getOrCreate(String name, Supplier supplier) {
+ if (groups.containsKey(name)) {
+ return groups.get(name);
+ }
+ EventLoopGroup group = supplier.get();
+ groups.put(name, group);
+ return group;
+ }
+
+ /**
+ * Get a channel class required by Netty to open a client UDS channel.
+ *
+ * @return Return a class supported by the current platform.
+ */
+ public static Class extends DomainSocketChannel> getClientUdsNettyChannelType() {
+ if (capabilities.isKQueueAvailable()) {
+ return KQueueDomainSocketChannel.class;
+ }
+ if (capabilities.isEPollAvailable()) {
+ return EpollDomainSocketChannel.class;
+ }
+ throw new RuntimeException("Current platform supports no known socket types");
+ }
+
+ /**
+ * A JVM shutdown hook to be registered. It is responsible for closing connection and freeing
+ * resources. It is recommended to use a class instead of lambda to ensure that it is called.
+ * See {@link Runtime#addShutdownHook}.
+ */
+ private static class ShutdownHook implements Runnable {
+ @Override
+ public void run() {
+ groups.values().forEach(EventLoopGroup::shutdownGracefully);
+ }
+ }
+
+ static {
+ Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook"));
+ }
+}
diff --git a/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java b/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java
new file mode 100644
index 0000000000..6d4ec45121
--- /dev/null
+++ b/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java
@@ -0,0 +1,25 @@
+package babushka.ffi.resolvers;
+
+public class BabushkaCoreNativeDefinitions {
+ public static native String startSocketListenerExternal() throws Exception;
+
+ public static native Object valueFromPointer(long pointer);
+
+ static {
+ System.loadLibrary("javababushka");
+ }
+
+ /**
+ * Make an FFI call to obtain the socket path.
+ *
+ * @return A UDS path.
+ */
+ public static String getSocket() {
+ try {
+ return startSocketListenerExternal();
+ } catch (Exception | UnsatisfiedLinkError e) {
+ System.err.printf("Failed to create a UDS connection: %s%n%n", e);
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/java/src/lib.rs b/java/src/lib.rs
index 3cd7bc5ed7..13577f0805 100644
--- a/java/src/lib.rs
+++ b/java/src/lib.rs
@@ -42,7 +42,9 @@ fn redis_value_to_java(mut env: JNIEnv, val: Value) -> JObject {
}
#[no_mangle]
-pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_valueFromPointer<'local>(
+pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_valueFromPointer<
+ 'local,
+>(
env: JNIEnv<'local>,
_class: JClass<'local>,
pointer: jlong,
@@ -52,7 +54,7 @@ pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_valueFromPoin
}
#[no_mangle]
-pub extern "system" fn Java_babushka_BabushkaCoreNativeDefinitions_startSocketListenerExternal<
+pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_startSocketListenerExternal<
'local,
>(
env: JNIEnv<'local>,