From ff742c34684ad6fe535100d1e8fc52dd7cfa3654 Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Tue, 31 Dec 2024 17:30:54 -0700 Subject: [PATCH] JSSE: calling SSLSocket.close() should interrupt threads blocked in select()/poll() --- native/com_wolfssl_WolfSSLSession.c | 225 ++++++++++++++-- native/com_wolfssl_WolfSSLSession.h | 8 + src/java/com/wolfssl/WolfSSLSession.java | 25 ++ .../wolfssl/provider/jsse/WolfSSLSocket.java | 17 +- .../provider/jsse/test/WolfSSLSocketTest.java | 251 ++++++++++++++++++ 5 files changed, 499 insertions(+), 27 deletions(-) diff --git a/native/com_wolfssl_WolfSSLSession.c b/native/com_wolfssl_WolfSSLSession.c index ed9d909e..6781a464 100644 --- a/native/com_wolfssl_WolfSSLSession.c +++ b/native/com_wolfssl_WolfSSLSession.c @@ -72,8 +72,13 @@ static jobject g_crlCbIfaceObj; * function calls. Stored inside WOLFSSL app data, set with * wolfSSL_set_app_data(), retrieved with wolfSSL_get_app_data(). * Global callback objects are created with NewGlobalRef(), then freed - * inside freeSSL() with DeleteGlobalRef(). */ + * inside freeSSL() with DeleteGlobalRef(). + * + * interruptFds[2] is a pipe() used for non-Windows platforms. This pipe is + * used to interrupt threads blocked inside select()/poll() when a separate + * Java thread calls close() on the SSLSocket. */ typedef struct SSLAppData { + int interruptFds[2]; /* pipe for interrupting socketSelect() */ wolfSSL_Mutex* jniSessLock; /* WOLFSSL session lock */ jobject* g_verifySSLCbIfaceObj; /* Java verify callback [global ref] */ } SSLAppData; @@ -191,6 +196,36 @@ int NativeSSLVerifyCallback(int preverify_ok, WOLFSSL_X509_STORE_CTX* store) return retval; } +#ifndef USE_WINDOWS_API + +/* Close interrupt pipe() descriptors and reset back to -1. */ +static void closeInterruptPipe(SSLAppData* appData) +{ + if (appData != NULL) { + if (appData->interruptFds[0] != -1) { + close(appData->interruptFds[0]); + appData->interruptFds[0] = -1; + } + if (appData->interruptFds[1] != -1) { + close(appData->interruptFds[1]); + appData->interruptFds[1] = -1; + } + } +} + +/* Signal to threads blocked in select() or poll() to wake up, by writing + * one byte to the appData.interruptFds[1] pipe. */ +static void writeToInterruptPipe(SSLAppData* appData) +{ + if (appData != NULL) { + if (appData->interruptFds[1] != -1) { + write(appData->interruptFds[1], "1", 1); + } + } +} + +#endif /* !USE_WINDOWS_API */ + JNIEXPORT jlong JNICALL Java_com_wolfssl_WolfSSLSession_newSSL (JNIEnv* jenv, jobject jcl, jlong ctx) { @@ -251,10 +286,44 @@ JNIEXPORT jlong JNICALL Java_com_wolfssl_WolfSSLSession_newSSL wc_InitMutex(jniSessLock); appData->jniSessLock = jniSessLock; + /* set up interrupt pipe for SSLSocket.close() to use if/when needed. + * currently only non-Windows platforms supported due to Windows not + * supporting direct/same pipe() operation. Make read pipe non + * blocking since byte read from it could have already been taken + * out by either reader/writer thread before the other has a chance + * to read it. But, we only use it for waking us up and don't care + * much about actually reading the byte passed over the pipe. */ +#ifndef USE_WINDOWS_API + appData->interruptFds[0] = -1; + appData->interruptFds[1] = -1; + ret = pipe(appData->interruptFds); + if (ret == -1) { + printf("error setting up pipe() for interruptFds[] in newSSL\n"); + (*jenv)->DeleteGlobalRef(jenv, *g_cachedSSLObj); + XFREE(appData, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(g_cachedSSLObj, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wolfSSL_free((WOLFSSL*)(uintptr_t)sslPtr); + return SSL_FAILURE; + } + + ret = fcntl(appData->interruptFds[0], F_SETFL, + fcntl(appData->interruptFds[0], F_GETFL, 0) | O_NONBLOCK); + if (ret < 0) { + printf("error setting interruptFds[0] non-blocking in newSSL\n"); + closeInterruptPipe(appData); + (*jenv)->DeleteGlobalRef(jenv, *g_cachedSSLObj); + XFREE(appData, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(g_cachedSSLObj, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wolfSSL_free((WOLFSSL*)(uintptr_t)sslPtr); + return SSL_FAILURE; + } +#endif /* !USE_WINDOWS_API */ + /* cache associated WolfSSLSession jobject in native WOLFSSL */ ret = wolfSSL_set_jobject((WOLFSSL*)(uintptr_t)sslPtr, g_cachedSSLObj); if (ret != SSL_SUCCESS) { printf("error storing jobject in wolfSSL native session\n"); + closeInterruptPipe(appData); (*jenv)->DeleteGlobalRef(jenv, *g_cachedSSLObj); XFREE(appData, NULL, DYNAMIC_TYPE_TMP_BUFFER); XFREE(g_cachedSSLObj, NULL, DYNAMIC_TYPE_TMP_BUFFER); @@ -266,6 +335,7 @@ JNIEXPORT jlong JNICALL Java_com_wolfssl_WolfSSLSession_newSSL if (wolfSSL_set_app_data( (WOLFSSL*)(uintptr_t)sslPtr, appData) != SSL_SUCCESS) { printf("error setting WOLFSSL app data in newSSL\n"); + closeInterruptPipe(appData); (*jenv)->DeleteGlobalRef(jenv, *g_cachedSSLObj); XFREE(jniSessLock, NULL, DYNAMIC_TYPE_TMP_BUFFER); XFREE(appData, NULL, DYNAMIC_TYPE_TMP_BUFFER); @@ -641,12 +711,13 @@ enum { * @return possible return values are: * WOLFJNI_IO_EVENT_FAIL * WOLFJNI_IO_EVENT_ERROR + * WOLFJNI_IO_EVENT_FD_CLOSED * WOLFJNI_IO_EVENT_TIMEOUT * WOLFJNI_IO_EVENT_RECV_READY * WOLFJNI_IO_EVENT_SEND_READY * WOLFJNI_IO_EVENT_INVALID_TIMEOUT */ -static int socketSelect(int sockfd, int timeout_ms, int rx) +static int socketSelect(SSLAppData* appData, int sockfd, int timeout_ms, int rx) { fd_set fds, errfds; fd_set* recvfds = NULL; @@ -654,20 +725,35 @@ static int socketSelect(int sockfd, int timeout_ms, int rx) int nfds = sockfd + 1; int result = 0; struct timeval timeout; + char tmpBuf[1]; /* Java Socket does not support negative timeouts, sanitize */ if (timeout_ms < 0) { return WOLFJNI_IO_EVENT_INVALID_TIMEOUT; } + if (appData == NULL) { + return WOLFJNI_IO_EVENT_ERROR; + } + #ifndef USE_WINDOWS_API do { #endif timeout.tv_sec = timeout_ms / 1000; timeout.tv_usec = (timeout_ms % 1000) * 1000; + /* file/socket descriptors */ FD_ZERO(&fds); FD_SET(sockfd, &fds); +#ifndef USE_WINDOWS_API + FD_SET(appData->interruptFds[0], &fds); + /* nfds should be set to the highest number descriptor plus 1 */ + if (appData->interruptFds[0] > sockfd) { + nfds = appData->interruptFds[0] + 1; + } +#endif /* !USE_WINDOWS_API */ + + /* error descriptors */ FD_ZERO(&errfds); FD_SET(sockfd, &errfds); @@ -692,9 +778,25 @@ static int socketSelect(int sockfd, int timeout_ms, int rx) } else { return WOLFJNI_IO_EVENT_SEND_READY; } - } else if (FD_ISSET(sockfd, &errfds)) { + } + else if (FD_ISSET(sockfd, &errfds)) { return WOLFJNI_IO_EVENT_ERROR; } +#ifndef USE_WINDOWS_API + else if (FD_ISSET(appData->interruptFds[0], &fds)) { + /* We got interrupted by our interrupt fd, due to a Java + * thread calling SSLSocket.close(). Try to read byte that + * was placed on our interruptFds[0] descriptor, but not + * an error if not there. Another read/write() may have + * already read it off. We just want to be interrupted, + * byte value does not matter. */ + do { + read(appData->interruptFds[0], tmpBuf, 1); + } while (errno == EINTR); + + return WOLFJNI_IO_EVENT_FD_CLOSED; + } +#endif /* !USE_WINDOWS_API */ } #ifndef USE_WINDOWS_API @@ -736,11 +838,17 @@ static int socketSelect(int sockfd, int timeout_ms, int rx) * WOLFJNI_IO_EVENT_POLLHUP * WOLFJNI_IO_EVENT_INVALID_TIMEOUT */ -static int socketPoll(int sockfd, int timeout_ms, int rx, int tx) +static int socketPoll(SSLAppData* appData, int sockfd, int timeout_ms, + int rx, int tx) { int ret; int timeout; - struct pollfd fds[1]; + struct pollfd fds[2]; + char tmpBuf[1]; + + if (appData == NULL) { + return WOLFJNI_IO_EVENT_ERROR; + } /* Sanitize timeout and convert from Java to poll() expectations */ timeout = timeout_ms; @@ -750,6 +858,7 @@ static int socketPoll(int sockfd, int timeout_ms, int rx, int tx) timeout = -1; } + /* fd for socket I/O */ fds[0].fd = sockfd; fds[0].events = 0; if (tx) { @@ -758,27 +867,44 @@ static int socketPoll(int sockfd, int timeout_ms, int rx, int tx) if (rx) { fds[0].events |= POLLIN; } + /* fd for interrupt / signaling SSLSocket.close() */ + fds[1].fd = appData->interruptFds[0]; + fds[1].events = POLLIN; do { - ret = poll(fds, 1, timeout); + ret = poll(fds, 2, timeout); if (ret == 0) { return WOLFJNI_IO_EVENT_TIMEOUT; - } else if (ret > 0) { - if (fds[0].revents & POLLIN || + } + else if (ret > 0) { + if (fds[1].revents & POLLIN) { + /* received data on interrupt pipe, read and return + * that descriptor is closed (closing) */ + do { + read(appData->interruptFds[0], tmpBuf, 1); + } while (errno == EINTR); + + return WOLFJNI_IO_EVENT_FD_CLOSED; + } + else if (fds[0].revents & POLLIN || fds[0].revents & POLLPRI) { /* read possible */ return WOLFJNI_IO_EVENT_RECV_READY; - } else if (fds[0].revents & POLLOUT) { /* write possible */ + } + else if (fds[0].revents & POLLOUT) { /* write possible */ return WOLFJNI_IO_EVENT_SEND_READY; - } else if (fds[0].revents & POLLNVAL) { /* fd not open */ + } + else if (fds[0].revents & POLLNVAL) { /* fd not open */ return WOLFJNI_IO_EVENT_FD_CLOSED; - } else if (fds[0].revents & POLLERR) { /* exceptional error */ + } + else if (fds[0].revents & POLLERR) { /* exceptional error */ return WOLFJNI_IO_EVENT_ERROR; - } else if (fds[0].revents & POLLHUP) { /* sock disconnected */ + } + else if (fds[0].revents & POLLHUP) { /* sock disconnected */ return WOLFJNI_IO_EVENT_POLLHUP; } } @@ -795,7 +921,9 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect { int ret = 0, err = 0, sockfd = 0; int pollRx = 0; +#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; +#endif wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; @@ -852,14 +980,16 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect if (err == SSL_ERROR_WANT_READ) { pollRx = 1; } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) else if (err == SSL_ERROR_WANT_WRITE) { pollTx = 1; } + #endif #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); + ret = socketSelect(appData, sockfd, (int)timeout, pollRx); #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, pollTx); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || (ret == WOLFJNI_IO_EVENT_SEND_READY)) { @@ -898,7 +1028,9 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write byte* data = NULL; int ret = SSL_FAILURE, err, sockfd; int pollRx = 0; +#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; +#endif wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; @@ -963,14 +1095,16 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write if (err == SSL_ERROR_WANT_READ) { pollRx = 1; } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) else if (err == SSL_ERROR_WANT_WRITE) { pollTx = 1; } + #endif #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); + ret = socketSelect(appData, sockfd, (int)timeout, pollRx); #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, pollTx); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || (ret == WOLFJNI_IO_EVENT_SEND_READY)) { @@ -1009,7 +1143,9 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read byte* data = NULL; int size = 0, ret, err, sockfd; int pollRx = 0; +#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; +#endif wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; @@ -1071,14 +1207,16 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read if (err == SSL_ERROR_WANT_READ) { pollRx = 1; } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) else if (err == SSL_ERROR_WANT_WRITE) { pollTx = 1; } + #endif #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); + ret = socketSelect(appData, sockfd, (int)timeout, pollRx); #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, pollTx); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || (ret == WOLFJNI_IO_EVENT_SEND_READY)) { @@ -1110,7 +1248,9 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_accept { int ret = 0, err, sockfd; int pollRx = 0; +#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; +#endif wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; @@ -1167,14 +1307,16 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_accept if (err == SSL_ERROR_WANT_READ) { pollRx = 1; } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) else if (err == SSL_ERROR_WANT_WRITE) { pollTx = 1; } + #endif #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); + ret = socketSelect(appData, sockfd, (int)timeout, pollRx); #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, pollTx); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || (ret == WOLFJNI_IO_EVENT_SEND_READY)) { @@ -1247,6 +1389,10 @@ JNIEXPORT void JNICALL Java_com_wolfssl_WolfSSLSession_freeSSL XFREE(g_cachedVerifyCb, NULL, DYNAMIC_TYPE_TMP_BUFFER); g_cachedVerifyCb = NULL; } +#ifndef USE_WINDOWS_API + /* close pipe() descriptors */ + closeInterruptPipe(appData); +#endif /* free appData */ XFREE(appData, NULL, DYNAMIC_TYPE_TMP_BUFFER); appData = NULL; @@ -1359,7 +1505,9 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_shutdownSSL { int ret = 0, err, sockfd; int pollRx = 0; +#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; +#endif wolfSSL_Mutex* jniSessLock; SSLAppData* appData = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; @@ -1416,14 +1564,16 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_shutdownSSL if (err == SSL_ERROR_WANT_READ) { pollRx = 1; } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) else if (err == SSL_ERROR_WANT_WRITE) { pollTx = 1; } + #endif #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); + ret = socketSelect(appData, sockfd, (int)timeout, pollRx); #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, pollTx); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || (ret == WOLFJNI_IO_EVENT_SEND_READY)) { @@ -1628,10 +1778,10 @@ JNIEXPORT jlong JNICALL Java_com_wolfssl_WolfSSLSession_get1Session #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) /* Default to select() on Windows or if WOLFJNI_USE_IO_SELECT */ - ret = socketSelect(sockfd, + ret = socketSelect(appData, sockfd, (int)WOLFSSL_JNI_DEFAULT_PEEK_TIMEOUT, 1); #else - ret = socketPoll(sockfd, + ret = socketPoll(appData, sockfd, (int)WOLFSSL_JNI_DEFAULT_PEEK_TIMEOUT, 1, 0); #endif if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || @@ -5251,6 +5401,33 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_hasTicket #endif } +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_interruptBlockedIO + (JNIEnv* jenv, jobject jcl, jlong sslPtr) +{ + SSLAppData* appData = NULL; + WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; + (void)jcl; + + if (jenv == NULL) { + return SSL_FAILURE; + } + +#ifndef USE_WINDOWS_API + /* get session mutex from SSL app data */ + appData = (SSLAppData*)wolfSSL_get_app_data(ssl); + if (appData == NULL) { + return WOLFSSL_FAILURE; + } + + /* signal any blocked threads in select()/poll() to wake up, so we + * don't have a deadlock when trying to lock jniSessLock next */ + writeToInterruptPipe(appData); + +#endif /* USE_WINDOWS_API */ + + return SSL_SUCCESS; +} + JNIEXPORT void JNICALL Java_com_wolfssl_WolfSSLSession_setSSLIORecv (JNIEnv* jenv, jobject jcl, jlong sslPtr) { diff --git a/native/com_wolfssl_WolfSSLSession.h b/native/com_wolfssl_WolfSSLSession.h index e7941971..9039d264 100644 --- a/native/com_wolfssl_WolfSSLSession.h +++ b/native/com_wolfssl_WolfSSLSession.h @@ -871,6 +871,14 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_useSupportedCurve JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_hasTicket (JNIEnv *, jobject, jlong); +/* + * Class: com_wolfssl_WolfSSLSession + * Method: interruptBlockedIO + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_interruptBlockedIO + (JNIEnv *, jobject, jlong); + #ifdef __cplusplus } #endif diff --git a/src/java/com/wolfssl/WolfSSLSession.java b/src/java/com/wolfssl/WolfSSLSession.java index cb5d2a05..d06c3111 100644 --- a/src/java/com/wolfssl/WolfSSLSession.java +++ b/src/java/com/wolfssl/WolfSSLSession.java @@ -354,6 +354,7 @@ private native int setTlsHmacInner(long ssl, byte[] inner, long sz, private native int set1SigAlgsList(long ssl, String list); private native int useSupportedCurve(long ssl, int name); private native int hasTicket(long session); + private native int interruptBlockedIO(long ssl); /* ------------------- session-specific methods --------------------- */ @@ -4042,6 +4043,30 @@ public synchronized void setIOSend(WolfSSLIOSendCallback callback) } } + /** + * Interrupt native I/O operations blocked inside select()/poll(). + * + * This is used by wolfJSSE when SSLSocket.close() is called, to wake up + * threads that are blocked in select()/poll(). + * + * @return WolfSSL.SSL_SUCCESS on success, negative on error. + * + * @throws IllegalStateException WolfSSLSession has been freed + */ + public synchronized int interruptBlockedIO() + throws IllegalStateException { + + confirmObjectIsActive(); + + /* Not synchronizing on sslLock, since we want to interrupt threads + * blocked on I/O operations, which will already hold sslLock */ + + WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, + WolfSSLDebug.INFO, "entered interruptBlockedIO()"); + + return interruptBlockedIO(this.sslPtr); + } + /** * Use SNI name with this session. * diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java b/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java index 12b83256..73ec1907 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java @@ -2016,6 +2016,10 @@ public synchronized void close() throws IOException { WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, "shutting down SSL/TLS connection"); + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, + "signaling any blocked I/O threads to wake up"); + ssl.interruptBlockedIO(); + WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, "thread trying to get ioLock (shutdown)"); @@ -2036,12 +2040,15 @@ public synchronized void close() throws IOException { } try { + /* Use SO_LINGER value when calling + * shutdown here, since we are closing the + * socket */ if (this.socket != null) { ret = ssl.shutdownSSL( - this.socket.getSoTimeout()); + this.socket.getSoLinger()); } else { ret = ssl.shutdownSSL( - super.getSoTimeout()); + super.getSoLinger()); } WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO, @@ -2687,7 +2694,11 @@ public synchronized int read(byte[] b, int off, int len) throw e; } catch (IllegalStateException e) { - throw new IOException(e); + /* SSLSocket.close() may have already called freeSSL(), + * thus causing a 'WolfSSLSession object has been freed' + * IllegalStateException to be thrown from + * WolfSSLSession.read(). Return as a SocketException here. */ + throw new SocketException(e.getMessage()); } /* return number of bytes read */ diff --git a/src/test/com/wolfssl/provider/jsse/test/WolfSSLSocketTest.java b/src/test/com/wolfssl/provider/jsse/test/WolfSSLSocketTest.java index 712af895..5985b28a 100644 --- a/src/test/com/wolfssl/provider/jsse/test/WolfSSLSocketTest.java +++ b/src/test/com/wolfssl/provider/jsse/test/WolfSSLSocketTest.java @@ -59,6 +59,7 @@ import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSessionContext; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLParameters; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; @@ -116,6 +117,8 @@ public void testSessionResumptionWithTicketEnabled(); public void testDoubleSocketClose(); public void testSocketConnectException(); + public void testSocketCloseInterruptsWrite(); + public void testSocketCloseInterruptsRead(); */ public class WolfSSLSocketTest { @@ -2746,6 +2749,254 @@ public Void call() throws Exception { System.out.println("\t... passed"); } + /* Test timeout set to 10000 ms (10 sec) in case inerrupt code is not + * working as expected, we will see the timeout as a hard error that + * this test has failed */ + @Test(timeout = 10000) + public void testSocketCloseInterruptsWrite() throws Exception { + + String protocol = null; + SSLServerSocketFactory ssf = null; + SSLServerSocket ss = null; + SSLSocketFactory sf = null; + boolean passed = false; + + System.out.print("\tTesting close/write interrupt"); + + if (WolfSSL.TLSv12Enabled()) { + protocol = "TLSv1.2"; + } else if (WolfSSL.TLSv11Enabled()) { + protocol = "TLSv1.1"; + } else if (WolfSSL.TLSv1Enabled()) { + protocol = "TLSv1.0"; + } else { + System.out.println("\t... skipped"); + return; + } + + /* create new CTX */ + this.ctx = tf.createSSLContext(protocol, ctxProvider); + + /* create SSLServerSocket first to get ephemeral port */ + ss = (SSLServerSocket)ctx.getServerSocketFactory() + .createServerSocket(0); + + final SSLSocket cs = (SSLSocket)ctx.getSocketFactory().createSocket(); + cs.connect(new InetSocketAddress(ss.getLocalPort())); + + final SSLSocket server = (SSLSocket)ss.accept(); + final CountDownLatch closeLatch = new CountDownLatch(1); + + ExecutorService es = Executors.newSingleThreadExecutor(); + Future serverFuture = es.submit(new Callable() { + @Override + public Void call() throws Exception { + try { + server.startHandshake(); + + boolean doClose = closeLatch.await(90L, TimeUnit.SECONDS); + if (!doClose) { + /* Return without closing, latch not hit within + * time limit */ + return null; + } + + /* Sleep so write thread has a chance to do some + * writing before interrupt */ + Thread.sleep(1000); + cs.setSoLinger(true, 5); + cs.close(); + + } catch (SSLException e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Server thread got SSLException when not expected"); + } + return null; + } + }); + + byte[] tmpArr = new byte[1024]; + Arrays.fill(tmpArr, (byte)0xA2); + OutputStream out = cs.getOutputStream(); + + try { + try { + cs.startHandshake(); + out.write(tmpArr); + } + catch (Exception e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Exception from first out.write() when not expected"); + } + + try { + /* signal server thread to try and close socket */ + closeLatch.countDown(); + + /* keep writing, we should get interrupted */ + while (true) { + out.write(tmpArr); + } + + } catch (SocketException e) { + /* We expect SocketException with this message, error if + * different than expected */ + if (!e.getMessage().contains("Socket fd closed during poll")) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Incorrect SocketException thrown by client"); + throw e; + } + + passed = true; + } + } + finally { + es.shutdown(); + serverFuture.get(); + if (!cs.isClosed()) { + cs.close(); + } + if (!server.isClosed()) { + server.close(); + } + if (!ss.isClosed()) { + ss.close(); + } + } + + if (passed) { + System.out.println("\t... passed"); + } + } + + /* Test timeout set to 10000 ms (10 sec) in case inerrupt code is not + * working as expected, we will see the timeout as a hard error that + * this test has failed */ + @Test(timeout = 10000) + public void testSocketCloseInterruptsRead() throws Exception { + + int ret = 0; + String protocol = null; + SSLServerSocketFactory ssf = null; + SSLServerSocket ss = null; + SSLSocketFactory sf = null; + boolean passed = false; + + System.out.print("\tTesting close/read interrupt"); + + if (WolfSSL.TLSv12Enabled()) { + protocol = "TLSv1.2"; + } else if (WolfSSL.TLSv11Enabled()) { + protocol = "TLSv1.1"; + } else if (WolfSSL.TLSv1Enabled()) { + protocol = "TLSv1.0"; + } else { + System.out.println("\t... skipped"); + return; + } + + /* create new CTX */ + this.ctx = tf.createSSLContext(protocol, ctxProvider); + + /* create SSLServerSocket first to get ephemeral port */ + ss = (SSLServerSocket)ctx.getServerSocketFactory() + .createServerSocket(0); + + final SSLSocket cs = (SSLSocket)ctx.getSocketFactory().createSocket(); + cs.connect(new InetSocketAddress(ss.getLocalPort())); + + final SSLSocket server = (SSLSocket)ss.accept(); + final CountDownLatch closeLatch = new CountDownLatch(1); + + ExecutorService es = Executors.newSingleThreadExecutor(); + Future serverFuture = es.submit(new Callable() { + @Override + public Void call() throws Exception { + try { + server.startHandshake(); + + boolean doClose = closeLatch.await(90L, TimeUnit.SECONDS); + if (!doClose) { + /* Return without closing, latch not hit within + * time limit */ + return null; + } + + /* Sleep to let client thread hit read call */ + Thread.sleep(1000); + cs.setSoLinger(true, 5); + cs.close(); + + } catch (SSLException e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Server thread got SSLException when not expected"); + } + return null; + } + }); + + byte[] tmpArr = new byte[1024]; + InputStream in = cs.getInputStream(); + + try { + try { + cs.startHandshake(); + } + catch (Exception e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Exception from startHandshake() when not expected"); + } + + try { + /* signal server thread to try and close socket */ + closeLatch.countDown(); + + while (true) { + ret = in.read(tmpArr, 0, tmpArr.length); + if (ret == -1) { + /* end of stream */ + break; + } + } + + } catch (SocketException e) { + /* We expect SocketException with this message, error if + * different than expected */ + if (!e.getMessage().contains("Socket is closed") && + !e.getMessage().contains("Connection already shutdown")) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail("Incorrect SocketException thrown by client"); + throw e; + } + } + + passed = true; + } + finally { + es.shutdown(); + serverFuture.get(); + if (!cs.isClosed()) { + cs.close(); + } + if (!server.isClosed()) { + server.close(); + } + if (!ss.isClosed()) { + ss.close(); + } + } + + if (passed) { + System.out.println("\t... passed"); + } + } + @Test public void testSocketMethodsAfterClose() throws Exception {