From 1b08ef7ac6edfca656fa1728f88aa9960e9afde9 Mon Sep 17 00:00:00 2001 From: Andrew Gaffney Date: Sun, 28 May 2023 20:18:08 -0500 Subject: [PATCH] feat: allow running multiple instances of a protocol in muxer Fixes #282 --- connection_test.go | 1 + internal/test/ouroboros_mock/connection.go | 19 +++++++- internal/test/ouroboros_mock/mock_test.go | 1 + muxer/muxer.go | 51 ++++++++++++++++------ protocol/handshake/client_test.go | 2 + protocol/protocol.go | 7 ++- 6 files changed, 65 insertions(+), 16 deletions(-) diff --git a/connection_test.go b/connection_test.go index d9d8bfe2..3b8632f9 100644 --- a/connection_test.go +++ b/connection_test.go @@ -38,6 +38,7 @@ func TestDialFailClose(t *testing.T) { func TestDoubleClose(t *testing.T) { mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ ouroboros_mock.ConversationEntryHandshakeRequestGeneric, ouroboros_mock.ConversationEntryHandshakeResponse, diff --git a/internal/test/ouroboros_mock/connection.go b/internal/test/ouroboros_mock/connection.go index 9f37a5ce..e1422819 100644 --- a/internal/test/ouroboros_mock/connection.go +++ b/internal/test/ouroboros_mock/connection.go @@ -25,6 +25,16 @@ import ( "github.com/blinklabs-io/gouroboros/muxer" ) +// ProtocolRole is an enum of the protocol roles +type ProtocolRole uint + +// Protocol roles +const ( + ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role + ProtocolRoleClient ProtocolRole = 1 // Client protocol role + ProtocolRoleServer ProtocolRole = 2 // Server protocol role +) + // Connection mocks an Ouroboros connection type Connection struct { mockConn net.Conn @@ -35,15 +45,20 @@ type Connection struct { } // NewConnection returns a new Connection with the provided conversation entries -func NewConnection(conversation []ConversationEntry) net.Conn { +func NewConnection(protocolRole ProtocolRole, conversation []ConversationEntry) net.Conn { c := &Connection{ conversation: conversation, } c.conn, c.mockConn = net.Pipe() // Start a muxer on the mocked side of the connection c.muxer = muxer.New(c.mockConn) + // The muxer is for the opposite end of the connection, so we flip the protocol role + muxerProtocolRole := muxer.ProtocolRoleResponder + if protocolRole == ProtocolRoleServer { + muxerProtocolRole = muxer.ProtocolRoleInitiator + } // We use ProtocolUnknown to catch all inbound messages when no other protocols are registered - _, c.muxerRecvChan, _ = c.muxer.RegisterProtocol(muxer.ProtocolUnknown) + _, c.muxerRecvChan, _ = c.muxer.RegisterProtocol(muxer.ProtocolUnknown, muxerProtocolRole) c.muxer.Start() // Start async muxer error handler go func() { diff --git a/internal/test/ouroboros_mock/mock_test.go b/internal/test/ouroboros_mock/mock_test.go index 28fb847e..2ec690a3 100644 --- a/internal/test/ouroboros_mock/mock_test.go +++ b/internal/test/ouroboros_mock/mock_test.go @@ -23,6 +23,7 @@ import ( // Basic test of conversation mock functionality func TestBasic(t *testing.T) { mockConn := NewConnection( + ProtocolRoleClient, []ConversationEntry{ ConversationEntryHandshakeRequestGeneric, ConversationEntryHandshakeResponse, diff --git a/muxer/muxer.go b/muxer/muxer.go index 1217879e..51556724 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -42,6 +42,16 @@ const ( DiffusionModeInitiatorAndResponder DiffusionMode = 3 // Initiator and responder (full duplex) mode ) +// ProtocolRole is an enum of the protocol roles +type ProtocolRole uint + +// Protocol roles +const ( + ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role + ProtocolRoleInitiator ProtocolRole = 1 // Initiator (client) protocol role + ProtocolRoleResponder ProtocolRole = 2 // Responder (server) protocol role +) + // Muxer wraps a connection to allow running multiple mini-protocols over a single connection type Muxer struct { errorChan chan error @@ -50,8 +60,8 @@ type Muxer struct { startChan chan bool doneChan chan bool waitGroup sync.WaitGroup - protocolSenders map[uint16]chan *Segment - protocolReceivers map[uint16]chan *Segment + protocolSenders map[uint16]map[ProtocolRole]chan *Segment + protocolReceivers map[uint16]map[ProtocolRole]chan *Segment diffusionMode DiffusionMode onceStart sync.Once onceStop sync.Once @@ -64,8 +74,8 @@ func New(conn net.Conn) *Muxer { startChan: make(chan bool, 1), doneChan: make(chan bool), errorChan: make(chan error, 10), - protocolSenders: make(map[uint16]chan *Segment), - protocolReceivers: make(map[uint16]chan *Segment), + protocolSenders: make(map[uint16]map[ProtocolRole]chan *Segment), + protocolReceivers: make(map[uint16]map[ProtocolRole]chan *Segment), } m.waitGroup.Add(1) go m.readLoop() @@ -95,8 +105,10 @@ func (m *Muxer) Stop() { m.waitGroup.Wait() // Close protocol receive channels // We rely on the individual mini-protocols to close the sender channel - for _, recvChan := range m.protocolReceivers { - close(recvChan) + for _, protocolRoles := range m.protocolReceivers { + for _, recvChan := range protocolRoles { + close(recvChan) + } } // Close ErrorChan to signify to consumer that we're shutting down close(m.errorChan) @@ -124,13 +136,17 @@ func (m *Muxer) sendError(err error) { // RegisterProtocol registers the provided protocol ID with the muxer. It returns a channel for sending, // a channel for receiving, and a channel to know when the muxer is shutting down -func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment, chan bool) { +func (m *Muxer) RegisterProtocol(protocolId uint16, protocolRole ProtocolRole) (chan *Segment, chan *Segment, chan bool) { // Generate channels senderChan := make(chan *Segment, 10) receiverChan := make(chan *Segment, 10) // Record channels in protocol sender/receiver maps - m.protocolSenders[protocolId] = senderChan - m.protocolReceivers[protocolId] = receiverChan + if _, ok := m.protocolSenders[protocolId]; !ok { + m.protocolSenders[protocolId] = make(map[ProtocolRole]chan *Segment) + m.protocolReceivers[protocolId] = make(map[ProtocolRole]chan *Segment) + } + m.protocolSenders[protocolId][protocolRole] = senderChan + m.protocolReceivers[protocolId][protocolRole] = receiverChan // Start Goroutine to handle outbound messages m.waitGroup.Add(1) go func() { @@ -216,15 +232,24 @@ func (m *Muxer) readLoop() { return } // Send message payload to proper receiver - recvChan := m.protocolReceivers[msg.GetProtocolId()] - if recvChan == nil { + protocolRole := ProtocolRoleResponder + if msg.IsResponse() { + protocolRole = ProtocolRoleInitiator + } + protocolRoles, ok := m.protocolReceivers[msg.GetProtocolId()] + if !ok { // Try the "unknown protocol" receiver if we didn't find an explicit one - recvChan = m.protocolReceivers[ProtocolUnknown] - if recvChan == nil { + protocolRoles, ok = m.protocolReceivers[ProtocolUnknown] + if !ok { m.sendError(fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())) return } } + recvChan := protocolRoles[protocolRole] + if recvChan == nil { + m.sendError(fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())) + return + } if recvChan != nil { recvChan <- msg } diff --git a/protocol/handshake/client_test.go b/protocol/handshake/client_test.go index d9b20c59..f43db8ba 100644 --- a/protocol/handshake/client_test.go +++ b/protocol/handshake/client_test.go @@ -24,6 +24,7 @@ import ( func TestBasicHandshake(t *testing.T) { mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ ouroboros_mock.ConversationEntryHandshakeRequestGeneric, ouroboros_mock.ConversationEntryHandshakeResponse, @@ -53,6 +54,7 @@ func TestBasicHandshake(t *testing.T) { func TestDoubleStart(t *testing.T) { mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ ouroboros_mock.ConversationEntryHandshakeRequestGeneric, ouroboros_mock.ConversationEntryHandshakeResponse, diff --git a/protocol/protocol.go b/protocol/protocol.go index dc78cb67..eb9cd06e 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -75,6 +75,7 @@ const ( // ProtocolRole is an enum of the protocol roles type ProtocolRole uint +// Protocol roles const ( ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role ProtocolRoleClient ProtocolRole = 1 // Client protocol role @@ -110,7 +111,11 @@ func New(config ProtocolConfig) *Protocol { func (p *Protocol) Start() { p.onceStart.Do(func() { // Register protocol with muxer - p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId) + muxerProtocolRole := muxer.ProtocolRoleInitiator + if p.config.Role == ProtocolRoleServer { + muxerProtocolRole = muxer.ProtocolRoleResponder + } + p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId, muxerProtocolRole) // Create buffers and channels p.recvBuffer = bytes.NewBuffer(nil) p.sendQueueChan = make(chan Message, 50)