From 0c8edadd13482bb284423bdc3ba01d0c4b541192 Mon Sep 17 00:00:00 2001 From: stuart Date: Mon, 18 Mar 2024 10:42:37 +0100 Subject: [PATCH 1/6] move a function --- be1-go/hub/hub.go | 56 ---------------------------------- be1-go/hub/message_handling.go | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index 34beb2dc80..6761633ea0 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -10,7 +10,6 @@ import ( "popstellar/inbox" jsonrpc "popstellar/message" "popstellar/message/answer" - "popstellar/message/messagedata" "popstellar/message/query" "popstellar/message/query/method" "popstellar/message/query/method/message" @@ -523,61 +522,6 @@ func (h *Hub) sendHeartbeatToServers() { h.serverSockets.SendToAll(buf) } -// createLao creates a new LAO using the data in the publish parameter. -func (h *Hub) createLao(msg message.Message, laoCreate messagedata.LaoCreate, - socket socket.Socket, -) error { - laoChannelPath := rootPrefix + laoCreate.ID - - if _, ok := h.channelByID.Get(laoChannelPath); ok { - return answer.NewDuplicateResourceError("failed to create lao: duplicate lao path: %q", laoChannelPath) - } - - senderBuf, err := base64.URLEncoding.DecodeString(msg.Sender) - if err != nil { - return answer.NewInvalidMessageFieldError("failed to decode public key of the sender: %v", err) - } - - // Check if the sender of the LAO creation message is the organizer - senderPubKey := crypto.Suite.Point() - err = senderPubKey.UnmarshalBinary(senderBuf) - if err != nil { - return answer.NewInvalidMessageFieldError("failed to unmarshal public key of the sender: %v", err) - } - - organizerBuf, err := base64.URLEncoding.DecodeString(laoCreate.Organizer) - if err != nil { - return answer.NewInvalidMessageFieldError("failed to decode public key of the organizer: %v", err) - } - - organizerPubKey := crypto.Suite.Point() - err = organizerPubKey.UnmarshalBinary(organizerBuf) - if err != nil { - return answer.NewInvalidMessageFieldError("failed to unmarshal public key of the organizer: %v", err) - } - - // Check if the sender and organizer fields of the create lao message are equal - if !organizerPubKey.Equal(senderPubKey) { - return answer.NewAccessDeniedError("sender's public key does not match the organizer field: %q != %q", senderPubKey, organizerPubKey) - } - - // Check if the sender of the LAO creation message is the owner - if h.GetPubKeyOwner() != nil && !h.GetPubKeyOwner().Equal(senderPubKey) { - return answer.NewAccessDeniedError("sender's public key does not match the owner's: %q != %q", senderPubKey, h.GetPubKeyOwner()) - } - - laoCh, err := h.laoFac(laoChannelPath, h, msg, h.log, senderPubKey, socket) - if err != nil { - return answer.NewInvalidMessageFieldError("failed to create the LAO: %v", err) - } - - h.log.Info().Msgf("storing new channel '%s' %v", laoChannelPath, msg) - - h.NotifyNewChannel(laoChannelPath, laoCh, socket) - - return nil -} - // GetPubKeyOwner implements channel.HubFunctionalities func (h *Hub) GetPubKeyOwner() kyber.Point { return h.pubKeyOwner diff --git a/be1-go/hub/message_handling.go b/be1-go/hub/message_handling.go index 1c9e190277..b5093271ed 100644 --- a/be1-go/hub/message_handling.go +++ b/be1-go/hub/message_handling.go @@ -611,3 +611,58 @@ func (h *Hub) loopOverMessages(messages *map[string][]json.RawMessage, senderSoc } return tempBlacklist, nil } + +// createLao creates a new LAO using the data in the publish parameter. +func (h *Hub) createLao(msg message.Message, laoCreate messagedata.LaoCreate, + socket socket.Socket, +) error { + laoChannelPath := rootPrefix + laoCreate.ID + + if _, ok := h.channelByID.Get(laoChannelPath); ok { + return answer.NewDuplicateResourceError("failed to create lao: duplicate lao path: %q", laoChannelPath) + } + + senderBuf, err := base64.URLEncoding.DecodeString(msg.Sender) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to decode public key of the sender: %v", err) + } + + // Check if the sender of the LAO creation message is the organizer + senderPubKey := crypto.Suite.Point() + err = senderPubKey.UnmarshalBinary(senderBuf) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to unmarshal public key of the sender: %v", err) + } + + organizerBuf, err := base64.URLEncoding.DecodeString(laoCreate.Organizer) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to decode public key of the organizer: %v", err) + } + + organizerPubKey := crypto.Suite.Point() + err = organizerPubKey.UnmarshalBinary(organizerBuf) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to unmarshal public key of the organizer: %v", err) + } + + // Check if the sender and organizer fields of the create lao message are equal + if !organizerPubKey.Equal(senderPubKey) { + return answer.NewAccessDeniedError("sender's public key does not match the organizer field: %q != %q", senderPubKey, organizerPubKey) + } + + // Check if the sender of the LAO creation message is the owner + if h.GetPubKeyOwner() != nil && !h.GetPubKeyOwner().Equal(senderPubKey) { + return answer.NewAccessDeniedError("sender's public key does not match the owner's: %q != %q", senderPubKey, h.GetPubKeyOwner()) + } + + laoCh, err := h.laoFac(laoChannelPath, h, msg, h.log, senderPubKey, socket) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to create the LAO: %v", err) + } + + h.log.Info().Msgf("storing new channel '%s' %v", laoChannelPath, msg) + + h.NotifyNewChannel(laoChannelPath, laoCh, socket) + + return nil +} From 9fa48e074f18fd53f8ddc8ce0bf29606cdd12dbc Mon Sep 17 00:00:00 2001 From: stuart Date: Mon, 18 Mar 2024 13:27:19 +0100 Subject: [PATCH 2/6] rename hub factory function --- be1-go/cli/cli.go | 2 +- be1-go/cli/cli_test.go | 4 +-- be1-go/hub/hub.go | 4 +-- be1-go/hub/hub_test.go | 58 +++++++++++++++++------------------ be1-go/network/server_test.go | 2 +- be1-go/popcha/server_test.go | 2 +- 6 files changed, 36 insertions(+), 36 deletions(-) diff --git a/be1-go/cli/cli.go b/be1-go/cli/cli.go index c45cd3143d..3f447128d1 100644 --- a/be1-go/cli/cli.go +++ b/be1-go/cli/cli.go @@ -85,7 +85,7 @@ func Serve(cliCtx *cli.Context) error { ownerKey(serverConfig.PublicKey, &point) // create user hub - h, err := hub.NewHub(point, serverConfig.ClientAddress, serverConfig.ServerAddress, log.With().Str("role", "server").Logger(), + h, err := hub.New(point, serverConfig.ClientAddress, serverConfig.ServerAddress, log.With().Str("role", "server").Logger(), lao.NewChannel) if err != nil { return xerrors.Errorf("failed create the hub: %v", err) diff --git a/be1-go/cli/cli_test.go b/be1-go/cli/cli_test.go index 6a043b6164..68c916888d 100644 --- a/be1-go/cli/cli_test.go +++ b/be1-go/cli/cli_test.go @@ -36,7 +36,7 @@ const ( func TestConnectToSocket(t *testing.T) { log := zerolog.New(io.Discard) - oh, err := hub.NewHub(crypto.Suite.Point(), "", "", log, lao.NewChannel) + oh, err := hub.New(crypto.Suite.Point(), "", "", log, lao.NewChannel) require.NoError(t, err) oh.Start() @@ -46,7 +46,7 @@ func TestConnectToSocket(t *testing.T) { time.Sleep(1 * time.Second) - wh, err := hub.NewHub(crypto.Suite.Point(), "", "", log, lao.NewChannel) + wh, err := hub.New(crypto.Suite.Point(), "", "", log, lao.NewChannel) require.NoError(t, err) wDone := make(chan struct{}) wh.Start() diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index 6761633ea0..e533a541e5 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -118,8 +118,8 @@ type Hub struct { blacklist state.ThreadSafeSlice[string] } -// NewHub returns a new Hub. -func NewHub(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAddress string, log zerolog.Logger, +// New returns a new Hub. +func New(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAddress string, log zerolog.Logger, laoFac channel.LaoFactory, ) (*Hub, error) { schemaValidator, err := validation.NewSchemaValidator(log) diff --git a/be1-go/hub/hub_test.go b/be1-go/hub/hub_test.go index a414153458..f2b3751494 100644 --- a/be1-go/hub/hub_test.go +++ b/be1-go/hub/hub_test.go @@ -32,7 +32,7 @@ import ( func Test_Add_Server_Socket(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) sock := &fakeSocket{id: "fakeID"} @@ -47,7 +47,7 @@ func Test_Create_LAO_Bad_Key(t *testing.T) { fakeChannelFac := &fakeChannelFac{c: &fakeChannel{}} - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -122,7 +122,7 @@ func Test_Create_LAO_Different_Sender_And_Organizer_Keys(t *testing.T) { fakeChannelFac := &fakeChannelFac{c: &fakeChannel{}} - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -196,7 +196,7 @@ func Test_Create_LAO_No_Key(t *testing.T) { fakeChannelFac := &fakeChannelFac{c: &fakeChannel{}} - hub, err := NewHub(nil, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(nil, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -272,7 +272,7 @@ func Test_Create_LAO_Bad_MessageID(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -350,7 +350,7 @@ func Test_Create_LAO_Bad_Signature(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -427,7 +427,7 @@ func Test_Create_LAO_Data_Not_Base64(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -503,7 +503,7 @@ func Test_Create_Invalid_Json_Schema(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) type N0thing struct { @@ -575,7 +575,7 @@ func Test_Create_Invalid_Lao_Id(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -652,7 +652,7 @@ func Test_Create_LAO(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) now := time.Now().Unix() @@ -742,7 +742,7 @@ func Test_Wrong_Root_Publish(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "/root" @@ -824,7 +824,7 @@ func Test_Handle_Answer(t *testing.T) { var output bytes.Buffer - hub, err := NewHub(keypair.public, "", "", zerolog.New(&output), fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", zerolog.New(&output), fakeChannelFac.newChannel) require.NoError(t, err) result := struct { @@ -944,7 +944,7 @@ func Test_Handle_Publish_From_Client(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1011,7 +1011,7 @@ func Test_Handle_Publish_From_Server(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1078,7 +1078,7 @@ func Test_Receive_Publish_Twice(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1155,7 +1155,7 @@ func Test_Create_LAO_GetMessagesById_Result(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) name := "LAO X" @@ -1256,7 +1256,7 @@ func Test_Create_LAO_GetMessagesById_Wrong_MessageID(t *testing.T) { c: &fakeChannel{}, } - hub, err := NewHub(keypair.public, "", "", nolog, fakeChannelFac.newChannel) + hub, err := New(keypair.public, "", "", nolog, fakeChannelFac.newChannel) require.NoError(t, err) name := "LAO X" @@ -1343,7 +1343,7 @@ func Test_Handle_Subscribe(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1406,7 +1406,7 @@ func TestServer_Handle_Unsubscribe(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1480,7 +1480,7 @@ func TestServer_Handle_Catchup(t *testing.T) { msgs: fakeMessages, } - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1532,7 +1532,7 @@ func TestServer_Handle_Catchup(t *testing.T) { func Test_Get_Server_Number(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) sock1 := &fakeSocket{id: "fakeID1"} @@ -1552,7 +1552,7 @@ func Test_Send_And_Handle_Message(t *testing.T) { c := &fakeChannel{} - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) laoID := "XXX" @@ -1618,7 +1618,7 @@ func Test_Send_And_Handle_Message(t *testing.T) { func Test_Send_Heartbeat_Message(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) sock := &fakeSocket{} @@ -1654,7 +1654,7 @@ func Test_Send_Heartbeat_Message(t *testing.T) { func Test_Handle_Heartbeat(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) hub.hubInbox.StoreMessage("/root", msg1) @@ -1712,7 +1712,7 @@ func Test_Handle_Heartbeat(t *testing.T) { func Test_Handle_GetMessagesById(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) sock := &fakeSocket{} @@ -1766,7 +1766,7 @@ func Test_Handle_GetMessagesById(t *testing.T) { func Test_Send_GreetServer_Message(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) + hub, err := New(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) require.NoError(t, err) pkServ, err := hub.pubKeyServ.MarshalBinary() @@ -1792,7 +1792,7 @@ func Test_Send_GreetServer_Message(t *testing.T) { func Test_Handle_GreetServer_First_Time(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) + hub, err := New(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) require.NoError(t, err) pkServ, err := hub.pubKeyServ.MarshalBinary() @@ -1844,7 +1844,7 @@ func Test_Handle_GreetServer_First_Time(t *testing.T) { func Test_Handle_GreetServer_Already_Greeted(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) + hub, err := New(keypair.public, "ws://localhost:9000/client", "ws://localhost:9001/server", nolog, nil) require.NoError(t, err) sock := &fakeSocket{} @@ -1891,7 +1891,7 @@ func Test_Handle_GreetServer_Already_Greeted(t *testing.T) { func Test_Handle_GreetServer_Already_Received(t *testing.T) { keypair := generateKeyPair(t) - hub, err := NewHub(keypair.public, "", "", nolog, nil) + hub, err := New(keypair.public, "", "", nolog, nil) require.NoError(t, err) serverInfo1 := method.ServerInfo{ diff --git a/be1-go/network/server_test.go b/be1-go/network/server_test.go index a253b91272..5d1a9b8773 100644 --- a/be1-go/network/server_test.go +++ b/be1-go/network/server_test.go @@ -15,7 +15,7 @@ import ( func TestServerStartAndShutdown(t *testing.T) { log := zerolog.New(io.Discard) - h, err := hub.NewHub(crypto.Suite.Point(), "", "", log, nil) + h, err := hub.New(crypto.Suite.Point(), "", "", log, nil) require.NoErrorf(t, err, "could not create hub") srv := NewServer(h, "", 0, "testsocket", log) diff --git a/be1-go/popcha/server_test.go b/be1-go/popcha/server_test.go index 665c383c89..b827948132 100644 --- a/be1-go/popcha/server_test.go +++ b/be1-go/popcha/server_test.go @@ -54,7 +54,7 @@ func genString(r *rand.Rand, s int) string { func TestAuthServerStartAndShutdown(t *testing.T) { l := popstellar.Logger - h, err := hub.NewHub(crypto.Suite.Point(), "", "", l, nil) + h, err := hub.New(crypto.Suite.Point(), "", "", l, nil) require.NoError(t, err, "could not create hub") s, err := NewAuthServer(h, "localhost", 2003, l) From 1c78b4c72c534accb7e7d7ab1af776dc454c141c Mon Sep 17 00:00:00 2001 From: stuart Date: Mon, 18 Mar 2024 13:58:49 +0100 Subject: [PATCH 3/6] refactor hub --- .../{message_handling.go => client_query.go} | 555 +++++------------- be1-go/hub/hub.go | 31 +- be1-go/hub/server_answer.go | 167 ++++++ be1-go/hub/server_query.go | 148 +++++ 4 files changed, 460 insertions(+), 441 deletions(-) rename be1-go/hub/{message_handling.go => client_query.go} (59%) create mode 100644 be1-go/hub/server_answer.go create mode 100644 be1-go/hub/server_query.go diff --git a/be1-go/hub/message_handling.go b/be1-go/hub/client_query.go similarity index 59% rename from be1-go/hub/message_handling.go rename to be1-go/hub/client_query.go index b5093271ed..ace3221d00 100644 --- a/be1-go/hub/message_handling.go +++ b/be1-go/hub/client_query.go @@ -3,218 +3,136 @@ package hub import ( "encoding/base64" "encoding/json" + "go.dedis.ch/kyber/v3/sign/schnorr" + "golang.org/x/xerrors" "popstellar/crypto" - "popstellar/hub/state" - jsonrpc "popstellar/message" "popstellar/message/answer" "popstellar/message/messagedata" - "popstellar/message/query" "popstellar/message/query/method" "popstellar/message/query/method/message" "popstellar/network/socket" "popstellar/validation" - - "github.com/rs/zerolog/log" - - "go.dedis.ch/kyber/v3/sign/schnorr" - - "golang.org/x/exp/slices" - "golang.org/x/xerrors" ) -const ( - publishError = "failed to publish: %v" - wrongMessageIdError = "message_id is wrong: expected %q found %q" - maxRetry = 10 -) +func (h *Hub) handleSubscribe(socket socket.Socket, byteMessage []byte) (int, error) { + var subscribe method.Subscribe -// handleRootChannelPublishMessage handles an incoming publish message on the root channel. -func (h *Hub) handleRootChannelPublishMessage(sock socket.Socket, publish method.Publish) error { - jsonData, err := base64.URLEncoding.DecodeString(publish.Params.Message.Data) + err := json.Unmarshal(byteMessage, &subscribe) if err != nil { - err := answer.NewInvalidMessageFieldError("failed to decode message data: %v", err) - - return err + return -1, xerrors.Errorf("failed to unmarshal subscribe message: %v", err) } - // validate message data against the json schema - err = h.schemaValidator.VerifyJSON(jsonData, validation.Data) + channel, err := h.getChan(subscribe.Params.Channel) if err != nil { - err := answer.NewInvalidMessageFieldError("failed to validate message against json schema: %v", err) - return err + return subscribe.ID, xerrors.Errorf("failed to get subscribe channel: %v", err) } - // get object#action - object, action, err := messagedata.GetObjectAndAction(jsonData) + err = channel.Subscribe(socket, subscribe) if err != nil { - err := answer.NewInvalidMessageFieldError("failed to get object#action: %v", err) - return err + return subscribe.ID, xerrors.Errorf(publishError, err) } - // must be "lao#create" - if object != messagedata.LAOObject || action != messagedata.LAOActionCreate { - err := answer.NewInvalidMessageFieldError("only lao#create is allowed on root, "+ - "but found %s#%s", object, action) - return err - } + return subscribe.ID, nil +} - var laoCreate messagedata.LaoCreate +func (h *Hub) handleUnsubscribe(socket socket.Socket, byteMessage []byte) (int, error) { + var unsubscribe method.Unsubscribe - err = publish.Params.Message.UnmarshalData(&laoCreate) + err := json.Unmarshal(byteMessage, &unsubscribe) if err != nil { - h.log.Err(err).Msg("failed to unmarshal lao#create") - return err + return -1, xerrors.Errorf("failed to unmarshal unsubscribe message: %v", err) } - err = laoCreate.Verify() + channel, err := h.getChan(unsubscribe.Params.Channel) if err != nil { - h.log.Err(err).Msg("invalid lao#create message " + err.Error()) - return err + return unsubscribe.ID, xerrors.Errorf("failed to get unsubscribe channel: %v", err) } - err = h.createLao(publish.Params.Message, laoCreate, sock) + err = channel.Unsubscribe(socket.ID(), unsubscribe) if err != nil { - h.log.Err(err).Msg("failed to create lao") - return err + return unsubscribe.ID, xerrors.Errorf("failed to unsubscribe: %v", err) } - h.hubInbox.StoreMessage(publish.Params.Channel, publish.Params.Message) - return nil + return unsubscribe.ID, nil } -// handleRootChannelPublishMessage handles an incoming publish message on the root channel. -func (h *Hub) handleRootChannelBroadcastMessage(sock socket.Socket, - broadcast method.Broadcast, -) error { - jsonData, err := base64.URLEncoding.DecodeString(broadcast.Params.Message.Data) - if err != nil { - err := xerrors.Errorf("failed to decode message data: %v", err) - sock.SendError(nil, err) - return err - } - - // validate message data against the json schema - err = h.schemaValidator.VerifyJSON(jsonData, validation.Data) - if err != nil { - err := xerrors.Errorf("failed to validate message against json schema: %v", err) - sock.SendError(nil, err) - return err - } +func (h *Hub) handleCatchup(socket socket.Socket, + byteMessage []byte, +) ([]message.Message, int, error) { + var catchup method.Catchup - // get object#action - object, action, err := messagedata.GetObjectAndAction(jsonData) + err := json.Unmarshal(byteMessage, &catchup) if err != nil { - err := xerrors.Errorf("failed to get object#action: %v", err) - sock.SendError(nil, err) - return err - } - - // must be "lao#create" - if object != messagedata.LAOObject || action != messagedata.LAOActionCreate { - err := xerrors.Errorf("only lao#create is allowed on root, but found %s#%s", - object, action) - sock.SendError(nil, err) - return err + return nil, -1, xerrors.Errorf("failed to unmarshal catchup message: %v", err) } - var laoCreate messagedata.LaoCreate - - err = broadcast.Params.Message.UnmarshalData(&laoCreate) - if err != nil { - h.log.Err(err).Msg("failed to unmarshal lao#create") - sock.SendError(nil, err) - return err + if catchup.Params.Channel == rootChannel { + return h.handleRootCatchup(socket, byteMessage) } - err = laoCreate.Verify() + channel, err := h.getChan(catchup.Params.Channel) if err != nil { - h.log.Err(err).Msg("invalid lao#create message") - sock.SendError(nil, err) - return err + return nil, catchup.ID, xerrors.Errorf("failed to get catchup channel: %v", err) } - err = h.createLao(broadcast.Params.Message, laoCreate, sock) + msg := channel.Catchup(catchup) if err != nil { - h.log.Err(err).Msg("failed to create lao") - sock.SendError(nil, err) - return err + return nil, catchup.ID, xerrors.Errorf("failed to catchup: %v", err) } - h.hubInbox.StoreMessage(broadcast.Params.Channel, broadcast.Params.Message) - return nil + return msg, catchup.ID, nil } -// handleRootCatchup handles an incoming catchup message on the root channel -func (h *Hub) handleRootCatchup(senderSocket socket.Socket, - byteMessage []byte, -) ([]message.Message, int, error) { - var catchup method.Catchup +func (h *Hub) handleBroadcast(socket socket.Socket, byteMessage []byte) error { + var broadcast method.Broadcast - err := json.Unmarshal(byteMessage, &catchup) + err := json.Unmarshal(byteMessage, &broadcast) if err != nil { - return nil, -1, xerrors.Errorf("failed to unmarshal catchup message: %v", err) - } - - if catchup.Params.Channel != rootChannel { - return nil, catchup.ID, xerrors.Errorf("server catchup message can only " + - "be sent on /root channel") + return xerrors.Errorf("failed to unmarshal publish message: %v", err) } - messages := h.hubInbox.GetRootMessages() + signature := broadcast.Params.Message.Signature + messageID := broadcast.Params.Message.MessageID + data := broadcast.Params.Message.Data - return messages, catchup.ID, nil -} + expectedMessageID := messagedata.Hash(data, signature) + if expectedMessageID != messageID { + return xerrors.Errorf(wrongMessageIdError, + expectedMessageID, messageID) + } -// handleAnswer handles the answer to a message sent by the server -func (h *Hub) handleAnswer(senderSocket socket.Socket, byteMessage []byte) error { - var answerMsg answer.Answer + _, ok := h.hubInbox.GetMessage(broadcast.Params.Message.MessageID) + if ok { + h.log.Info().Msg("message was already received") + return nil + } + h.hubInbox.StoreMessage(broadcast.Params.Channel, broadcast.Params.Message) - err := json.Unmarshal(byteMessage, &answerMsg) if err != nil { - return xerrors.Errorf("failed to unmarshal answer: %v", err) + return xerrors.Errorf("failed to broadcast message: %v", err) } - if answerMsg.Result == nil { - h.log.Warn().Msg("received an error, nothing to handle") - // don't send any error to avoid infinite error loop as a server will - // send an error to another server that will create another error - return nil - } - if answerMsg.Result.IsEmpty() { - h.log.Info().Msg("result isn't an answer to a query, nothing to handle") + if broadcast.Params.Channel == rootChannel { + err := h.handleRootChannelBroadcastMessage(socket, broadcast) + if err != nil { + return xerrors.Errorf(rootChannelErr, err) + } return nil } - err = h.queries.SetQueryReceived(*answerMsg.ID) + channel, err := h.getChan(broadcast.Params.Channel) if err != nil { - return xerrors.Errorf("failed to set query state: %v", err) + return xerrors.Errorf(getChannelErr, err) } - err = h.handleGetMessagesByIdAnswer(senderSocket, answerMsg) + err = channel.Broadcast(broadcast, socket) if err != nil { - return err + return xerrors.Errorf(publishError, err) } return nil } -func (h *Hub) handleGetMessagesByIdAnswer(senderSocket socket.Socket, answerMsg answer.Answer) error { - var err error - messages := answerMsg.Result.GetMessagesByChannel() - tempBlacklist := make([]string, 0) - // Loops over the messages to process them until it succeeds or reaches - // the max number of attempts - for i := 0; i < maxRetry; i++ { - tempBlacklist, err = h.loopOverMessages(&messages, senderSocket) - if err == nil && len(tempBlacklist) == 0 { - return nil - } - } - // Add contents from tempBlacklist to h.blacklist - h.blacklist.Append(tempBlacklist...) - return xerrors.Errorf("failed to process messages: %v", err) -} - func (h *Hub) handlePublish(socket socket.Socket, byteMessage []byte) (int, error) { var publish method.Publish @@ -282,336 +200,145 @@ func (h *Hub) handlePublish(socket socket.Socket, byteMessage []byte) (int, erro return publish.ID, nil } -func (h *Hub) handleBroadcast(socket socket.Socket, byteMessage []byte) error { - var broadcast method.Broadcast - - err := json.Unmarshal(byteMessage, &broadcast) +// handleRootChannelPublishMessage handles an incoming publish message on the root channel. +func (h *Hub) handleRootChannelPublishMessage(sock socket.Socket, publish method.Publish) error { + jsonData, err := base64.URLEncoding.DecodeString(publish.Params.Message.Data) if err != nil { - return xerrors.Errorf("failed to unmarshal publish message: %v", err) - } - - signature := broadcast.Params.Message.Signature - messageID := broadcast.Params.Message.MessageID - data := broadcast.Params.Message.Data - - expectedMessageID := messagedata.Hash(data, signature) - if expectedMessageID != messageID { - return xerrors.Errorf(wrongMessageIdError, - expectedMessageID, messageID) - } + err := answer.NewInvalidMessageFieldError("failed to decode message data: %v", err) - _, ok := h.hubInbox.GetMessage(broadcast.Params.Message.MessageID) - if ok { - h.log.Info().Msg("message was already received") - return nil + return err } - h.hubInbox.StoreMessage(broadcast.Params.Channel, broadcast.Params.Message) + // validate message data against the json schema + err = h.schemaValidator.VerifyJSON(jsonData, validation.Data) if err != nil { - return xerrors.Errorf("failed to broadcast message: %v", err) - } - - if broadcast.Params.Channel == rootChannel { - err := h.handleRootChannelBroadcastMessage(socket, broadcast) - if err != nil { - return xerrors.Errorf(rootChannelErr, err) - } - return nil + err := answer.NewInvalidMessageFieldError("failed to validate message against json schema: %v", err) + return err } - channel, err := h.getChan(broadcast.Params.Channel) + // get object#action + object, action, err := messagedata.GetObjectAndAction(jsonData) if err != nil { - return xerrors.Errorf(getChannelErr, err) + err := answer.NewInvalidMessageFieldError("failed to get object#action: %v", err) + return err } - err = channel.Broadcast(broadcast, socket) - if err != nil { - return xerrors.Errorf(publishError, err) + // must be "lao#create" + if object != messagedata.LAOObject || action != messagedata.LAOActionCreate { + err := answer.NewInvalidMessageFieldError("only lao#create is allowed on root, "+ + "but found %s#%s", object, action) + return err } - return nil -} - -func (h *Hub) handleSubscribe(socket socket.Socket, byteMessage []byte) (int, error) { - var subscribe method.Subscribe + var laoCreate messagedata.LaoCreate - err := json.Unmarshal(byteMessage, &subscribe) + err = publish.Params.Message.UnmarshalData(&laoCreate) if err != nil { - return -1, xerrors.Errorf("failed to unmarshal subscribe message: %v", err) + h.log.Err(err).Msg("failed to unmarshal lao#create") + return err } - channel, err := h.getChan(subscribe.Params.Channel) + err = laoCreate.Verify() if err != nil { - return subscribe.ID, xerrors.Errorf("failed to get subscribe channel: %v", err) + h.log.Err(err).Msg("invalid lao#create message " + err.Error()) + return err } - err = channel.Subscribe(socket, subscribe) + err = h.createLao(publish.Params.Message, laoCreate, sock) if err != nil { - return subscribe.ID, xerrors.Errorf(publishError, err) + h.log.Err(err).Msg("failed to create lao") + return err } - return subscribe.ID, nil + h.hubInbox.StoreMessage(publish.Params.Channel, publish.Params.Message) + return nil } -func (h *Hub) handleUnsubscribe(socket socket.Socket, byteMessage []byte) (int, error) { - var unsubscribe method.Unsubscribe - - err := json.Unmarshal(byteMessage, &unsubscribe) +// handleRootChannelPublishMessage handles an incoming publish message on the root channel. +func (h *Hub) handleRootChannelBroadcastMessage(sock socket.Socket, + broadcast method.Broadcast, +) error { + jsonData, err := base64.URLEncoding.DecodeString(broadcast.Params.Message.Data) if err != nil { - return -1, xerrors.Errorf("failed to unmarshal unsubscribe message: %v", err) + err := xerrors.Errorf("failed to decode message data: %v", err) + sock.SendError(nil, err) + return err } - channel, err := h.getChan(unsubscribe.Params.Channel) + // validate message data against the json schema + err = h.schemaValidator.VerifyJSON(jsonData, validation.Data) if err != nil { - return unsubscribe.ID, xerrors.Errorf("failed to get unsubscribe channel: %v", err) + err := xerrors.Errorf("failed to validate message against json schema: %v", err) + sock.SendError(nil, err) + return err } - err = channel.Unsubscribe(socket.ID(), unsubscribe) + // get object#action + object, action, err := messagedata.GetObjectAndAction(jsonData) if err != nil { - return unsubscribe.ID, xerrors.Errorf("failed to unsubscribe: %v", err) + err := xerrors.Errorf("failed to get object#action: %v", err) + sock.SendError(nil, err) + return err } - return unsubscribe.ID, nil -} - -func (h *Hub) handleCatchup(socket socket.Socket, - byteMessage []byte, -) ([]message.Message, int, error) { - var catchup method.Catchup - - err := json.Unmarshal(byteMessage, &catchup) - if err != nil { - return nil, -1, xerrors.Errorf("failed to unmarshal catchup message: %v", err) + // must be "lao#create" + if object != messagedata.LAOObject || action != messagedata.LAOActionCreate { + err := xerrors.Errorf("only lao#create is allowed on root, but found %s#%s", + object, action) + sock.SendError(nil, err) + return err } - if catchup.Params.Channel == rootChannel { - return h.handleRootCatchup(socket, byteMessage) - } + var laoCreate messagedata.LaoCreate - channel, err := h.getChan(catchup.Params.Channel) + err = broadcast.Params.Message.UnmarshalData(&laoCreate) if err != nil { - return nil, catchup.ID, xerrors.Errorf("failed to get catchup channel: %v", err) + h.log.Err(err).Msg("failed to unmarshal lao#create") + sock.SendError(nil, err) + return err } - msg := channel.Catchup(catchup) + err = laoCreate.Verify() if err != nil { - return nil, catchup.ID, xerrors.Errorf("failed to catchup: %v", err) + h.log.Err(err).Msg("invalid lao#create message") + sock.SendError(nil, err) + return err } - return msg, catchup.ID, nil -} - -func (h *Hub) handleHeartbeat(socket socket.Socket, - byteMessage []byte, -) error { - var heartbeat method.Heartbeat - - err := json.Unmarshal(byteMessage, &heartbeat) + err = h.createLao(broadcast.Params.Message, laoCreate, sock) if err != nil { - return xerrors.Errorf("failed to unmarshal heartbeat message: %v", err) - } - - receivedIds := heartbeat.Params - - missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), &h.blacklist) - - if len(missingIds) > 0 { - err = h.sendGetMessagesByIdToServer(socket, missingIds) - if err != nil { - return xerrors.Errorf("failed to send getMessagesById message: %v", err) - } + h.log.Err(err).Msg("failed to create lao") + sock.SendError(nil, err) + return err } + h.hubInbox.StoreMessage(broadcast.Params.Channel, broadcast.Params.Message) return nil } -func (h *Hub) handleGetMessagesById(socket socket.Socket, +// handleRootCatchup handles an incoming catchup message on the root channel +func (h *Hub) handleRootCatchup(senderSocket socket.Socket, byteMessage []byte, -) (map[string][]message.Message, int, error) { - var getMessagesById method.GetMessagesById - - err := json.Unmarshal(byteMessage, &getMessagesById) - if err != nil { - return nil, 0, xerrors.Errorf("failed to unmarshal getMessagesById message: %v", err) - } - - missingMessages, err := h.getMissingMessages(getMessagesById.Params) - if err != nil { - return nil, getMessagesById.ID, xerrors.Errorf("failed to retrieve messages: %v", err) - } - - return missingMessages, getMessagesById.ID, nil -} - -func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error { - var greetServer method.GreetServer +) ([]message.Message, int, error) { + var catchup method.Catchup - err := json.Unmarshal(byteMessage, &greetServer) + err := json.Unmarshal(byteMessage, &catchup) if err != nil { - return xerrors.Errorf("failed to unmarshal greetServer message: %v", err) + return nil, -1, xerrors.Errorf("failed to unmarshal catchup message: %v", err) } - // store information about the server - err = h.peers.AddPeerInfo(socket.ID(), greetServer.Params) - if err != nil { - return xerrors.Errorf("failed to add peer info: %v", err) + if catchup.Params.Channel != rootChannel { + return nil, catchup.ID, xerrors.Errorf("server catchup message can only " + + "be sent on /root channel") } - if h.peers.IsPeerGreeted(socket.ID()) { - return nil - } + messages := h.hubInbox.GetRootMessages() - err = h.SendGreetServer(socket) - if err != nil { - return xerrors.Errorf("failed to send greetServer message: %v", err) - } - return nil + return messages, catchup.ID, nil } //-----------------------Helper methods for message handling--------------------------- -// getMissingIds compares two maps of channel Ids associated to slices of message Ids to -// determine the missing Ids from the storedIds map with respect to the receivedIds map -func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist *state.ThreadSafeSlice[string]) map[string][]string { - missingIds := make(map[string][]string) - for channelId, receivedMessageIds := range receivedIds { - for _, messageId := range receivedMessageIds { - blacklisted := blacklist.Contains(messageId) - storedIdsForChannel, channelKnown := storedIds[channelId] - if blacklisted { - break - } - if channelKnown { - contains := slices.Contains(storedIdsForChannel, messageId) - if !contains { - missingIds[channelId] = append(missingIds[channelId], messageId) - } - } else { - missingIds[channelId] = append(missingIds[channelId], messageId) - } - } - } - return missingIds -} - -// getMissingMessages retrieves the missing messages from the inbox given their Ids -func (h *Hub) getMissingMessages(missingIds map[string][]string) (map[string][]message.Message, error) { - missingMsgs := make(map[string][]message.Message) - for channelId, messageIds := range missingIds { - for _, messageId := range messageIds { - msg, exists := h.hubInbox.GetMessage(messageId) - if !exists { - return nil, xerrors.Errorf("Message %s not found in hub inbox", messageId) - } - missingMsgs[channelId] = append(missingMsgs[channelId], *msg) - } - } - return missingMsgs, nil -} - -// handleReceivedMessage handle a message obtained by the server receiving a -// getMessagesById result -func (h *Hub) handleReceivedMessage(socket socket.Socket, messageData message.Message, targetChannel string) error { - signature := messageData.Signature - messageID := messageData.MessageID - data := messageData.Data - log.Info().Msgf("Received message on %s", targetChannel) - - expectedMessageID := messagedata.Hash(data, signature) - if expectedMessageID != messageID { - return xerrors.Errorf(wrongMessageIdError, - expectedMessageID, messageID) - } - - publish := method.Publish{ - Base: query.Base{ - JSONRPCBase: jsonrpc.JSONRPCBase{ - JSONRPC: "2.0", - }, - Method: "publish", - }, - - Params: struct { - Channel string `json:"channel"` - Message message.Message `json:"message"` - }{ - Channel: targetChannel, - Message: messageData, - }, - } - _, stored := h.hubInbox.GetMessage(publish.Params.Message.MessageID) - if stored { - h.log.Info().Msgf("Already stored message %s", publish.Params.Message.MessageID) - return nil - } - - if publish.Params.Channel == rootChannel { - err := h.handleRootChannelPublishMessage(socket, publish) - if err != nil { - return xerrors.Errorf(rootChannelErr, err) - } - return nil - } - - channel, err := h.getChan(publish.Params.Channel) - if err != nil { - return xerrors.Errorf(getChannelErr, err) - } - - err = channel.Publish(publish, socket) - if err != nil { - return xerrors.Errorf(publishError, err) - } - - h.hubInbox.StoreMessage(publish.Params.Channel, publish.Params.Message) - return nil -} - -// loopOverMessages loops over the messages received from a getMessagesById answer to process them -// and update the list of messages to process during the next iteration with those that fail -func (h *Hub) loopOverMessages(messages *map[string][]json.RawMessage, senderSocket socket.Socket) ([]string, error) { - var errMsg string - tempBlacklist := make([]string, 0) - for channel, messageArray := range *messages { - newMessageArray := make([]json.RawMessage, 0) - - // Try to process each message - for _, msg := range messageArray { - var messageData message.Message - err := json.Unmarshal(msg, &messageData) - if err != nil { - h.log.Error().Msgf("failed to unmarshal message during getMessagesById answer handling: %v", err) - continue - } - - if h.blacklist.Contains(messageData.MessageID) { - break - } - - err = h.handleReceivedMessage(senderSocket, messageData, channel) - if err != nil { - h.log.Error().Msgf("failed to handle message received from getMessagesById answer: %v", err) - newMessageArray = append(newMessageArray, msg) // if there's an error, keep the message - errMsg += err.Error() - - // Add the ID of the failed message to the blacklist - tempBlacklist = append(tempBlacklist, messageData.MessageID) - } - } - // Update the list of messages to process during the next iteration - (*messages)[channel] = newMessageArray - // if no messages left for the channel, remove the channel from the map - if len(newMessageArray) == 0 { - delete(*messages, channel) - } - } - - if errMsg != "" { - return tempBlacklist, xerrors.New(errMsg) - } - return tempBlacklist, nil -} - // createLao creates a new LAO using the data in the publish parameter. func (h *Hub) createLao(msg message.Message, laoCreate messagedata.LaoCreate, socket socket.Socket, diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index e533a541e5..b48b305dec 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -45,6 +45,10 @@ const ( // heartbeatDelay represents the number of seconds // between heartbeat messages heartbeatDelay = 30 * time.Second + + publishError = "failed to publish: %v" + wrongMessageIdError = "message_id is wrong: expected %q found %q" + maxRetry = 10 ) var suite = crypto.Suite @@ -476,33 +480,6 @@ func (h *Hub) handleIncomingMessage(incomingMessage *socket.IncomingMessage) err } } -// sendGetMessagesByIdToServer sends a getMessagesById message to a server -func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[string][]string) error { - queryId := h.queries.GetNextID() - - getMessagesById := method.GetMessagesById{ - Base: query.Base{ - JSONRPCBase: jsonrpc.JSONRPCBase{ - JSONRPC: "2.0", - }, - Method: "get_messages_by_id", - }, - ID: queryId, - Params: missingIds, - } - - buf, err := json.Marshal(getMessagesById) - if err != nil { - return xerrors.Errorf("failed to marshal getMessagesById query: %v", err) - } - - socket.Send(buf) - - h.queries.AddQuery(queryId, getMessagesById) - - return nil -} - // sendHeartbeatToServers sends a heartbeat message to all servers func (h *Hub) sendHeartbeatToServers() { heartbeatMessage := method.Heartbeat{ diff --git a/be1-go/hub/server_answer.go b/be1-go/hub/server_answer.go new file mode 100644 index 0000000000..64fe4d2941 --- /dev/null +++ b/be1-go/hub/server_answer.go @@ -0,0 +1,167 @@ +package hub + +import ( + "encoding/json" + "github.com/rs/zerolog/log" + "golang.org/x/xerrors" + message2 "popstellar/message" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "popstellar/network/socket" +) + +// handleAnswer handles the answer to a message sent by the server +func (h *Hub) handleAnswer(senderSocket socket.Socket, byteMessage []byte) error { + var answerMsg answer.Answer + + err := json.Unmarshal(byteMessage, &answerMsg) + if err != nil { + return xerrors.Errorf("failed to unmarshal answer: %v", err) + } + + if answerMsg.Result == nil { + h.log.Warn().Msg("received an error, nothing to handle") + // don't send any error to avoid infinite error loop as a server will + // send an error to another server that will create another error + return nil + } + if answerMsg.Result.IsEmpty() { + h.log.Info().Msg("result isn't an answer to a query, nothing to handle") + return nil + } + + err = h.queries.SetQueryReceived(*answerMsg.ID) + if err != nil { + return xerrors.Errorf("failed to set query state: %v", err) + } + + err = h.handleGetMessagesByIdAnswer(senderSocket, answerMsg) + if err != nil { + return err + } + + return nil +} + +func (h *Hub) handleGetMessagesByIdAnswer(senderSocket socket.Socket, answerMsg answer.Answer) error { + var err error + messages := answerMsg.Result.GetMessagesByChannel() + tempBlacklist := make([]string, 0) + // Loops over the messages to process them until it succeeds or reaches + // the max number of attempts + for i := 0; i < maxRetry; i++ { + tempBlacklist, err = h.loopOverMessages(&messages, senderSocket) + if err == nil && len(tempBlacklist) == 0 { + return nil + } + } + // Add contents from tempBlacklist to h.blacklist + h.blacklist.Append(tempBlacklist...) + return xerrors.Errorf("failed to process messages: %v", err) +} + +// loopOverMessages loops over the messages received from a getMessagesById answer to process them +// and update the list of messages to process during the next iteration with those that fail +func (h *Hub) loopOverMessages(messages *map[string][]json.RawMessage, senderSocket socket.Socket) ([]string, error) { + var errMsg string + tempBlacklist := make([]string, 0) + for channel, messageArray := range *messages { + newMessageArray := make([]json.RawMessage, 0) + + // Try to process each message + for _, msg := range messageArray { + var messageData message.Message + err := json.Unmarshal(msg, &messageData) + if err != nil { + h.log.Error().Msgf("failed to unmarshal message during getMessagesById answer handling: %v", err) + continue + } + + if h.blacklist.Contains(messageData.MessageID) { + break + } + + err = h.handleReceivedMessage(senderSocket, messageData, channel) + if err != nil { + h.log.Error().Msgf("failed to handle message received from getMessagesById answer: %v", err) + newMessageArray = append(newMessageArray, msg) // if there's an error, keep the message + errMsg += err.Error() + + // Add the ID of the failed message to the blacklist + tempBlacklist = append(tempBlacklist, messageData.MessageID) + } + } + // Update the list of messages to process during the next iteration + (*messages)[channel] = newMessageArray + // if no messages left for the channel, remove the channel from the map + if len(newMessageArray) == 0 { + delete(*messages, channel) + } + } + + if errMsg != "" { + return tempBlacklist, xerrors.New(errMsg) + } + return tempBlacklist, nil +} + +// handleReceivedMessage handle a message obtained by the server receiving a +// getMessagesById result +func (h *Hub) handleReceivedMessage(socket socket.Socket, messageData message.Message, targetChannel string) error { + signature := messageData.Signature + messageID := messageData.MessageID + data := messageData.Data + log.Info().Msgf("Received message on %s", targetChannel) + + expectedMessageID := messagedata.Hash(data, signature) + if expectedMessageID != messageID { + return xerrors.Errorf(wrongMessageIdError, + expectedMessageID, messageID) + } + + publish := method.Publish{ + Base: query.Base{ + JSONRPCBase: message2.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: "publish", + }, + + Params: struct { + Channel string `json:"channel"` + Message message.Message `json:"message"` + }{ + Channel: targetChannel, + Message: messageData, + }, + } + _, stored := h.hubInbox.GetMessage(publish.Params.Message.MessageID) + if stored { + h.log.Info().Msgf("Already stored message %s", publish.Params.Message.MessageID) + return nil + } + + if publish.Params.Channel == rootChannel { + err := h.handleRootChannelPublishMessage(socket, publish) + if err != nil { + return xerrors.Errorf(rootChannelErr, err) + } + return nil + } + + channel, err := h.getChan(publish.Params.Channel) + if err != nil { + return xerrors.Errorf(getChannelErr, err) + } + + err = channel.Publish(publish, socket) + if err != nil { + return xerrors.Errorf(publishError, err) + } + + h.hubInbox.StoreMessage(publish.Params.Channel, publish.Params.Message) + return nil +} diff --git a/be1-go/hub/server_query.go b/be1-go/hub/server_query.go new file mode 100644 index 0000000000..5f5c3749aa --- /dev/null +++ b/be1-go/hub/server_query.go @@ -0,0 +1,148 @@ +package hub + +import ( + "encoding/json" + "golang.org/x/exp/slices" + "golang.org/x/xerrors" + "popstellar/hub/state" + message2 "popstellar/message" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "popstellar/network/socket" +) + +func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error { + var greetServer method.GreetServer + + err := json.Unmarshal(byteMessage, &greetServer) + if err != nil { + return xerrors.Errorf("failed to unmarshal greetServer message: %v", err) + } + + // store information about the server + err = h.peers.AddPeerInfo(socket.ID(), greetServer.Params) + if err != nil { + return xerrors.Errorf("failed to add peer info: %v", err) + } + + if h.peers.IsPeerGreeted(socket.ID()) { + return nil + } + + err = h.SendGreetServer(socket) + if err != nil { + return xerrors.Errorf("failed to send greetServer message: %v", err) + } + return nil +} + +func (h *Hub) handleHeartbeat(socket socket.Socket, + byteMessage []byte, +) error { + var heartbeat method.Heartbeat + + err := json.Unmarshal(byteMessage, &heartbeat) + if err != nil { + return xerrors.Errorf("failed to unmarshal heartbeat message: %v", err) + } + + receivedIds := heartbeat.Params + + missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), &h.blacklist) + + if len(missingIds) > 0 { + err = h.sendGetMessagesByIdToServer(socket, missingIds) + if err != nil { + return xerrors.Errorf("failed to send getMessagesById message: %v", err) + } + } + + return nil +} + +func (h *Hub) handleGetMessagesById(socket socket.Socket, + byteMessage []byte, +) (map[string][]message.Message, int, error) { + var getMessagesById method.GetMessagesById + + err := json.Unmarshal(byteMessage, &getMessagesById) + if err != nil { + return nil, 0, xerrors.Errorf("failed to unmarshal getMessagesById message: %v", err) + } + + missingMessages, err := h.getMissingMessages(getMessagesById.Params) + if err != nil { + return nil, getMessagesById.ID, xerrors.Errorf("failed to retrieve messages: %v", err) + } + + return missingMessages, getMessagesById.ID, nil +} + +//-----------------------Helper methods for message handling--------------------------- + +// getMissingIds compares two maps of channel Ids associated to slices of message Ids to +// determine the missing Ids from the storedIds map with respect to the receivedIds map +func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist *state.ThreadSafeSlice[string]) map[string][]string { + missingIds := make(map[string][]string) + for channelId, receivedMessageIds := range receivedIds { + for _, messageId := range receivedMessageIds { + blacklisted := blacklist.Contains(messageId) + storedIdsForChannel, channelKnown := storedIds[channelId] + if blacklisted { + break + } + if channelKnown { + contains := slices.Contains(storedIdsForChannel, messageId) + if !contains { + missingIds[channelId] = append(missingIds[channelId], messageId) + } + } else { + missingIds[channelId] = append(missingIds[channelId], messageId) + } + } + } + return missingIds +} + +// getMissingMessages retrieves the missing messages from the inbox given their Ids +func (h *Hub) getMissingMessages(missingIds map[string][]string) (map[string][]message.Message, error) { + missingMsgs := make(map[string][]message.Message) + for channelId, messageIds := range missingIds { + for _, messageId := range messageIds { + msg, exists := h.hubInbox.GetMessage(messageId) + if !exists { + return nil, xerrors.Errorf("Message %s not found in hub inbox", messageId) + } + missingMsgs[channelId] = append(missingMsgs[channelId], *msg) + } + } + return missingMsgs, nil +} + +// sendGetMessagesByIdToServer sends a getMessagesById message to a server +func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[string][]string) error { + queryId := h.queries.GetNextID() + + getMessagesById := method.GetMessagesById{ + Base: query.Base{ + JSONRPCBase: message2.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: "get_messages_by_id", + }, + ID: queryId, + Params: missingIds, + } + + buf, err := json.Marshal(getMessagesById) + if err != nil { + return xerrors.Errorf("failed to marshal getMessagesById query: %v", err) + } + + socket.Send(buf) + + h.queries.AddQuery(queryId, getMessagesById) + + return nil +} From ec50faffd2751f6ac75a64989cce5596049bde25 Mon Sep 17 00:00:00 2001 From: stuart Date: Mon, 18 Mar 2024 19:16:47 +0100 Subject: [PATCH 4/6] Remove useless method from client and server and fix tests --- be1-go/hub/hub.go | 5 ----- be1-go/hub/hub_test.go | 9 ++------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index b48b305dec..6efecd30a9 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -430,15 +430,10 @@ func (h *Hub) handleMessageFromServer(incomingMessage *socket.IncomingMessage) e case query.MethodPublish: id, handlerErr = h.handlePublish(socket, byteMessage) h.sendHeartbeatToServers() - case query.MethodSubscribe: - id, handlerErr = h.handleSubscribe(socket, byteMessage) - case query.MethodUnsubscribe: - id, handlerErr = h.handleUnsubscribe(socket, byteMessage) case query.MethodHeartbeat: handlerErr = h.handleHeartbeat(socket, byteMessage) case query.MethodGetMessagesById: msgsByChannel, id, handlerErr = h.handleGetMessagesById(socket, byteMessage) - default: err = answer.NewErrorf(-2, "unexpected method: '%s'", queryBase.Method) socket.SendError(nil, err) diff --git a/be1-go/hub/hub_test.go b/be1-go/hub/hub_test.go index f2b3751494..6931e43966 100644 --- a/be1-go/hub/hub_test.go +++ b/be1-go/hub/hub_test.go @@ -1386,7 +1386,7 @@ func Test_Handle_Subscribe(t *testing.T) { require.Equal(t, subscribe, c.subscribe) // check that there is no errors with messages from witness too - hub.handleMessageFromServer(&socket.IncomingMessage{ + hub.handleMessageFromClient(&socket.IncomingMessage{ Socket: sock, Message: publishBuf, }) @@ -1456,12 +1456,7 @@ func TestServer_Handle_Unsubscribe(t *testing.T) { }) // check the socket - require.NoError(t, sock.err) - require.Equal(t, unsubscribe.ID, sock.resultID) - - // check that the channel has been called with the publish message - require.Equal(t, unsubscribe, c.unsubscribe) - require.Equal(t, sock.id, c.socketID) + require.Error(t, sock.err) } // Check that if the server receives a catchup message, it will call the From 33ed22614009b351525fb374c9148a94c9b21fa8 Mon Sep 17 00:00:00 2001 From: stuart Date: Tue, 19 Mar 2024 12:22:34 +0100 Subject: [PATCH 5/6] Rename files --- .../hub/{client_query.go => from_client.go} | 0 .../hub/{server_answer.go => from_server.go} | 137 ++++++++++++++++ be1-go/hub/server_query.go | 148 ------------------ 3 files changed, 137 insertions(+), 148 deletions(-) rename be1-go/hub/{client_query.go => from_client.go} (100%) rename be1-go/hub/{server_answer.go => from_server.go} (55%) delete mode 100644 be1-go/hub/server_query.go diff --git a/be1-go/hub/client_query.go b/be1-go/hub/from_client.go similarity index 100% rename from be1-go/hub/client_query.go rename to be1-go/hub/from_client.go diff --git a/be1-go/hub/server_answer.go b/be1-go/hub/from_server.go similarity index 55% rename from be1-go/hub/server_answer.go rename to be1-go/hub/from_server.go index 64fe4d2941..6c19cd357e 100644 --- a/be1-go/hub/server_answer.go +++ b/be1-go/hub/from_server.go @@ -3,7 +3,9 @@ package hub import ( "encoding/json" "github.com/rs/zerolog/log" + "golang.org/x/exp/slices" "golang.org/x/xerrors" + "popstellar/hub/state" message2 "popstellar/message" "popstellar/message/answer" "popstellar/message/messagedata" @@ -13,6 +15,73 @@ import ( "popstellar/network/socket" ) +func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error { + var greetServer method.GreetServer + + err := json.Unmarshal(byteMessage, &greetServer) + if err != nil { + return xerrors.Errorf("failed to unmarshal greetServer message: %v", err) + } + + // store information about the server + err = h.peers.AddPeerInfo(socket.ID(), greetServer.Params) + if err != nil { + return xerrors.Errorf("failed to add peer info: %v", err) + } + + if h.peers.IsPeerGreeted(socket.ID()) { + return nil + } + + err = h.SendGreetServer(socket) + if err != nil { + return xerrors.Errorf("failed to send greetServer message: %v", err) + } + return nil +} + +func (h *Hub) handleHeartbeat(socket socket.Socket, + byteMessage []byte, +) error { + var heartbeat method.Heartbeat + + err := json.Unmarshal(byteMessage, &heartbeat) + if err != nil { + return xerrors.Errorf("failed to unmarshal heartbeat message: %v", err) + } + + receivedIds := heartbeat.Params + + missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), &h.blacklist) + + if len(missingIds) > 0 { + err = h.sendGetMessagesByIdToServer(socket, missingIds) + if err != nil { + return xerrors.Errorf("failed to send getMessagesById message: %v", err) + } + } + + return nil +} + +func (h *Hub) handleGetMessagesById(socket socket.Socket, + byteMessage []byte, +) (map[string][]message.Message, int, error) { + var getMessagesById method.GetMessagesById + + err := json.Unmarshal(byteMessage, &getMessagesById) + if err != nil { + return nil, 0, xerrors.Errorf("failed to unmarshal getMessagesById message: %v", err) + } + + missingMessages, err := h.getMissingMessages(getMessagesById.Params) + if err != nil { + return nil, getMessagesById.ID, xerrors.Errorf("failed to retrieve messages: %v", err) + } + + return missingMessages, getMessagesById.ID, nil +} + // handleAnswer handles the answer to a message sent by the server func (h *Hub) handleAnswer(senderSocket socket.Socket, byteMessage []byte) error { var answerMsg answer.Answer @@ -165,3 +234,71 @@ func (h *Hub) handleReceivedMessage(socket socket.Socket, messageData message.Me h.hubInbox.StoreMessage(publish.Params.Channel, publish.Params.Message) return nil } + +//-----------------------Helper methods for message handling--------------------------- + +// getMissingIds compares two maps of channel Ids associated to slices of message Ids to +// determine the missing Ids from the storedIds map with respect to the receivedIds map +func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist *state.ThreadSafeSlice[string]) map[string][]string { + missingIds := make(map[string][]string) + for channelId, receivedMessageIds := range receivedIds { + for _, messageId := range receivedMessageIds { + blacklisted := blacklist.Contains(messageId) + storedIdsForChannel, channelKnown := storedIds[channelId] + if blacklisted { + break + } + if channelKnown { + contains := slices.Contains(storedIdsForChannel, messageId) + if !contains { + missingIds[channelId] = append(missingIds[channelId], messageId) + } + } else { + missingIds[channelId] = append(missingIds[channelId], messageId) + } + } + } + return missingIds +} + +// getMissingMessages retrieves the missing messages from the inbox given their Ids +func (h *Hub) getMissingMessages(missingIds map[string][]string) (map[string][]message.Message, error) { + missingMsgs := make(map[string][]message.Message) + for channelId, messageIds := range missingIds { + for _, messageId := range messageIds { + msg, exists := h.hubInbox.GetMessage(messageId) + if !exists { + return nil, xerrors.Errorf("Message %s not found in hub inbox", messageId) + } + missingMsgs[channelId] = append(missingMsgs[channelId], *msg) + } + } + return missingMsgs, nil +} + +// sendGetMessagesByIdToServer sends a getMessagesById message to a server +func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[string][]string) error { + queryId := h.queries.GetNextID() + + getMessagesById := method.GetMessagesById{ + Base: query.Base{ + JSONRPCBase: message2.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: "get_messages_by_id", + }, + ID: queryId, + Params: missingIds, + } + + buf, err := json.Marshal(getMessagesById) + if err != nil { + return xerrors.Errorf("failed to marshal getMessagesById query: %v", err) + } + + socket.Send(buf) + + h.queries.AddQuery(queryId, getMessagesById) + + return nil +} diff --git a/be1-go/hub/server_query.go b/be1-go/hub/server_query.go deleted file mode 100644 index 5f5c3749aa..0000000000 --- a/be1-go/hub/server_query.go +++ /dev/null @@ -1,148 +0,0 @@ -package hub - -import ( - "encoding/json" - "golang.org/x/exp/slices" - "golang.org/x/xerrors" - "popstellar/hub/state" - message2 "popstellar/message" - "popstellar/message/query" - "popstellar/message/query/method" - "popstellar/message/query/method/message" - "popstellar/network/socket" -) - -func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error { - var greetServer method.GreetServer - - err := json.Unmarshal(byteMessage, &greetServer) - if err != nil { - return xerrors.Errorf("failed to unmarshal greetServer message: %v", err) - } - - // store information about the server - err = h.peers.AddPeerInfo(socket.ID(), greetServer.Params) - if err != nil { - return xerrors.Errorf("failed to add peer info: %v", err) - } - - if h.peers.IsPeerGreeted(socket.ID()) { - return nil - } - - err = h.SendGreetServer(socket) - if err != nil { - return xerrors.Errorf("failed to send greetServer message: %v", err) - } - return nil -} - -func (h *Hub) handleHeartbeat(socket socket.Socket, - byteMessage []byte, -) error { - var heartbeat method.Heartbeat - - err := json.Unmarshal(byteMessage, &heartbeat) - if err != nil { - return xerrors.Errorf("failed to unmarshal heartbeat message: %v", err) - } - - receivedIds := heartbeat.Params - - missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), &h.blacklist) - - if len(missingIds) > 0 { - err = h.sendGetMessagesByIdToServer(socket, missingIds) - if err != nil { - return xerrors.Errorf("failed to send getMessagesById message: %v", err) - } - } - - return nil -} - -func (h *Hub) handleGetMessagesById(socket socket.Socket, - byteMessage []byte, -) (map[string][]message.Message, int, error) { - var getMessagesById method.GetMessagesById - - err := json.Unmarshal(byteMessage, &getMessagesById) - if err != nil { - return nil, 0, xerrors.Errorf("failed to unmarshal getMessagesById message: %v", err) - } - - missingMessages, err := h.getMissingMessages(getMessagesById.Params) - if err != nil { - return nil, getMessagesById.ID, xerrors.Errorf("failed to retrieve messages: %v", err) - } - - return missingMessages, getMessagesById.ID, nil -} - -//-----------------------Helper methods for message handling--------------------------- - -// getMissingIds compares two maps of channel Ids associated to slices of message Ids to -// determine the missing Ids from the storedIds map with respect to the receivedIds map -func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist *state.ThreadSafeSlice[string]) map[string][]string { - missingIds := make(map[string][]string) - for channelId, receivedMessageIds := range receivedIds { - for _, messageId := range receivedMessageIds { - blacklisted := blacklist.Contains(messageId) - storedIdsForChannel, channelKnown := storedIds[channelId] - if blacklisted { - break - } - if channelKnown { - contains := slices.Contains(storedIdsForChannel, messageId) - if !contains { - missingIds[channelId] = append(missingIds[channelId], messageId) - } - } else { - missingIds[channelId] = append(missingIds[channelId], messageId) - } - } - } - return missingIds -} - -// getMissingMessages retrieves the missing messages from the inbox given their Ids -func (h *Hub) getMissingMessages(missingIds map[string][]string) (map[string][]message.Message, error) { - missingMsgs := make(map[string][]message.Message) - for channelId, messageIds := range missingIds { - for _, messageId := range messageIds { - msg, exists := h.hubInbox.GetMessage(messageId) - if !exists { - return nil, xerrors.Errorf("Message %s not found in hub inbox", messageId) - } - missingMsgs[channelId] = append(missingMsgs[channelId], *msg) - } - } - return missingMsgs, nil -} - -// sendGetMessagesByIdToServer sends a getMessagesById message to a server -func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[string][]string) error { - queryId := h.queries.GetNextID() - - getMessagesById := method.GetMessagesById{ - Base: query.Base{ - JSONRPCBase: message2.JSONRPCBase{ - JSONRPC: "2.0", - }, - Method: "get_messages_by_id", - }, - ID: queryId, - Params: missingIds, - } - - buf, err := json.Marshal(getMessagesById) - if err != nil { - return xerrors.Errorf("failed to marshal getMessagesById query: %v", err) - } - - socket.Send(buf) - - h.queries.AddQuery(queryId, getMessagesById) - - return nil -} From c22e274fee5044fa51abca13b7a1d5483825296f Mon Sep 17 00:00:00 2001 From: stuart Date: Thu, 28 Mar 2024 15:39:45 +0100 Subject: [PATCH 6/6] Move code, remove workers and fix tests --- be1-go/channel/channel.go | 4 +- be1-go/hub/from_client.go | 72 +++++++++++++++ be1-go/hub/from_server.go | 89 ++++++++++++++++++- be1-go/hub/hub.go | 180 +------------------------------------- be1-go/hub/hub_test.go | 4 +- 5 files changed, 165 insertions(+), 184 deletions(-) diff --git a/be1-go/channel/channel.go b/be1-go/channel/channel.go index 6e5e479201..1832e9f088 100644 --- a/be1-go/channel/channel.go +++ b/be1-go/channel/channel.go @@ -92,8 +92,8 @@ func (s *Sockets) Delete(ID string) bool { // HubFunctionalities defines the functions needed by a channel from the hub. type HubFunctionalities interface { - GetPubKeyOwner() kyber.Point - GetPubKeyServ() kyber.Point + GetPubKeyOwner() kyber.Point // Become useless + GetPubKeyServ() kyber.Point // Become useless Sign([]byte) ([]byte, error) GetSchemaValidator() validation.SchemaValidator NotifyNewChannel(channelID string, channel Channel, socket socket.Socket) diff --git a/be1-go/hub/from_client.go b/be1-go/hub/from_client.go index ace3221d00..1a3801ad40 100644 --- a/be1-go/hub/from_client.go +++ b/be1-go/hub/from_client.go @@ -6,14 +6,86 @@ import ( "go.dedis.ch/kyber/v3/sign/schnorr" "golang.org/x/xerrors" "popstellar/crypto" + jsonrpc "popstellar/message" "popstellar/message/answer" "popstellar/message/messagedata" + "popstellar/message/query" "popstellar/message/query/method" "popstellar/message/query/method/message" "popstellar/network/socket" "popstellar/validation" ) +// handleMessageFromClient handles an incoming message from an end user. +func (h *Hub) handleMessageFromClient(incomingMessage *socket.IncomingMessage) error { + socket := incomingMessage.Socket + byteMessage := incomingMessage.Message + + // validate against json schema + err := h.schemaValidator.VerifyJSON(byteMessage, validation.GenericMessage) + if err != nil { + schemaErr := xerrors.Errorf("message is not valid against json schema: %v", err) + socket.SendError(nil, schemaErr) + return schemaErr + } + + rpctype, err := jsonrpc.GetType(byteMessage) + if err != nil { + rpcErr := xerrors.Errorf("failed to get rpc type: %v", err) + socket.SendError(nil, rpcErr) + return rpcErr + } + + if rpctype != jsonrpc.RPCTypeQuery { + rpcErr := xerrors.New("rpc message sent by a client should be a query") + socket.SendError(nil, rpcErr) + return rpcErr + } + + var queryBase query.Base + + err = json.Unmarshal(byteMessage, &queryBase) + if err != nil { + err := answer.NewErrorf(-4, "failed to unmarshal incoming message: %v", err) + socket.SendError(nil, err) + return err + } + + var id int + var msgs []message.Message + var handlerErr error + + switch queryBase.Method { + case query.MethodPublish: + id, handlerErr = h.handlePublish(socket, byteMessage) + h.sendHeartbeatToServers() + case query.MethodSubscribe: + id, handlerErr = h.handleSubscribe(socket, byteMessage) + case query.MethodUnsubscribe: + id, handlerErr = h.handleUnsubscribe(socket, byteMessage) + case query.MethodCatchUp: + msgs, id, handlerErr = h.handleCatchup(socket, byteMessage) + default: + err = answer.NewInvalidResourceError("unexpected method: '%s'", queryBase.Method) + socket.SendError(nil, err) + return err + } + + if handlerErr != nil { + socket.SendError(&id, handlerErr) + return err + } + + if queryBase.Method == query.MethodCatchUp { + socket.SendResult(id, msgs, nil) + return nil + } + + socket.SendResult(id, nil, nil) + + return nil +} + func (h *Hub) handleSubscribe(socket socket.Socket, byteMessage []byte) (int, error) { var subscribe method.Subscribe diff --git a/be1-go/hub/from_server.go b/be1-go/hub/from_server.go index 6c19cd357e..0cde41d270 100644 --- a/be1-go/hub/from_server.go +++ b/be1-go/hub/from_server.go @@ -6,15 +6,98 @@ import ( "golang.org/x/exp/slices" "golang.org/x/xerrors" "popstellar/hub/state" - message2 "popstellar/message" + jsonrpc "popstellar/message" "popstellar/message/answer" "popstellar/message/messagedata" "popstellar/message/query" "popstellar/message/query/method" "popstellar/message/query/method/message" "popstellar/network/socket" + "popstellar/validation" ) +// handleMessageFromServer handles an incoming message from a server. +func (h *Hub) handleMessageFromServer(incomingMessage *socket.IncomingMessage) error { + socket := incomingMessage.Socket + byteMessage := incomingMessage.Message + + // validate against json schema + err := h.schemaValidator.VerifyJSON(byteMessage, validation.GenericMessage) + if err != nil { + schemaErr := xerrors.Errorf("message is not valid against json schema: %v", err) + socket.SendError(nil, schemaErr) + return schemaErr + } + + rpctype, err := jsonrpc.GetType(byteMessage) + if err != nil { + rpcErr := xerrors.Errorf("failed to get rpc type: %v", err) + socket.SendError(nil, rpcErr) + return rpcErr + } + + // check type (answer or query) + if rpctype == jsonrpc.RPCTypeAnswer { + err = h.handleAnswer(socket, byteMessage) + if err != nil { + err = answer.NewErrorf(-4, "failed to handle answer message: %v", err) + socket.SendError(nil, err) + return err + } + + return nil + } + + if rpctype != jsonrpc.RPCTypeQuery { + rpcErr := xerrors.New("jsonRPC is of unknown type") + socket.SendError(nil, rpcErr) + return rpcErr + } + + var queryBase query.Base + + err = json.Unmarshal(byteMessage, &queryBase) + if err != nil { + err := answer.NewErrorf(-4, "failed to unmarshal incoming message: %v", err) + socket.SendError(nil, err) + return err + } + + id := -1 + var msgsByChannel map[string][]message.Message + var handlerErr error + + switch queryBase.Method { + case query.MethodGreetServer: + handlerErr = h.handleGreetServer(socket, byteMessage) + case query.MethodHeartbeat: + handlerErr = h.handleHeartbeat(socket, byteMessage) + case query.MethodGetMessagesById: + msgsByChannel, id, handlerErr = h.handleGetMessagesById(socket, byteMessage) + default: + err = answer.NewErrorf(-2, "unexpected method: '%s'", queryBase.Method) + socket.SendError(nil, err) + return err + } + + if handlerErr != nil { + err := answer.NewErrorf(-4, "failed to handle method: %v", handlerErr) + socket.SendError(&id, err) + return err + } + + if queryBase.Method == query.MethodGetMessagesById { + socket.SendResult(id, nil, msgsByChannel) + return nil + } + + if id != -1 { + socket.SendResult(id, nil, nil) + } + + return nil +} + func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error { var greetServer method.GreetServer @@ -193,7 +276,7 @@ func (h *Hub) handleReceivedMessage(socket socket.Socket, messageData message.Me publish := method.Publish{ Base: query.Base{ - JSONRPCBase: message2.JSONRPCBase{ + JSONRPCBase: jsonrpc.JSONRPCBase{ JSONRPC: "2.0", }, Method: "publish", @@ -282,7 +365,7 @@ func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[s getMessagesById := method.GetMessagesById{ Base: query.Base{ - JSONRPCBase: message2.JSONRPCBase{ + JSONRPCBase: jsonrpc.JSONRPCBase{ JSONRPC: "2.0", }, Method: "get_messages_by_id", diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index 6efecd30a9..06926f3a70 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -1,7 +1,6 @@ package hub import ( - "context" "encoding/base64" "encoding/json" "popstellar/channel" @@ -9,10 +8,8 @@ import ( "popstellar/hub/state" "popstellar/inbox" jsonrpc "popstellar/message" - "popstellar/message/answer" "popstellar/message/query" "popstellar/message/query/method" - "popstellar/message/query/method/message" "popstellar/network/socket" "popstellar/validation" "strings" @@ -22,7 +19,6 @@ import ( "github.com/rs/zerolog" "go.dedis.ch/kyber/v3" "go.dedis.ch/kyber/v3/sign/schnorr" - "golang.org/x/sync/semaphore" "golang.org/x/xerrors" ) @@ -98,8 +94,6 @@ type Hub struct { stop chan struct{} - workers *semaphore.Weighted - log zerolog.Logger laoFac channel.LaoFactory @@ -146,7 +140,6 @@ func New(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAddres secKeyServ: secServ, schemaValidator: schemaValidator, stop: make(chan struct{}), - workers: semaphore.NewWeighted(numWorkers), log: log, laoFac: laoFac, serverSockets: channel.NewSockets(), @@ -181,18 +174,10 @@ func (h *Hub) Start() { for { select { case incomingMessage := <-h.messageChan: - ok := h.workers.TryAcquire(1) - if !ok { - h.log.Warn().Msg("worker pool full, waiting...") - h.workers.Acquire(context.Background(), 1) + err := h.handleIncomingMessage(&incomingMessage) + if err != nil { + h.log.Err(err).Msg("problem handling incoming message") } - - go func() { - err := h.handleIncomingMessage(&incomingMessage) - if err != nil { - h.log.Err(err).Msg("problem handling incoming message") - } - }() case id := <-h.closedSockets: h.channelByID.ForEach(func(c channel.Channel) { // dummy Unsubscribe message because it's only used for logging... @@ -209,8 +194,6 @@ func (h *Hub) Start() { // Stop implements hub.Hub func (h *Hub) Stop() { close(h.stop) - h.log.Info().Msg("waiting for existing workers to finish...") - h.workers.Acquire(context.Background(), numWorkers) } // Receiver implements hub.Hub @@ -303,166 +286,9 @@ func (h *Hub) getChan(channelPath string) (channel.Channel, error) { return channel, nil } -// handleMessageFromClient handles an incoming message from an end user. -func (h *Hub) handleMessageFromClient(incomingMessage *socket.IncomingMessage) error { - socket := incomingMessage.Socket - byteMessage := incomingMessage.Message - - // validate against json schema - err := h.schemaValidator.VerifyJSON(byteMessage, validation.GenericMessage) - if err != nil { - schemaErr := xerrors.Errorf("message is not valid against json schema: %v", err) - socket.SendError(nil, schemaErr) - return schemaErr - } - - rpctype, err := jsonrpc.GetType(byteMessage) - if err != nil { - rpcErr := xerrors.Errorf("failed to get rpc type: %v", err) - socket.SendError(nil, rpcErr) - return rpcErr - } - - if rpctype != jsonrpc.RPCTypeQuery { - rpcErr := xerrors.New("rpc message sent by a client should be a query") - socket.SendError(nil, rpcErr) - return rpcErr - } - - var queryBase query.Base - - err = json.Unmarshal(byteMessage, &queryBase) - if err != nil { - err := answer.NewErrorf(-4, "failed to unmarshal incoming message: %v", err) - socket.SendError(nil, err) - return err - } - - var id int - var msgs []message.Message - var handlerErr error - - switch queryBase.Method { - case query.MethodPublish: - id, handlerErr = h.handlePublish(socket, byteMessage) - h.sendHeartbeatToServers() - case query.MethodSubscribe: - id, handlerErr = h.handleSubscribe(socket, byteMessage) - case query.MethodUnsubscribe: - id, handlerErr = h.handleUnsubscribe(socket, byteMessage) - case query.MethodCatchUp: - msgs, id, handlerErr = h.handleCatchup(socket, byteMessage) - default: - err = answer.NewInvalidResourceError("unexpected method: '%s'", queryBase.Method) - socket.SendError(nil, err) - return err - } - - if handlerErr != nil { - socket.SendError(&id, handlerErr) - return err - } - - if queryBase.Method == query.MethodCatchUp { - socket.SendResult(id, msgs, nil) - return nil - } - - socket.SendResult(id, nil, nil) - - return nil -} - -// handleMessageFromServer handles an incoming message from a server. -func (h *Hub) handleMessageFromServer(incomingMessage *socket.IncomingMessage) error { - socket := incomingMessage.Socket - byteMessage := incomingMessage.Message - - // validate against json schema - err := h.schemaValidator.VerifyJSON(byteMessage, validation.GenericMessage) - if err != nil { - schemaErr := xerrors.Errorf("message is not valid against json schema: %v", err) - socket.SendError(nil, schemaErr) - return schemaErr - } - - rpctype, err := jsonrpc.GetType(byteMessage) - if err != nil { - rpcErr := xerrors.Errorf("failed to get rpc type: %v", err) - socket.SendError(nil, rpcErr) - return rpcErr - } - - // check type (answer or query) - if rpctype == jsonrpc.RPCTypeAnswer { - err = h.handleAnswer(socket, byteMessage) - if err != nil { - err = answer.NewErrorf(-4, "failed to handle answer message: %v", err) - socket.SendError(nil, err) - return err - } - - return nil - } - - if rpctype != jsonrpc.RPCTypeQuery { - rpcErr := xerrors.New("jsonRPC is of unknown type") - socket.SendError(nil, rpcErr) - return rpcErr - } - - var queryBase query.Base - - err = json.Unmarshal(byteMessage, &queryBase) - if err != nil { - err := answer.NewErrorf(-4, "failed to unmarshal incoming message: %v", err) - socket.SendError(nil, err) - return err - } - - id := -1 - var msgsByChannel map[string][]message.Message - var handlerErr error - - switch queryBase.Method { - case query.MethodGreetServer: - handlerErr = h.handleGreetServer(socket, byteMessage) - case query.MethodPublish: - id, handlerErr = h.handlePublish(socket, byteMessage) - h.sendHeartbeatToServers() - case query.MethodHeartbeat: - handlerErr = h.handleHeartbeat(socket, byteMessage) - case query.MethodGetMessagesById: - msgsByChannel, id, handlerErr = h.handleGetMessagesById(socket, byteMessage) - default: - err = answer.NewErrorf(-2, "unexpected method: '%s'", queryBase.Method) - socket.SendError(nil, err) - return err - } - - if handlerErr != nil { - err := answer.NewErrorf(-4, "failed to handle method: %v", handlerErr) - socket.SendError(&id, err) - return err - } - - if queryBase.Method == query.MethodGetMessagesById { - socket.SendResult(id, nil, msgsByChannel) - return nil - } - - if id != -1 { - socket.SendResult(id, nil, nil) - } - - return nil -} - // handleIncomingMessage handles an incoming message based on the socket it // originates from. func (h *Hub) handleIncomingMessage(incomingMessage *socket.IncomingMessage) error { - defer h.workers.Release(1) - h.log.Info().Str("msg", string(incomingMessage.Message)).Msg("handle incoming message") switch incomingMessage.Socket.Type() { diff --git a/be1-go/hub/hub_test.go b/be1-go/hub/hub_test.go index 6931e43966..b09156d060 100644 --- a/be1-go/hub/hub_test.go +++ b/be1-go/hub/hub_test.go @@ -1058,7 +1058,7 @@ func Test_Handle_Publish_From_Server(t *testing.T) { sock := &fakeSocket{} // check that there is no errors with messages from witness - hub.handleMessageFromServer(&socket.IncomingMessage{ + hub.handleMessageFromClient(&socket.IncomingMessage{ Socket: sock, Message: publishBuf, }) @@ -1125,7 +1125,7 @@ func Test_Receive_Publish_Twice(t *testing.T) { sock := &fakeSocket{} // Receive message from a server - hub.handleMessageFromServer(&socket.IncomingMessage{ + hub.handleMessageFromClient(&socket.IncomingMessage{ Socket: sock, Message: publishBuf, })