Skip to content

Commit

Permalink
Pass tokio runtime handle to ws2::Chat::new
Browse files Browse the repository at this point in the history
  • Loading branch information
akonradi-signal authored Dec 6, 2024
1 parent bad358e commit 975f9b3
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ public void testConstructRequest() throws Exception {

@Test
public void testConnectUnauth() throws Exception {
// Use the presence of the proxy server environment setting to know whether we should make
// network requests in our tests.
final String PROXY_SERVER = TestEnvironment.get("LIBSIGNAL_TESTING_PROXY_SERVER");
Assume.assumeNotNull(PROXY_SERVER);
// Use the presence of the environment setting to know whether we should
// make network requests in our tests.
final String ENABLE_TEST = TestEnvironment.get("LIBSIGNAL_TESTING_RUN_NONHERMETIC_TESTS");
Assume.assumeNotNull(ENABLE_TEST);

final Network net = new Network(Network.Environment.STAGING, USER_AGENT);
final UnauthenticatedChatService chat = net.createUnauthChatService(null);
Expand Down
8 changes: 5 additions & 3 deletions node/ts/test/NetTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
//

import { assert, config, expect, use } from 'chai';
import { config, expect, use } from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as sinon from 'sinon';
import * as sinonChai from 'sinon-chai';
Expand Down Expand Up @@ -237,9 +237,11 @@ describe('chat service api', () => {
await connectChatUnauthenticated(net);
}).timeout(10000);

it('can connect through a proxy server', async () => {
it('can connect through a proxy server', async function () {
const PROXY_SERVER = process.env.LIBSIGNAL_TESTING_PROXY_SERVER;
assert(PROXY_SERVER, 'checked above');
if (!PROXY_SERVER) {
this.skip();
}

// The default TLS proxy config doesn't support staging, so we connect to production.
const net = new Net({
Expand Down
68 changes: 38 additions & 30 deletions rust/bridge/shared/types/src/net/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ impl RefUnwindSafe for AuthenticatedChatConnection {}

enum MaybeChatConnection {
Running(ChatConnection),
WaitingForListener(chat::PendingChatConnection),
WaitingForListener(tokio::runtime::Handle, chat::PendingChatConnection),
TemporarilyEvicted,
}

Expand All @@ -254,7 +254,11 @@ impl UnauthenticatedChatConnection {
let inner = establish_chat_connection(connection_manager, None).await?;
log::info!("connected unauthenticated chat");
Ok(Self {
inner: MaybeChatConnection::WaitingForListener(inner).into(),
inner: MaybeChatConnection::WaitingForListener(
tokio::runtime::Handle::current(),
inner,
)
.into(),
})
}
}
Expand All @@ -274,7 +278,11 @@ impl AuthenticatedChatConnection {
.await?;
log::info!("connected authenticated chat");
Ok(Self {
inner: MaybeChatConnection::WaitingForListener(pending).into(),
inner: MaybeChatConnection::WaitingForListener(
tokio::runtime::Handle::current(),
pending,
)
.into(),
})
}
}
Expand Down Expand Up @@ -305,40 +313,36 @@ pub trait BridgeChatConnection {
fn info(&self) -> ConnectionInfo;
}

impl<C: AsRef<tokio::sync::RwLock<MaybeChatConnection>>> BridgeChatConnection for C {
impl<C: AsRef<tokio::sync::RwLock<MaybeChatConnection>> + Sync> BridgeChatConnection for C {
fn init_listener(&self, listener: Box<dyn ChatListener>) {
init_listener(&mut self.as_ref().blocking_write(), listener)
}

fn send(
async fn send(
&self,
message: Request,
timeout: Duration,
) -> impl Future<Output = Result<ChatResponse, ChatServiceError>> + Send {
let guard = self.as_ref().blocking_read();
async move {
let MaybeChatConnection::Running(inner) = &*guard else {
panic!("listener was not set")
};
inner.send(message, timeout).await
}
) -> Result<ChatResponse, ChatServiceError> {
let guard = self.as_ref().read().await;
let MaybeChatConnection::Running(inner) = &*guard else {
panic!("listener was not set")
};
inner.send(message, timeout).await
}

fn disconnect(&self) -> impl Future<Output = ()> + Send {
let guard = self.as_ref().blocking_read();
async move {
let MaybeChatConnection::Running(inner) = &*guard else {
panic!("listener was not set")
};
inner.disconect().await
}
async fn disconnect(&self) {
let guard = self.as_ref().read().await;
let MaybeChatConnection::Running(inner) = &*guard else {
panic!("listener was not set")
};
inner.disconect().await
}

fn info(&self) -> ConnectionInfo {
let guard = self.as_ref().blocking_read();
match &*guard {
MaybeChatConnection::Running(chat_connection) => chat_connection.connection_info(),
MaybeChatConnection::WaitingForListener(pending_chat_connection) => {
MaybeChatConnection::WaitingForListener(_, pending_chat_connection) => {
pending_chat_connection.connection_info()
}
MaybeChatConnection::TemporarilyEvicted => unreachable!("unobservable state"),
Expand All @@ -347,16 +351,20 @@ impl<C: AsRef<tokio::sync::RwLock<MaybeChatConnection>>> BridgeChatConnection fo
}

fn init_listener(connection: &mut MaybeChatConnection, listener: Box<dyn ChatListener>) {
let pending = match std::mem::replace(connection, MaybeChatConnection::TemporarilyEvicted) {
MaybeChatConnection::Running(chat_connection) => {
*connection = MaybeChatConnection::Running(chat_connection);
panic!("listener already set")
}
MaybeChatConnection::WaitingForListener(pending_chat_connection) => pending_chat_connection,
MaybeChatConnection::TemporarilyEvicted => panic!("should be a temporary state"),
};
let (tokio_runtime, pending) =
match std::mem::replace(connection, MaybeChatConnection::TemporarilyEvicted) {
MaybeChatConnection::Running(chat_connection) => {
*connection = MaybeChatConnection::Running(chat_connection);
panic!("listener already set")
}
MaybeChatConnection::WaitingForListener(tokio_runtime, pending_chat_connection) => {
(tokio_runtime, pending_chat_connection)
}
MaybeChatConnection::TemporarilyEvicted => panic!("should be a temporary state"),
};

*connection = MaybeChatConnection::Running(ChatConnection::finish_connect(
tokio_runtime,
pending,
listener.into_event_listener(),
))
Expand Down
31 changes: 6 additions & 25 deletions rust/net/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,29 +554,6 @@ pub struct AuthenticatedChatHeaders {
pub type ChatServiceRoute = UnresolvedWebsocketServiceRoute;

impl ChatConnection {
pub async fn connect_with(
connect: &tokio::sync::RwLock<ConnectState>,
resolver: &DnsResolver,
http_route_provider: impl RouteProvider<Route = UnresolvedHttpsServiceRoute>,
confirmation_header_name: Option<HeaderName>,
user_agent: &UserAgent,
ws_config: self::ws2::Config,
listener: self::ws2::EventListener,
auth: Option<AuthenticatedChatHeaders>,
) -> Result<Self, ChatServiceError> {
let pending = Self::start_connect_with(
connect,
resolver,
http_route_provider,
confirmation_header_name,
user_agent,
ws_config,
auth,
)
.await?;
Ok(Self::finish_connect(pending, listener))
}

pub async fn start_connect_with(
connect: &tokio::sync::RwLock<ConnectState>,
resolver: &DnsResolver,
Expand Down Expand Up @@ -650,14 +627,18 @@ impl ChatConnection {
}
}

pub fn finish_connect(pending: PendingChatConnection, listener: ws2::EventListener) -> Self {
pub fn finish_connect(
tokio_runtime: tokio::runtime::Handle,
pending: PendingChatConnection,
listener: ws2::EventListener,
) -> Self {
let PendingChatConnection {
connection,
connection_info,
ws_config,
} = pending;
Self {
inner: crate::chat::ws2::Chat::new(connection, ws_config, listener),
inner: crate::chat::ws2::Chat::new(tokio_runtime, connection, ws_config, listener),
connection_info,
}
}
Expand Down
14 changes: 12 additions & 2 deletions rust/net/src/chat/ws2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ pub struct Responder {
pub type EventListener = Box<dyn FnMut(ListenerEvent) + Send>;

impl Chat {
pub fn new<T>(transport: T, config: Config, listener: EventListener) -> Self
pub fn new<T>(
tokio_runtime: tokio::runtime::Handle,
transport: T,
config: Config,
listener: EventListener,
) -> Self
where
T: Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
+ Sink<tungstenite::Message, Error = tungstenite::Error>
Expand All @@ -156,6 +161,8 @@ impl Chat {
remote_idle_timeout,
} = config;

// Enable access to tokio types like Sleep, but only for the duration of this call.
let _enable_tokio_types = tokio_runtime.enter();
Self::new_inner(
(
transport,
Expand All @@ -167,6 +174,7 @@ impl Chat {
),
initial_request_id,
listener,
tokio_runtime,
)
}

Expand Down Expand Up @@ -267,6 +275,7 @@ impl Chat {
into_inner_connection: impl IntoInnerConnection,
initial_request_id: u64,
listener: EventListener,
tokio_runtime: tokio::runtime::Handle,
) -> Self {
let (request_tx, request_rx) = mpsc::channel(1);
let (response_tx, response_rx) = mpsc::unbounded_channel();
Expand Down Expand Up @@ -300,7 +309,7 @@ impl Chat {
requests_in_flight,
};

let task = tokio::spawn(spawned_task_body(
let task = tokio_runtime.spawn(spawned_task_body(
connection,
listener,
response_tx.downgrade(),
Expand Down Expand Up @@ -1207,6 +1216,7 @@ mod test {
},
initial_request_id,
listener,
tokio::runtime::Handle::current(),
);

(chat, (outgoing_events_rx, incoming_events_tx))
Expand Down
4 changes: 2 additions & 2 deletions swift/Tests/LibSignalClientTests/ChatServiceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ final class ChatServiceTests: TestCaseBase {
}

func testConnectUnauth() async throws {
// Use the presence of the proxy server environment setting to know whether we should make network requests in our tests.
guard ProcessInfo.processInfo.environment["LIBSIGNAL_TESTING_PROXY_SERVER"] != nil else {
// Use the presence of the environment setting to know whether we should make network requests in our tests.
guard ProcessInfo.processInfo.environment["LIBSIGNAL_TESTING_RUN_NONHERMETIC_TESTS"] != nil else {
throw XCTSkip()
}

Expand Down

0 comments on commit 975f9b3

Please sign in to comment.