Skip to content

Commit

Permalink
use new client capability of consumable notifications
Browse files Browse the repository at this point in the history
  • Loading branch information
battermann committed Nov 20, 2024
1 parent e6c757c commit 7ea5515
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 150 deletions.
1 change: 1 addition & 0 deletions integration/integration.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ library
Testlib.App
Testlib.Assertions
Testlib.Cannon
Testlib.Cannon.ConsumableNotifications
Testlib.Certs
Testlib.Env
Testlib.HTTP
Expand Down
112 changes: 109 additions & 3 deletions integration/test/MLS/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import System.IO.Temp
import System.Posix.Files
import System.Process
import Testlib.Assertions
import Testlib.Cannon.ConsumableNotifications
import Testlib.HTTP
import Testlib.JSON
import Testlib.Prelude
Expand Down Expand Up @@ -570,6 +571,76 @@ createExternalCommit convId cid mgi = do
data MLSNotificationTag = MLSNotificationMessageTag | MLSNotificationWelcomeTag
deriving (Show, Eq, Ord)

consumingMassagesViaNewCapability :: (HasCallStack) => MLSProtocol -> MessagePackage -> Codensity App ()
consumingMassagesViaNewCapability mlsProtocol mp = Codensity $ \k -> do
conv <- getMLSConv mp.convId
-- clients that should receive the message itself
let oldClients = Set.delete mp.sender conv.members
-- clients that should receive a welcome message
let newClients = Set.delete mp.sender conv.newMembers
-- all clients that should receive some MLS notification, together with the
-- expected notification tag
let clients =
map (,MLSNotificationMessageTag) (toList oldClients)
<> map (,MLSNotificationWelcomeTag) (toList newClients)

let newUsers =
Set.delete mp.sender.user $
Set.difference
(Set.map (.user) newClients)
(Set.map (.user) oldClients)

let uidsWithClients =
fmap
((\c -> (c.user, (object ["domain" .= c.domain, "id" .= c.user], c.client))) . fst)
clients

withEventsWebSockets (fmap snd uidsWithClients) $ \chans -> do
r <- k ()

-- if the conversation is actually MLS (and not mixed), pick one client for
-- each new user and wait for its join event. In Mixed protocol, the user is
-- already in the conversation so they do not get a member-join
-- notification.
when (mlsProtocol == MLSProtocolMLS) $ do
let uidsWithChannels = zip uidsWithClients chans
let newUserChans = uidsWithChannels & filter (\((uid, _), _) -> Set.member uid newUsers) & fmap snd
let assertJoin e = do
e %. "data.event.payload.0.type" `shouldMatch` "conversation.member-join"
pure e

traverse_
( \(eventChan, ackChan) -> assertEvent eventChan assertJoin >>= ackEvent ackChan
)
newUserChans

-- at this point we know that every new user has been added to the
-- conversation
for_ (zip clients chans) $ \((cid, t), (eventChan, ackChan)) -> case t of
MLSNotificationMessageTag -> do
event <-
awaitEvent
eventChan
ackChan
( \e -> do
eventType <- e %. "data.event.payload.0.type" & asString
pure $ eventType == "conversation.mls-message-add"
)
eventData <- event %. "data.event.payload.0.data" & asByteString
void $ mlsCliConsume mp.convId conv.ciphersuite cid eventData
MLSNotificationWelcomeTag -> do
event <-
awaitEvent
eventChan
ackChan
( \e -> do
eventType <- e %. "data.event.payload.0.type" & asString
pure $ eventType == "conversation.mls-welcome"
)
eventData <- event %. "data.event.payload.0.data" & asByteString
void $ fromWelcome mp.convId conv.ciphersuite cid eventData
pure r

consumingMessages :: (HasCallStack) => MLSProtocol -> MessagePackage -> Codensity App ()
consumingMessages mlsProtocol mp = Codensity $ \k -> do
conv <- getMLSConv mp.convId
Expand All @@ -588,10 +659,8 @@ consumingMessages mlsProtocol mp = Codensity $ \k -> do
Set.difference
(Set.map (.user) newClients)
(Set.map (.user) oldClients)
let userClients = map ((\ci -> (ci.user, ci.client)) . fst) clients

-- withEventWebSockets userClients
withwebsockets (map fst clients) $ \wss -> do
withWebSockets (map fst clients) $ \wss -> do
r <- k ()

-- if the conversation is actually MLS (and not mixed), pick one client for
Expand Down Expand Up @@ -678,6 +747,43 @@ sendAndConsumeMessage mp = lowerCodensity $ do
consumingMessages MLSProtocolMLS mp
lift $ postMLSMessage mp.sender mp.message >>= getJSON 201

sendAndConsumeCommitBundle' :: (HasCallStack) => MessagePackage -> App Value
sendAndConsumeCommitBundle' = sendAndConsumeCommitBundleWithProtocol' MLSProtocolMLS

-- | Send an MLS commit bundle, wait for clients to receive it, consume it, and
-- update the test state accordingly.
sendAndConsumeCommitBundleWithProtocol' :: (HasCallStack) => MLSProtocol -> MessagePackage -> App Value
sendAndConsumeCommitBundleWithProtocol' protocol mp = do
lowerCodensity $ do
consumingMassagesViaNewCapability protocol mp
lift $ do
r <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 201

-- if the sender is a new member (i.e. it's an external commit), then
-- process the welcome message directly
do
conv <- getMLSConv mp.convId
when (Set.member mp.sender conv.newMembers) $
traverse_ (fromWelcome mp.convId conv.ciphersuite mp.sender) mp.welcome

-- increment epoch and add new clients
modifyMLSState $ \mls ->
mls
{ convs =
Map.adjust
( \conv ->
conv
{ epoch = conv.epoch + 1,
members = conv.members <> conv.newMembers,
newMembers = mempty
}
)
mp.convId
mls.convs
}

pure r

sendAndConsumeCommitBundle :: (HasCallStack) => MessagePackage -> App Value
sendAndConsumeCommitBundle = sendAndConsumeCommitBundleWithProtocol MLSProtocolMLS

Expand Down
23 changes: 15 additions & 8 deletions integration/test/Performance/BigConversation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.Time (NominalDiffTime, diffUTCTime, getCurrentTime)
import MLS.Util
import SetupHelpers
import qualified System.CryptoBox as Cryptobox
import Testlib.Cannon.ConsumableNotifications (assertEvent, sendAck, withEventsWebSocket)
import Testlib.Prelude
import UnliftIO (pooledMapConcurrentlyN)
import UnliftIO.Temporary
Expand Down Expand Up @@ -44,10 +45,10 @@ batchForSize VeryLarge = 500

testCreateBigMLSConversation :: App ()
testCreateBigMLSConversation = withModifiedBackend def \domain -> do
let teamSize = 200
let batchSize = 50
let teamSize = 20
let batchSize = 2
putStrLn $ "Creating a team with " <> show teamSize <> " members"
(_, ownerClient, _, members, _) <- createTeamAndClients teamSize
(_, ownerClient, _, members, _) <- createTeamAndClients domain teamSize
putStrLn $ "Creating a conversation with " <> show teamSize <> " members in batches of " <> show batchSize
totalTime <-
snd <$> timeIt do
Expand All @@ -56,7 +57,7 @@ testCreateBigMLSConversation = withModifiedBackend def \domain -> do
for_ memberChunks $ \chunk -> do
(size, time) <- timeIt $ do
msg <- createAddCommit ownerClient convId chunk
void $ sendAndConsumeCommitBundle msg
void $ sendAndConsumeCommitBundle' msg
pure (BS.length msg.message)
putStrLn $ "Sent " <> show size <> " bytes in " <> show time
pure (size, time)
Expand All @@ -69,9 +70,9 @@ timeIt action = do
end <- liftIO getCurrentTime
pure (result, diffUTCTime end start)

createTeamAndClients :: Int -> App (Value, ClientIdentity, String, [Value], [ClientIdentity])
createTeamAndClients teamSize = do
(owner, tid, members) <- createTeam OwnDomain teamSize
createTeamAndClients :: String -> Int -> App (Value, ClientIdentity, String, [Value], [ClientIdentity])
createTeamAndClients domain teamSize = do
(owner, tid, members) <- createTeam domain teamSize
let genPrekeyInBox box i = do
pk <- assertCrytoboxSuccess =<< liftIO (Cryptobox.newPrekey box i)
pkBS <- liftIO $ Cryptobox.copyBytes pk.prekey
Expand All @@ -93,7 +94,13 @@ createTeamAndClients teamSize = do
acapabilities = Just ["consumable-notifications"]
}
}
createMLSClient def mlsClientOpts user
cid <- createMLSClient def mlsClientOpts user
withEventsWebSocket user cid.client $ \eventChan ackChan -> do
deliveryTag <- assertEvent eventChan \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.delivery_tag"
sendAck ackChan deliveryTag False
pure cid
ownerClient <- createClient owner
memClients <- pooledMapConcurrentlyN 64 createClient members
for_ memClients $ uploadNewKeyPackage def
Expand Down
140 changes: 1 addition & 139 deletions integration/test/Test/Events.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ import API.Galley
import API.Gundeck
import qualified Control.Concurrent.Timeout as Timeout
import Control.Retry
import Data.ByteString.Conversion (toByteString')
import qualified Data.Text as Text
import Data.Timeout
import qualified Network.WebSockets as WS
import Notifications
import SetupHelpers
import Testlib.Cannon.ConsumableNotifications
import Testlib.Prelude hiding (assertNoEvent)
import Testlib.Printing
import UnliftIO hiding (handle)

-- FUTUREWORK: Investigate why these tests are failing without
Expand Down Expand Up @@ -303,138 +300,3 @@ testTransientEvents = do
ackEvent ackChan e

assertNoEvent eventsChan

----------------------------------------------------------------------
-- helpers

withEventsWebSockets ::
forall uid a.
(HasCallStack, MakesValue uid) =>
[(uid, String)] ->
([(TChan Value, TChan Value)] -> App a) ->
App a
withEventsWebSockets userClients k = go [] $ reverse userClients
where
go :: [(TChan Value, TChan Value)] -> [(uid, String)] -> App a
go chans [] = k chans
go chans ((uid, cid) : remaining) =
withEventsWebSocket uid cid $ \eventsChan ackChan ->
go ((eventsChan, ackChan) : chans) remaining

withEventsWebSocket :: (HasCallStack, MakesValue uid) => uid -> String -> (TChan Value -> TChan Value -> App a) -> App a
withEventsWebSocket uid cid k = do
closeWS <- newEmptyMVar
bracket (setup closeWS) (\(_, _, wsThread) -> cancel wsThread) $ \(eventsChan, ackChan, wsThread) -> do
x <- k eventsChan ackChan

-- Ensure all the acks are sent before closing the websocket
isAckChanEmpty <-
retrying
(limitRetries 5 <> constantDelay 10_000)
(\_ isEmpty -> pure $ not isEmpty)
(\_ -> atomically $ isEmptyTChan ackChan)
unless isAckChanEmpty $ do
putStrLn $ colored yellow $ "The ack chan is not empty after 50ms, some acks may not make it to the server"

void $ tryPutMVar closeWS ()

timeout 1_000_000 (wait wsThread) >>= \case
Nothing ->
putStrLn $ colored yellow $ "The websocket thread did not close after waiting for 1s"
Just () -> pure ()

pure x
where
setup :: (HasCallStack) => MVar () -> App (TChan Value, TChan Value, Async ())
setup closeWS = do
(eventsChan, ackChan) <- liftIO $ (,) <$> newTChanIO <*> newTChanIO
wsThread <- eventsWebSocket uid cid eventsChan ackChan closeWS
pure (eventsChan, ackChan, wsThread)

sendMsg :: (HasCallStack) => TChan Value -> Value -> App ()
sendMsg eventsChan msg = liftIO $ atomically $ writeTChan eventsChan msg

ackFullSync :: (HasCallStack) => TChan Value -> App ()
ackFullSync ackChan = do
sendMsg ackChan
$ object ["type" .= "ack_full_sync"]

ackEvent :: (HasCallStack) => TChan Value -> Value -> App ()
ackEvent ackChan event = do
deliveryTag <- event %. "data.delivery_tag"
sendAck ackChan deliveryTag False

sendAck :: (HasCallStack) => TChan Value -> Value -> Bool -> App ()
sendAck ackChan deliveryTag multiple = do
sendMsg ackChan
$ object
[ "type" .= "ack",
"data"
.= object
[ "delivery_tag" .= deliveryTag,
"multiple" .= multiple
]
]

assertEvent :: (HasCallStack) => TChan Value -> ((HasCallStack) => Value -> App a) -> App a
assertEvent eventsChan expectations = do
timeout 10_000_000 (atomically (readTChan eventsChan)) >>= \case
Nothing -> assertFailure "No event received for 10s"
Just e -> do
pretty <- prettyJSON e
addFailureContext ("event:\n" <> pretty)
$ expectations e

assertNoEvent :: (HasCallStack) => TChan Value -> App ()
assertNoEvent eventsChan = do
timeout 1_000_000 (atomically (readTChan eventsChan)) >>= \case
Nothing -> pure ()
Just e -> do
eventJSON <- prettyJSON e
assertFailure $ "Did not expect event: \n" <> eventJSON

consumeAllEvents :: TChan Value -> TChan Value -> App ()
consumeAllEvents eventsChan ackChan = do
timeout 1_000_000 (atomically (readTChan eventsChan)) >>= \case
Nothing -> pure ()
Just e -> do
ackEvent ackChan e
consumeAllEvents eventsChan ackChan

eventsWebSocket :: (MakesValue user) => user -> String -> TChan Value -> TChan Value -> MVar () -> App (Async ())
eventsWebSocket user clientId eventsChan ackChan closeWS = do
serviceMap <- getServiceMap =<< objDomain user
uid <- objId =<< objQidObject user
let HostPort caHost caPort = serviceHostPort serviceMap Cannon
path = "/events?client=" <> clientId
caHdrs = [(fromString "Z-User", toByteString' uid)]
app conn = do
r <-
async $ wsRead conn `catch` \(e :: WS.ConnectionException) ->
case e of
WS.CloseRequest {} -> pure ()
_ -> throwIO e
w <- async $ wsWrite conn
void $ waitAny [r, w]

wsRead conn = forever $ do
bs <- WS.receiveData conn
case decodeStrict' bs of
Just n -> atomically $ writeTChan eventsChan n
Nothing ->
error $ "Failed to decode events: " ++ show bs

wsWrite conn = forever $ do
eitherAck <- race (readMVar closeWS) (atomically $ readTChan ackChan)
case eitherAck of
Left () -> WS.sendClose conn (Text.pack "")
Right ack -> WS.sendBinaryData conn (encode ack)
liftIO
$ async
$ WS.runClientWith
caHost
(fromIntegral caPort)
path
WS.defaultConnectionOptions
caHdrs
app
Loading

0 comments on commit 7ea5515

Please sign in to comment.