Skip to content

Commit

Permalink
[NO-ISSUE] Implement Origin check for terminal interpreter WebSocket …
Browse files Browse the repository at this point in the history
…connections

### What is this PR for?

This PR adds an Origin check to ensure that WebSocket connections are initiated from trusted sources only.
By validating the `Origin` header in the initial WebSocket handshake, we can prevent unauthorized or malicious websites from establishing WebSocket connections with our server.

Changes:
- Added server-side validation of the `Origin` header during WebSocket connection requests.

Other security enhancements may be needed and can be handled in future iterations.

### What type of PR is it?

Improvement

### Todos
* [ ] - Task

### How should this be tested?
* Strongly recommended: add automated unit tests for any new or changed behavior
* Outline any manual steps to test the PR here.

### Screenshots (if appropriate)

### Questions:
* Does the license files need to update? No
* Is there breaking changes for older versions? No
* Does this needs documentation? No

Closes #4823 from tbonelee/websocket.

Signed-off-by: Jongyoul Lee <[email protected]>
(cherry picked from commit 3575a3c)
Signed-off-by: Jongyoul Lee <[email protected]>
  • Loading branch information
tbonelee authored and jongyoul committed Nov 2, 2024
1 parent 8941cd8 commit f4847ea
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public TerminalInterpreter(Properties property) {
private InterpreterContext intpContext;

private int terminalPort = 0;
private String terminalHostIp;

// Internal and external IP mapping of zeppelin server
private HashMap<String, String> mapIpMapping = new HashMap<>();
Expand Down Expand Up @@ -109,7 +110,11 @@ public InterpreterResult internalInterpret(String cmd, InterpreterContext contex
if (null == terminalThread) {
try {
terminalPort = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
terminalThread = new TerminalThread(terminalPort);
terminalHostIp = RemoteInterpreterUtils.findAvailableHostAddress();
LOGGER.info("Terminal host IP: " + terminalHostIp);
LOGGER.info("Terminal port: " + terminalPort);
String allowedOrigin = generateOrigin(terminalHostIp, terminalPort);
terminalThread = new TerminalThread(terminalPort, allowedOrigin);
terminalThread.start();
} catch (IOException e) {
LOGGER.error(e.getMessage(), e);
Expand All @@ -136,20 +141,20 @@ public InterpreterResult internalInterpret(String cmd, InterpreterContext contex
mapIpMapping = gson.fromJson(strIpMapping, new TypeToken<Map<String, String>>(){}.getType());
}

createTerminalDashboard(context.getNoteId(), context.getParagraphId(), terminalPort);
createTerminalDashboard(context.getNoteId(), context.getParagraphId(),
terminalHostIp, terminalPort);

return new InterpreterResult(Code.SUCCESS);
}

public void createTerminalDashboard(String noteId, String paragraphId, int port) {
String hostName = "", hostIp = "";
public void createTerminalDashboard(String noteId, String paragraphId, String hostIp, int port) {
String hostName = "";
URL urlTemplate = Resources.getResource("ui_templates/terminal-dashboard.jinja");
String template = null;
try {
template = Resources.toString(urlTemplate, Charsets.UTF_8);
InetAddress addr = InetAddress.getLocalHost();
hostName = addr.getHostName().toString();
hostIp = RemoteInterpreterUtils.findAvailableHostAddress();

// Internal and external IP mapping of zeppelin server
if (mapIpMapping.containsKey(hostIp)) {
Expand All @@ -164,7 +169,7 @@ public void createTerminalDashboard(String noteId, String paragraphId, int port)
Jinjava jinjava = new Jinjava();
HashMap<String, Object> jinjaParams = new HashMap();
Date now = new Date();
String terminalServerUrl = "http://" + hostIp + ":" + port +
String terminalServerUrl = generateOrigin(hostIp, port) +
"?noteId=" + noteId + "&paragraphId=" + paragraphId + "&t=" + now.getTime();
jinjaParams.put("HOST_NAME", hostName);
jinjaParams.put("HOST_IP", hostIp);
Expand All @@ -183,6 +188,10 @@ public void createTerminalDashboard(String noteId, String paragraphId, int port)
}
}

private String generateOrigin(String hostIp, int port) {
return "http://" + hostIp + ":" + port;
}

@Override
public void cancel(InterpreterContext context) {
}
Expand Down Expand Up @@ -238,6 +247,11 @@ public int getTerminalPort() {
return terminalPort;
}

@VisibleForTesting
public String getTerminalHostIp() {
return terminalHostIp;
}

@VisibleForTesting
public boolean terminalThreadIsRunning() {
return terminalThread.isRunning();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.zeppelin.shell.terminal;

import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;

import org.apache.zeppelin.shell.terminal.websocket.TerminalSessionConfigurator;
import org.apache.zeppelin.shell.terminal.websocket.TerminalSocket;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
Expand All @@ -38,9 +40,11 @@ public class TerminalThread extends Thread {
private Server jettyServer = new Server();

private int port = 0;
private String allwedOrigin;

public TerminalThread(int port) {
public TerminalThread(int port, String allwedOrigin) {
this.port = port;
this.allwedOrigin = allwedOrigin;
}

public void run() {
Expand Down Expand Up @@ -72,7 +76,10 @@ public void run() {

try {
ServerContainer container = WebSocketServerContainerInitializer.configureContext(context);
container.addEndpoint(TerminalSocket.class);
container.addEndpoint(
ServerEndpointConfig.Builder.create(TerminalSocket.class, "/")
.configurator(new TerminalSessionConfigurator(allwedOrigin))
.build());
jettyServer.start();
jettyServer.join();
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.zeppelin.shell.terminal.websocket;

import javax.websocket.server.ServerEndpointConfig.Configurator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TerminalSessionConfigurator extends Configurator {
private static final Logger LOGGER = LoggerFactory.getLogger(TerminalSessionConfigurator.class);
private String allowedOrigin;

public TerminalSessionConfigurator(String allowedOrigin) {
this.allowedOrigin = allowedOrigin;
}

@Override
public boolean checkOrigin(String originHeaderValue) {
boolean allowed = allowedOrigin.equals(originHeaderValue);
LOGGER.info("Checking origin for TerminalSessionConfigurator: " +
originHeaderValue + " allowed: " + allowed);
return allowed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

package org.apache.zeppelin.shell;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Builder;
import javax.websocket.ClientEndpointConfig.Configurator;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterResult;
Expand All @@ -34,6 +40,7 @@
import javax.websocket.WebSocketContainer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
Expand Down Expand Up @@ -81,11 +88,17 @@ void testInvalidCommand() {
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);

URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + "/terminal/");
URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() +
":" + terminal.getTerminalPort() + "/terminal/");
LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort();
LOGGER.info("origin: " + origin);
ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin);
webSocketContainer = ContainerProvider.getWebSocketContainer();

// Attempt Connect
session = webSocketContainer.connectToServer(TerminalSocketTest.class, uri);
session = webSocketContainer.connectToServer(
TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri);

// Send Start terminal service message
String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
Expand Down Expand Up @@ -161,11 +174,17 @@ void testValidCommand() {
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);

URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + "/terminal/");
URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() +
":" + terminal.getTerminalPort() + "/terminal/");
LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort();
LOGGER.info("origin: " + origin);
ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin);
webSocketContainer = ContainerProvider.getWebSocketContainer();

// Attempt Connect
session = webSocketContainer.connectToServer(TerminalSocketTest.class, uri);
session = webSocketContainer.connectToServer(
TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri);

// Send Start terminal service message
String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
Expand Down Expand Up @@ -229,4 +248,118 @@ void testValidCommand() {
}
}
}

@Test
void testValidOrigin() {
Session session = null;

// mock connect terminal
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);

URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() +
":" + terminal.getTerminalPort() + "/terminal/");
LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort();
LOGGER.info("origin: " + origin);
ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin);
WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer();

Throwable exception = null;
try {
// Attempt Connect
session = webSocketContainer.connectToServer(
TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri);
} catch (DeploymentException e) {
exception = e;
} catch (IOException e) {
exception = e;
} finally {
if (session != null) {
try {
session.close();
} catch (IOException e) {
LOGGER.error(e.getMessage(), e);
}
}

// Force lifecycle stop when done with container.
// This is to free up threads and resources that the
// JSR-356 container allocates. But unfortunately
// the JSR-356 spec does not handle lifecycles (yet)
if (webSocketContainer instanceof LifeCycle) {
try {
((LifeCycle) webSocketContainer).stop();
} catch (Exception e) {
LOGGER.error(e.getMessage(), e);
}
}
}

assertNull(exception);
}

@Test
void testInvalidOrigin() {
Session session = null;

// mock connect terminal
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);

URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() +
":" + terminal.getTerminalPort() + "/terminal/");
LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
String origin = "http://invalid-origin";
LOGGER.info("origin: " + origin);
ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin);
WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer();

Throwable exception = null;
try {
// Attempt Connect
session = webSocketContainer.connectToServer(
TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri);
} catch (DeploymentException e) {
exception = e;
} catch (IOException e) {
exception = e;
} finally {
if (session != null) {
try {
session.close();
} catch (IOException e) {
LOGGER.error(e.getMessage(), e);
}
}

// Force lifecycle stop when done with container.
// This is to free up threads and resources that the
// JSR-356 container allocates. But unfortunately
// the JSR-356 spec does not handle lifecycles (yet)
if (webSocketContainer instanceof LifeCycle) {
try {
((LifeCycle) webSocketContainer).stop();
} catch (Exception e) {
LOGGER.error(e.getMessage(), e);
}
}
}

assertTrue(exception instanceof IOException);
assertEquals("Connect failure", exception.getMessage());
}

private static ClientEndpointConfig getOriginRequestHeaderConfig(String origin) {
Configurator configurator = new Configurator() {
@Override
public void beforeRequest(Map<String, List<String>> headers) {
headers.put("Origin", Arrays.asList(origin));
}
};
ClientEndpointConfig clientEndpointConfig = Builder.create()
.configurator(configurator)
.build();
return clientEndpointConfig;
}
}
Loading

0 comments on commit f4847ea

Please sign in to comment.