Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modbus/TCP Security #97

Merged
merged 19 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.digitalpetri.modbus.internal.util.ExecutionQueue;
import com.digitalpetri.modbus.serial.SerialPortTransportConfig;
import com.digitalpetri.modbus.serial.SerialPortTransportConfig.Builder;
import com.digitalpetri.modbus.server.ModbusRequestContext.ModbusRtuRequestContext;
import com.digitalpetri.modbus.server.ModbusRtuServerTransport;
import com.fazecast.jSerialComm.SerialPort;
import com.fazecast.jSerialComm.SerialPortDataListener;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

/**
* Configuration for a {@link NettyTcpClientTransport}.
Expand All @@ -23,6 +26,9 @@
* {@link Bootstrap}.
* @param pipelineCustomizer a {@link Consumer} that can be used to customize the Netty
* {@link ChannelPipeline}.
* @param tlsEnabled whether to enable TLS (Modbus/TCP Security).
* @param keyManagerFactory the {@link KeyManagerFactory} to use if TLS is enabled.
* @param trustManagerFactory the {@link TrustManagerFactory} to use if TLS is enabled.
*/
public record NettyClientTransportConfig(
String hostname,
Expand All @@ -33,9 +39,12 @@ public record NettyClientTransportConfig(
EventLoopGroup eventLoopGroup,
ExecutorService executor,
Consumer<Bootstrap> bootstrapCustomizer,
Consumer<ChannelPipeline> pipelineCustomizer
Consumer<ChannelPipeline> pipelineCustomizer,
boolean tlsEnabled,
Optional<KeyManagerFactory> keyManagerFactory,
Optional<TrustManagerFactory> trustManagerFactory
) {

/**
* Create a new {@link NettyClientTransportConfig} with a callback that allows customizing the
* configuration.
Expand All @@ -59,7 +68,7 @@ public static class Builder {
/**
* The port to connect to.
*/
public int port = 502;
public int port = -1;

/**
* The connect timeout.
Expand Down Expand Up @@ -100,16 +109,42 @@ public static class Builder {
*/
public Consumer<ChannelPipeline> pipelineCustomizer = p -> {};

/**
* Whether to enable TLS (Modbus/TCP Security).
*/
public boolean tlsEnabled = false;

/**
* The {@link KeyManagerFactory} to use if TLS is enabled.
*/
public KeyManagerFactory keyManagerFactory = null;

/**
* The {@link TrustManagerFactory} to use if TLS is enabled.
*/
public TrustManagerFactory trustManagerFactory = null;

public NettyClientTransportConfig build() {
if (hostname == null) {
throw new NullPointerException("hostname must not be null");
}
if (port == -1) {
port = tlsEnabled ? 802 : 502;
}
if (eventLoopGroup == null) {
eventLoopGroup = Netty.sharedEventLoop();
}
if (executor == null) {
executor = Modbus.sharedExecutor();
}
if (tlsEnabled) {
if (keyManagerFactory == null) {
throw new NullPointerException("keyManagerFactory must not be null");
}
if (trustManagerFactory == null) {
throw new NullPointerException("trustManagerFactory must not be null");
}
}

return new NettyClientTransportConfig(
hostname,
Expand All @@ -120,7 +155,10 @@ public NettyClientTransportConfig build() {
eventLoopGroup,
executor,
bootstrapCustomizer,
pipelineCustomizer
pipelineCustomizer,
tlsEnabled,
Optional.ofNullable(keyManagerFactory),
Optional.ofNullable(trustManagerFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProtocols;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -158,7 +162,20 @@ public CompletableFuture<Channel> connect(FsmContext<State, Event> fsmContext) {
.option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel channel) {
protected void initChannel(SocketChannel channel) throws Exception {
if (config.tlsEnabled()) {
SslContext sslContext = SslContextBuilder.forClient()
.clientAuth(ClientAuth.REQUIRE)
.keyManager(config.keyManagerFactory().orElseThrow())
.trustManager(config.trustManagerFactory().orElseThrow())
.protocols(SslProtocols.TLS_v1_2, SslProtocols.TLS_v1_3)
.build();

channel.pipeline().addLast(
sslContext.newHandler(channel.alloc(), config.hostname(), config.port())
);
}

channel.pipeline().addLast(new ModbusRtuClientFrameReceiver());

config.pipelineCustomizer().accept(channel.pipeline());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProtocols;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
Expand Down Expand Up @@ -163,6 +168,13 @@ protected void channelRead0(ChannelHandlerContext ctx, ModbusTcpFrame frame) {
executionQueue.submit(() -> frameReceiver.accept(frame));
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("Exception caught", cause);
ctx.close();
}

}

private class ModbusTcpChannelActions implements ChannelActions {
Expand All @@ -174,15 +186,7 @@ public CompletableFuture<Channel> connect(FsmContext<State, Event> fsmContext) {
.group(config.eventLoopGroup())
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.connectTimeout().toMillis())
.option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel channel) {
channel.pipeline().addLast(new ModbusTcpCodec());
channel.pipeline().addLast(new ModbusTcpFrameHandler());

config.pipelineCustomizer().accept(channel.pipeline());
}
});
.handler(newChannelInitializer());

config.bootstrapCustomizer().accept(bootstrap);

Expand All @@ -191,7 +195,21 @@ protected void initChannel(SocketChannel channel) {
bootstrap.connect(config.hostname(), config.port()).addListener(
(ChannelFutureListener) channelFuture -> {
if (channelFuture.isSuccess()) {
future.complete(channelFuture.channel());
Channel channel = channelFuture.channel();

if (config.tlsEnabled()) {
channel.pipeline().get(SslHandler.class).handshakeFuture().addListener(
handshakeFuture -> {
if (handshakeFuture.isSuccess()) {
future.complete(channel);
} else {
future.completeExceptionally(handshakeFuture.cause());
}
}
);
} else {
future.complete(channel);
}
} else {
future.completeExceptionally(channelFuture.cause());
}
Expand All @@ -201,6 +219,31 @@ protected void initChannel(SocketChannel channel) {
return future;
}

private ChannelInitializer<SocketChannel> newChannelInitializer() {
return new ChannelInitializer<>() {
@Override
protected void initChannel(SocketChannel channel) throws Exception {
if (config.tlsEnabled()) {
SslContext sslContext = SslContextBuilder.forClient()
.clientAuth(ClientAuth.REQUIRE)
.keyManager(config.keyManagerFactory().orElseThrow())
.trustManager(config.trustManagerFactory().orElseThrow())
.protocols(SslProtocols.TLS_v1_2, SslProtocols.TLS_v1_3)
.build();

channel.pipeline().addLast(
sslContext.newHandler(channel.alloc(), config.hostname(), config.port())
);
}

channel.pipeline().addLast(new ModbusTcpCodec());
channel.pipeline().addLast(new ModbusTcpFrameHandler());

config.pipelineCustomizer().accept(channel.pipeline());
}
};
}

@Override
public CompletableFuture<Void> disconnect(
FsmContext<State, Event> fsmContext, Channel channel) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.digitalpetri.modbus.tcp.security;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

public class SecurityUtil {

/**
* Create a {@link KeyManagerFactory} from a private key and certificates.
*
* @param privateKey the private key.
* @param certificates the certificates.
* @return a {@link KeyManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
* @throws IOException if an error occurs.
*/
public static KeyManagerFactory createKeyManagerFactory(
PrivateKey privateKey,
X509Certificate... certificates
) throws GeneralSecurityException, IOException {

KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);

keyStore.setKeyEntry("key", privateKey, new char[0], certificates);

return createKeyManagerFactory(keyStore, new char[0]);
}

/**
* Create a {@link KeyManagerFactory} from a {@link KeyStore}.
*
* @param keyStore the {@link KeyStore}.
* @param keyStorePassword the password for the {@link KeyStore}.
* @return a {@link KeyManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
*/
public static KeyManagerFactory createKeyManagerFactory(
KeyStore keyStore,
char[] keyStorePassword
) throws GeneralSecurityException {

KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());

keyManagerFactory.init(keyStore, keyStorePassword);

return keyManagerFactory;
}

/**
* Create a {@link TrustManagerFactory} from certificates.
*
* @param certificates the certificates.
* @return a {@link TrustManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
* @throws IOException if an error occurs.
*/
public static TrustManagerFactory createTrustManagerFactory(
X509Certificate... certificates
) throws GeneralSecurityException, IOException {

KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);

for (int i = 0; i < certificates.length; i++) {
keyStore.setCertificateEntry("cert" + i, certificates[i]);
}

return createTrustManagerFactory(keyStore);
}

/**
* Create a {@link TrustManagerFactory} from a {@link KeyStore}.
*
* @param keyStore the {@link KeyStore}.
* @return a {@link TrustManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
*/
public static TrustManagerFactory createTrustManagerFactory(
KeyStore keyStore
) throws GeneralSecurityException {

TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());

trustManagerFactory.init(keyStore);

return trustManagerFactory;
}

}
Loading
Loading