Skip to content

Commit

Permalink
Merge pull request #1554 from fedejinich/write-time
Browse files Browse the repository at this point in the history
Configurable write timeout for websocket server
  • Loading branch information
aeidelman authored Jul 15, 2021
2 parents 10a1941 + 61cc3a4 commit e0589f3
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 26 deletions.
5 changes: 3 additions & 2 deletions rskj-core/src/main/java/co/rsk/RskContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -1595,12 +1595,13 @@ private Web3WebSocketServer getWeb3WebSocketServer() {
new BlockchainBranchComparator(getBlockStore())
)
);
RskJsonRpcHandler jsonRpcHandler = new RskJsonRpcHandler(emitter, jsonRpcSerializer);
RskWebSocketJsonRpcHandler jsonRpcHandler = new RskWebSocketJsonRpcHandler(emitter, jsonRpcSerializer);
web3WebSocketServer = new Web3WebSocketServer(
rskSystemProperties.rpcWebSocketBindAddress(),
rskSystemProperties.rpcWebSocketPort(),
jsonRpcHandler,
getJsonRpcWeb3ServerHandler()
getJsonRpcWeb3ServerHandler(),
rskSystemProperties.rpcWebSocketServerWriteTimeoutSeconds()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import co.rsk.rpc.modules.RskJsonRpcRequestVisitor;
import co.rsk.rpc.modules.eth.subscribe.EthSubscribeRequest;
import co.rsk.rpc.modules.eth.subscribe.EthUnsubscribeRequest;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandler.Sharable;
Expand All @@ -48,25 +49,25 @@
*/

@Sharable
public class RskJsonRpcHandler
public class RskWebSocketJsonRpcHandler
extends SimpleChannelInboundHandler<ByteBufHolder>
implements RskJsonRpcRequestVisitor {
private static final Logger LOGGER = LoggerFactory.getLogger(RskJsonRpcHandler.class);
private static final Logger LOGGER = LoggerFactory.getLogger(RskWebSocketJsonRpcHandler.class);

private final EthSubscriptionNotificationEmitter emitter;
private final JsonRpcSerializer serializer;

public RskJsonRpcHandler(EthSubscriptionNotificationEmitter emitter, JsonRpcSerializer serializer) {
public RskWebSocketJsonRpcHandler(EthSubscriptionNotificationEmitter emitter, JsonRpcSerializer serializer) {
this.emitter = emitter;
this.serializer = serializer;
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBufHolder msg) {
try {
RskJsonRpcRequest request = serializer.deserializeRequest(
new ByteBufInputStream(msg.copy().content())
);
ByteBuf content = msg.copy().content();

try (ByteBufInputStream source = new ByteBufInputStream(content)){
RskJsonRpcRequest request = serializer.deserializeRequest(source);

// TODO(mc) we should support the ModuleDescription method filters
JsonRpcResultOrError resultOrError = request.accept(this, ctx);
Expand All @@ -75,10 +76,13 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBufHolder msg) {
return;
} catch (IOException e) {
LOGGER.trace("Not a known or valid JsonRpcRequest", e);

// We need to release this resource, netty only takes care about 'ByteBufHolder msg'
content.release(content.refCnt());
}

// delegate to the next handler if the message can't be matched to a known JSON-RPC request
ctx.fireChannelRead(msg.retain());
ctx.fireChannelRead(msg);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package co.rsk.rpc.netty;

import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.WriteTimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RskWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(RskWebSocketServerProtocolHandler.class);
public static final String WRITE_TIMEOUT_REASON = "Exceeded write timout";
public static final int NORMAL_CLOSE_WEBSOCKET_STATUS = 1000;

public RskWebSocketServerProtocolHandler(String websocketPath) {
super(websocketPath);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if(cause instanceof WriteTimeoutException) {
ctx.writeAndFlush(new CloseWebSocketFrame(NORMAL_CLOSE_WEBSOCKET_STATUS, WRITE_TIMEOUT_REASON)).addListener(ChannelFutureListener.CLOSE);
LOGGER.error("Write timeout exceeded, closing web socket channel", cause);
} else {
super.exceptionCaught(ctx, cause);
}
}
}
22 changes: 14 additions & 8 deletions rskj-core/src/main/java/co/rsk/rpc/netty/Web3WebSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,40 @@
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;

public class Web3WebSocketServer implements InternalService {
private static final Logger logger = LoggerFactory.getLogger(Web3WebSocketServer.class);
private static final int HTTP_MAX_CONTENT_LENGTH = 1024 * 1024 * 5;

private final InetAddress host;
private final int port;
private final RskJsonRpcHandler jsonRpcHandler;
private final RskWebSocketJsonRpcHandler webSocketJsonRpcHandler;
private final JsonRpcWeb3ServerHandler web3ServerHandler;
private final EventLoopGroup bossGroup;
private final EventLoopGroup workerGroup;
private @Nullable ChannelFuture webSocketChannel;
private final int serverWriteTimeoutSeconds;

public Web3WebSocketServer(
InetAddress host,
int port,
RskJsonRpcHandler jsonRpcHandler,
JsonRpcWeb3ServerHandler web3ServerHandler) {
RskWebSocketJsonRpcHandler webSocketJsonRpcHandler,
JsonRpcWeb3ServerHandler web3ServerHandler,
int serverWriteTimeoutSeconds) {
this.host = host;
this.port = port;
this.jsonRpcHandler = jsonRpcHandler;
this.webSocketJsonRpcHandler = webSocketJsonRpcHandler;
this.web3ServerHandler = web3ServerHandler;
this.bossGroup = new NioEventLoopGroup();
this.workerGroup = new NioEventLoopGroup();
this.serverWriteTimeoutSeconds = serverWriteTimeoutSeconds;
}

@Override
Expand All @@ -70,9 +75,10 @@ public void start() {
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new HttpServerCodec());
p.addLast(new HttpObjectAggregator(1024 * 1024 * 5));
p.addLast(new WebSocketServerProtocolHandler("/websocket"));
p.addLast(jsonRpcHandler);
p.addLast(new HttpObjectAggregator(HTTP_MAX_CONTENT_LENGTH));
p.addLast(new WriteTimeoutHandler(serverWriteTimeoutSeconds, TimeUnit.SECONDS));
p.addLast(new RskWebSocketServerProtocolHandler("/websocket"));
p.addLast(webSocketJsonRpcHandler);
p.addLast(web3ServerHandler);
p.addLast(new Web3ResultWebSocketResponseHandler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public abstract class SystemProperties {
private static final String PROPERTY_RPC_WEBSOCKET_ENABLED = "rpc.providers.web.ws.enabled";
private static final String PROPERTY_RPC_WEBSOCKET_ADDRESS = "rpc.providers.web.ws.bind_address";
private static final String PROPERTY_RPC_WEBSOCKET_PORT = "rpc.providers.web.ws.port";
private static final String PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS = "rpc.providers.web.ws.server_write_timeout_seconds";

public static final String PROPERTY_PUBLIC_IP = "public.ip";
public static final String PROPERTY_BIND_ADDRESS = "bind_address";
Expand Down Expand Up @@ -612,6 +613,10 @@ public int rpcWebSocketPort() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_PORT);
}

public int rpcWebSocketServerWriteTimeoutSeconds() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS);
}

public InetAddress rpcHttpBindAddress() {
return getWebBindAddress(PROPERTY_RPC_HTTP_ADDRESS);
}
Expand Down
1 change: 1 addition & 0 deletions rskj-core/src/main/resources/expected.conf
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ rpc = {
enabled = <enabled>
bind_address = <bind_address>
port = <port>
server_write_timeout_seconds = <timeout>
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rskj-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ rpc {
enabled = false
bind_address = localhost
port = 4445
# Shuts down the server when it's not able to write a response after a certain period (expressed in seconds)
server_write_timeout_seconds = 30
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.*;

public class RskJsonRpcHandlerTest {
public class RskWebSocketJsonRpcHandlerTest {
private static final SubscriptionId SAMPLE_SUBSCRIPTION_ID = new SubscriptionId("0x3075");
private static final EthSubscribeRequest SAMPLE_SUBSCRIBE_REQUEST = new EthSubscribeRequest(
JsonRpcVersion.V2_0,
Expand All @@ -47,15 +47,15 @@ public class RskJsonRpcHandlerTest {

);

private RskJsonRpcHandler handler;
private RskWebSocketJsonRpcHandler handler;
private EthSubscriptionNotificationEmitter emitter;
private JsonRpcSerializer serializer;

@Before
public void setUp() {
emitter = mock(EthSubscriptionNotificationEmitter.class);
serializer = mock(JsonRpcSerializer.class);
handler = new RskJsonRpcHandler(emitter, serializer);
handler = new RskWebSocketJsonRpcHandler(emitter, serializer);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package co.rsk.rpc.netty;

import co.rsk.config.TestSystemProperties;
import co.rsk.rpc.JacksonBasedRpcSerializer;
import co.rsk.rpc.ModuleDescription;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -46,14 +47,14 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class Web3WebSocketServerTest {

private static JsonNodeFactory JSON_NODE_FACTORY = JsonNodeFactory.instance;
private static ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final int DEFAULT_WRITE_TIMEOUT_SECONDS = 30;

private ExecutorService wsExecutor;

Expand All @@ -68,13 +69,24 @@ public void smokeTest() throws Exception {
String mockResult = "output";
when(web3Mock.web3_sha3(anyString())).thenReturn(mockResult);

int randomPort = 9998;//new ServerSocket(0).getLocalPort();
int randomPort = 9998;

TestSystemProperties testSystemProperties = new TestSystemProperties();

List<ModuleDescription> filteredModules = Collections.singletonList(new ModuleDescription("web3", "1.0", true, Collections.emptyList(), Collections.emptyList()));
RskJsonRpcHandler handler = new RskJsonRpcHandler(null, new JacksonBasedRpcSerializer());
RskWebSocketJsonRpcHandler handler = new RskWebSocketJsonRpcHandler(null, new JacksonBasedRpcSerializer());
JsonRpcWeb3ServerHandler serverHandler = new JsonRpcWeb3ServerHandler(web3Mock, filteredModules);
int serverWriteTimeoutSeconds = testSystemProperties.rpcWebSocketServerWriteTimeoutSeconds();

assertEquals(DEFAULT_WRITE_TIMEOUT_SECONDS, serverWriteTimeoutSeconds);

Web3WebSocketServer websocketServer = new Web3WebSocketServer(InetAddress.getLoopbackAddress(), randomPort, handler, serverHandler);
Web3WebSocketServer websocketServer = new Web3WebSocketServer(
InetAddress.getLoopbackAddress(),
randomPort,
handler,
serverHandler,
serverWriteTimeoutSeconds
);
websocketServer.start();

OkHttpClient wsClient = new OkHttpClient();
Expand Down

0 comments on commit e0589f3

Please sign in to comment.