Skip to content

Commit

Permalink
merge: #2980
Browse files Browse the repository at this point in the history
2980: Make crdt tests stable r=paulocsanz a=paulocsanz

Since we used websockets the port had to be available when starting the test, causing a failure if for some reason it isn't.

So we moved that logic from websockets to Streams and Sinks.

- Abstract CRDT broadcast logic over a `Sink<Message>` and `Stream<Message>` instead of using WebSockets directly
- Create a `Sink<Message>` and a `Stream<Message>` in tests from a tokio::broadcast channel pair
- Create a `Sink<Vec<u8>>` and a `Stream<Vec<u8>>` in tests from the `Sink<Message>/Stream<Message>`, to properly integrate with y_sync test machinery
- Create server and client structs over those sinks/streams to emulate the websocket ones

Co-authored-by: Paulo Cabral <[email protected]>
  • Loading branch information
si-bors-ng[bot] and paulocsanz authored Nov 28, 2023
2 parents 1efabfc + 45f4fef commit 23db00a
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 219 deletions.
1 change: 1 addition & 0 deletions lib/sdf-server/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ rust_test(
"//third-party/rust:yrs",
"//third-party/rust:tower",
"//third-party/rust:futures",
"//third-party/rust:futures-lite",
":sdf-server",
],
crate_root = "tests/api.rs",
Expand Down
6 changes: 5 additions & 1 deletion lib/sdf-server/src/server/service/ws.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use axum::{
http::StatusCode, response::IntoResponse, response::Response, routing::get, Json, Router,
};
use dal::TransactionsError;
use dal::{TransactionsError, WsEventError};
use si_data_pg::{PgError, PgPoolError};
use thiserror::Error;

Expand All @@ -14,11 +14,15 @@ pub enum WsError {
#[error(transparent)]
Crdt(#[from] CrdtError),
#[error(transparent)]
Nats(#[from] si_data_nats::Error),
#[error(transparent)]
Pg(#[from] PgError),
#[error(transparent)]
PgPool(#[from] PgPoolError),
#[error(transparent)]
Transactions(#[from] TransactionsError),
#[error("wsevent error: {0}")]
WsEvent(#[from] WsEventError),
}

pub mod crdt;
Expand Down
174 changes: 94 additions & 80 deletions lib/sdf-server/src/server/service/ws/crdt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use axum::{
response::IntoResponse,
};
use dal::{WorkspacePk, WsEventError};
use futures::{SinkExt, StreamExt};
use futures::{Sink, SinkExt, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use si_data_nats::{NatsClient, NatsError};
use si_data_nats::{NatsClient, NatsError, Subscriber};
use std::{collections::hash_map::Entry, collections::HashMap, sync::Arc};
use telemetry::prelude::*;
use thiserror::Error;
use tokio::{sync::Mutex, task::JoinSet};
use tokio::{sync::broadcast, sync::Mutex, task::JoinSet};
use y::{YSink, YStream};
use y_sync::net::BroadcastGroup;

Expand All @@ -30,13 +30,15 @@ pub mod y;
pub enum CrdtError {
#[error("axum error: {0}")]
Axum(#[from] axum::Error),
#[error("broadcast error: {0}")]
Broadcast(#[from] broadcast::error::SendError<Message>),
#[error("nats error: {0}")]
Nats(#[from] si_data_nats::Error),
#[error("Shutdown recv error: {0}")]
Recv(#[from] tokio::sync::broadcast::error::RecvError),
#[error("serde json error: {0}")]
Serde(#[from] serde_json::Error),
#[error("failed to subscribe to subject {1}")]
#[error("failed to subscribe to subject: {0} {1}")]
Subscribe(#[source] NatsError, String),
#[error("wsevent error: {0}")]
WsEvent(#[from] WsEventError),
Expand All @@ -61,95 +63,107 @@ pub async fn crdt(
State(shutdown_broadcast): State<ShutdownBroadcast>,
State(broadcast_groups): State<BroadcastGroups>,
) -> Result<impl IntoResponse, WsError> {
Ok(crdt_inner(
nats,
wsu,
claim.workspace_pk,
id,
shutdown_broadcast,
broadcast_groups,
)
.await?)
let workspace_pk = claim.workspace_pk;
let channel_name = format!("crdt-{workspace_pk}-{id}");
let subscription = nats.subscribe(&channel_name).await?;
let ws_subscription = nats.subscribe(&channel_name).await?;
let shutdown = shutdown_broadcast.subscribe();

Ok(wsu.on_upgrade(move |socket| async move {
let (sink, stream) = socket.split();
crdt_handle(
sink,
stream,
nats,
broadcast_groups,
channel_name,
subscription,
ws_subscription,
workspace_pk,
id,
shutdown,
)
.await
}))
}

#[allow(clippy::unused_async)]
pub async fn crdt_inner(
#[allow(clippy::too_many_arguments)]
pub async fn crdt_handle<W, R>(
mut sink: W,
mut stream: R,
nats: NatsClient,
wsu: WebSocketUpgrade,
broadcast_groups: BroadcastGroups,
channel_name: String,
subscription: Subscriber,
mut ws_subscription: Subscriber,
workspace_pk: WorkspacePk,
id: String,
shutdown_broadcast: ShutdownBroadcast,
broadcast_groups: BroadcastGroups,
) -> CrdtResult<impl IntoResponse> {
let channel_name = format!("crdt-{workspace_pk}-{id}");

let mut shutdown = shutdown_broadcast.subscribe();
let subscription = nats.subscribe(&channel_name).await?;
let mut ws_subscription = nats.subscribe(&channel_name).await?;
mut shutdown: broadcast::Receiver<()>,
) where
W: Sink<Message> + Unpin + Send + 'static,
R: Stream<Item = Result<Message, axum::Error>> + Unpin + Send + 'static,
CrdtError: From<<W as Sink<Message>>::Error>,
{
let mut tasks = JoinSet::new();

tasks.spawn(async move {
while let Some(message) = ws_subscription.next().await {
sink.send(Message::Binary(message.payload().to_owned()))
.await?;
}

Ok(wsu.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
let mut tasks = JoinSet::new();
let result: CrdtResult<()> = Ok(());
result
});

tasks.spawn(async move {
while let Some(message) = ws_subscription.next().await {
sink.send(Message::Binary(message.payload().to_owned()))
.await?;
let (ws_nats, ws_channel_name) = (nats.clone(), channel_name.clone());
tasks.spawn(async move {
while let Some(msg) = stream.next().await {
if let Message::Binary(vec) = msg? {
ws_nats.publish(&ws_channel_name, vec).await?;
}
}

let result: CrdtResult<()> = Ok(());
result
});

let (ws_nats, ws_channel_name) = (nats.clone(), channel_name.clone());
tasks.spawn(async move {
while let Some(msg) = stream.next().await {
if let Message::Binary(vec) = msg? {
ws_nats.publish(&ws_channel_name, vec).await?;
}
Ok(())
});

tasks.spawn(async move { Ok(shutdown.recv().await?) });

let sink = Arc::new(Mutex::new(YSink::new(nats, channel_name)));
let stream = YStream::new(subscription);

let bcast: Arc<BroadcastGroup> = match broadcast_groups
.lock()
.await
.entry(format!("{workspace_pk}-{id}"))
{
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => e
.insert(Arc::new(BroadcastGroup::new(Default::default(), 32).await))
.clone(),
};

let sub = bcast.subscribe(sink, stream);
tokio::select! {
result = sub.completed() => {
match result {
Ok(_) => info!("broadcasting for channel finished successfully"),
Err(e) => error!("broadcasting for channel finished abruptly: {}", e),
}

Ok(())
});

tasks.spawn(async move { Ok(shutdown.recv().await?) });

let sink = Arc::new(Mutex::new(YSink::new(nats, channel_name)));
let stream = YStream::new(subscription);

let bcast: Arc<BroadcastGroup> = match broadcast_groups
.lock()
.await
.entry(format!("{workspace_pk}-{id}"))
{
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => e
.insert(Arc::new(BroadcastGroup::new(Default::default(), 32).await))
.clone(),
};

let sub = bcast.subscribe(sink, stream);
tokio::select! {
result = sub.completed() => {
match result {
Ok(_) => info!("broadcasting for channel finished successfully"),
Err(e) => error!("broadcasting for channel finished abruptly: {}", e),
}
Some(result) = tasks.join_next() => {
match result {
Ok(Err(err)) => {
error!("Task failed: {err}");
}
}
Some(result) = tasks.join_next() => {
match result {
Ok(Err(err)) => {
error!("Task failed: {err}");
}
Err(err) => {
error!("Unable to join task: {err}");
}
Ok(Ok(())) => {},
Err(err) => {
error!("Unable to join task: {err}");
}
Ok(Ok(())) => {},
}
else => {},
}
else => {},
}

tasks.shutdown().await;
}))
tasks.shutdown().await;
}
Loading

0 comments on commit 23db00a

Please sign in to comment.