Skip to content

Commit

Permalink
Merge pull request #287 from blinklabs-io/feat/muxer-multiple-protoco…
Browse files Browse the repository at this point in the history
…l-instances

feat: allow running multiple instances of a protocol in muxer
  • Loading branch information
agaffney authored May 29, 2023
2 parents d827282 + 1b08ef7 commit 93b6409
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 16 deletions.
1 change: 1 addition & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func TestDialFailClose(t *testing.T) {

func TestDoubleClose(t *testing.T) {
mockConn := ouroboros_mock.NewConnection(
ouroboros_mock.ProtocolRoleClient,
[]ouroboros_mock.ConversationEntry{
ouroboros_mock.ConversationEntryHandshakeRequestGeneric,
ouroboros_mock.ConversationEntryHandshakeResponse,
Expand Down
19 changes: 17 additions & 2 deletions internal/test/ouroboros_mock/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ import (
"github.com/blinklabs-io/gouroboros/muxer"
)

// ProtocolRole is an enum of the protocol roles
type ProtocolRole uint

// Protocol roles
const (
ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role
ProtocolRoleClient ProtocolRole = 1 // Client protocol role
ProtocolRoleServer ProtocolRole = 2 // Server protocol role
)

// Connection mocks an Ouroboros connection
type Connection struct {
mockConn net.Conn
Expand All @@ -35,15 +45,20 @@ type Connection struct {
}

// NewConnection returns a new Connection with the provided conversation entries
func NewConnection(conversation []ConversationEntry) net.Conn {
func NewConnection(protocolRole ProtocolRole, conversation []ConversationEntry) net.Conn {
c := &Connection{
conversation: conversation,
}
c.conn, c.mockConn = net.Pipe()
// Start a muxer on the mocked side of the connection
c.muxer = muxer.New(c.mockConn)
// The muxer is for the opposite end of the connection, so we flip the protocol role
muxerProtocolRole := muxer.ProtocolRoleResponder
if protocolRole == ProtocolRoleServer {
muxerProtocolRole = muxer.ProtocolRoleInitiator
}
// We use ProtocolUnknown to catch all inbound messages when no other protocols are registered
_, c.muxerRecvChan, _ = c.muxer.RegisterProtocol(muxer.ProtocolUnknown)
_, c.muxerRecvChan, _ = c.muxer.RegisterProtocol(muxer.ProtocolUnknown, muxerProtocolRole)
c.muxer.Start()
// Start async muxer error handler
go func() {
Expand Down
1 change: 1 addition & 0 deletions internal/test/ouroboros_mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
// Basic test of conversation mock functionality
func TestBasic(t *testing.T) {
mockConn := NewConnection(
ProtocolRoleClient,
[]ConversationEntry{
ConversationEntryHandshakeRequestGeneric,
ConversationEntryHandshakeResponse,
Expand Down
51 changes: 38 additions & 13 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ const (
DiffusionModeInitiatorAndResponder DiffusionMode = 3 // Initiator and responder (full duplex) mode
)

// ProtocolRole is an enum of the protocol roles
type ProtocolRole uint

// Protocol roles
const (
ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role
ProtocolRoleInitiator ProtocolRole = 1 // Initiator (client) protocol role
ProtocolRoleResponder ProtocolRole = 2 // Responder (server) protocol role
)

// Muxer wraps a connection to allow running multiple mini-protocols over a single connection
type Muxer struct {
errorChan chan error
Expand All @@ -50,8 +60,8 @@ type Muxer struct {
startChan chan bool
doneChan chan bool
waitGroup sync.WaitGroup
protocolSenders map[uint16]chan *Segment
protocolReceivers map[uint16]chan *Segment
protocolSenders map[uint16]map[ProtocolRole]chan *Segment
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
diffusionMode DiffusionMode
onceStart sync.Once
onceStop sync.Once
Expand All @@ -64,8 +74,8 @@ func New(conn net.Conn) *Muxer {
startChan: make(chan bool, 1),
doneChan: make(chan bool),
errorChan: make(chan error, 10),
protocolSenders: make(map[uint16]chan *Segment),
protocolReceivers: make(map[uint16]chan *Segment),
protocolSenders: make(map[uint16]map[ProtocolRole]chan *Segment),
protocolReceivers: make(map[uint16]map[ProtocolRole]chan *Segment),
}
m.waitGroup.Add(1)
go m.readLoop()
Expand Down Expand Up @@ -95,8 +105,10 @@ func (m *Muxer) Stop() {
m.waitGroup.Wait()
// Close protocol receive channels
// We rely on the individual mini-protocols to close the sender channel
for _, recvChan := range m.protocolReceivers {
close(recvChan)
for _, protocolRoles := range m.protocolReceivers {
for _, recvChan := range protocolRoles {
close(recvChan)
}
}
// Close ErrorChan to signify to consumer that we're shutting down
close(m.errorChan)
Expand Down Expand Up @@ -124,13 +136,17 @@ func (m *Muxer) sendError(err error) {

// RegisterProtocol registers the provided protocol ID with the muxer. It returns a channel for sending,
// a channel for receiving, and a channel to know when the muxer is shutting down
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment, chan bool) {
func (m *Muxer) RegisterProtocol(protocolId uint16, protocolRole ProtocolRole) (chan *Segment, chan *Segment, chan bool) {
// Generate channels
senderChan := make(chan *Segment, 10)
receiverChan := make(chan *Segment, 10)
// Record channels in protocol sender/receiver maps
m.protocolSenders[protocolId] = senderChan
m.protocolReceivers[protocolId] = receiverChan
if _, ok := m.protocolSenders[protocolId]; !ok {
m.protocolSenders[protocolId] = make(map[ProtocolRole]chan *Segment)
m.protocolReceivers[protocolId] = make(map[ProtocolRole]chan *Segment)
}
m.protocolSenders[protocolId][protocolRole] = senderChan
m.protocolReceivers[protocolId][protocolRole] = receiverChan
// Start Goroutine to handle outbound messages
m.waitGroup.Add(1)
go func() {
Expand Down Expand Up @@ -216,15 +232,24 @@ func (m *Muxer) readLoop() {
return
}
// Send message payload to proper receiver
recvChan := m.protocolReceivers[msg.GetProtocolId()]
if recvChan == nil {
protocolRole := ProtocolRoleResponder
if msg.IsResponse() {
protocolRole = ProtocolRoleInitiator
}
protocolRoles, ok := m.protocolReceivers[msg.GetProtocolId()]
if !ok {
// Try the "unknown protocol" receiver if we didn't find an explicit one
recvChan = m.protocolReceivers[ProtocolUnknown]
if recvChan == nil {
protocolRoles, ok = m.protocolReceivers[ProtocolUnknown]
if !ok {
m.sendError(fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId()))
return
}
}
recvChan := protocolRoles[protocolRole]
if recvChan == nil {
m.sendError(fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId()))
return
}
if recvChan != nil {
recvChan <- msg
}
Expand Down
2 changes: 2 additions & 0 deletions protocol/handshake/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

func TestBasicHandshake(t *testing.T) {
mockConn := ouroboros_mock.NewConnection(
ouroboros_mock.ProtocolRoleClient,
[]ouroboros_mock.ConversationEntry{
ouroboros_mock.ConversationEntryHandshakeRequestGeneric,
ouroboros_mock.ConversationEntryHandshakeResponse,
Expand Down Expand Up @@ -53,6 +54,7 @@ func TestBasicHandshake(t *testing.T) {

func TestDoubleStart(t *testing.T) {
mockConn := ouroboros_mock.NewConnection(
ouroboros_mock.ProtocolRoleClient,
[]ouroboros_mock.ConversationEntry{
ouroboros_mock.ConversationEntryHandshakeRequestGeneric,
ouroboros_mock.ConversationEntryHandshakeResponse,
Expand Down
7 changes: 6 additions & 1 deletion protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ const (
// ProtocolRole is an enum of the protocol roles
type ProtocolRole uint

// Protocol roles
const (
ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role
ProtocolRoleClient ProtocolRole = 1 // Client protocol role
Expand Down Expand Up @@ -110,7 +111,11 @@ func New(config ProtocolConfig) *Protocol {
func (p *Protocol) Start() {
p.onceStart.Do(func() {
// Register protocol with muxer
p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
muxerProtocolRole := muxer.ProtocolRoleInitiator
if p.config.Role == ProtocolRoleServer {
muxerProtocolRole = muxer.ProtocolRoleResponder
}
p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId, muxerProtocolRole)
// Create buffers and channels
p.recvBuffer = bytes.NewBuffer(nil)
p.sendQueueChan = make(chan Message, 50)
Expand Down

0 comments on commit 93b6409

Please sign in to comment.