Skip to content

Commit

Permalink
Merge pull request #2 from cloudstruct/feature/protocol-state-machine
Browse files Browse the repository at this point in the history
Muxer improvements and handshake protocol state machine
  • Loading branch information
agaffney authored Dec 13, 2021
2 parents 9c18ec9 + 2316c97 commit c51b45f
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 222 deletions.
100 changes: 0 additions & 100 deletions handshake/handshake.go

This file was deleted.

107 changes: 0 additions & 107 deletions muxer.go

This file was deleted.

51 changes: 51 additions & 0 deletions muxer/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package muxer

import (
"time"
)

const (
MESSAGE_PROTOCOL_ID_RESPONSE_FLAG = 0x8000
)

type MessageHeader struct {
Timestamp uint32
ProtocolId uint16
PayloadLength uint16
}

type Message struct {
MessageHeader
Payload []byte
}

func NewMessage(protocolId uint16, payload []byte, isResponse bool) *Message {
header := MessageHeader{
Timestamp: uint32(time.Now().UnixNano() & 0xffffffff),
ProtocolId: protocolId,
}
if isResponse {
header.ProtocolId = header.ProtocolId + MESSAGE_PROTOCOL_ID_RESPONSE_FLAG
}
header.PayloadLength = uint16(len(payload))
msg := &Message{
MessageHeader: header,
Payload: payload,
}
return msg
}

func (s *MessageHeader) IsRequest() bool {
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) == 0
}

func (s *MessageHeader) IsResponse() bool {
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) > 0
}

func (s *MessageHeader) GetProtocolId() uint16 {
if s.ProtocolId >= MESSAGE_PROTOCOL_ID_RESPONSE_FLAG {
return s.ProtocolId - MESSAGE_PROTOCOL_ID_RESPONSE_FLAG
}
return s.ProtocolId
}
82 changes: 82 additions & 0 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package muxer

import (
"bytes"
"encoding/binary"
"fmt"
"io"
)

type Muxer struct {
conn io.ReadWriteCloser
ErrorChan chan error
protocolSenders map[uint16]chan *Message
protocolReceivers map[uint16]chan *Message
}

func New(conn io.ReadWriteCloser) *Muxer {
m := &Muxer{
conn: conn,
ErrorChan: make(chan error, 10),
protocolSenders: make(map[uint16]chan *Message),
protocolReceivers: make(map[uint16]chan *Message),
}
go m.readLoop()
return m
}

func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Message) {
// Generate channels
senderChan := make(chan *Message, 10)
receiverChan := make(chan *Message, 10)
// Record channels in protocol sender/receiver maps
m.protocolSenders[protocolId] = senderChan
m.protocolReceivers[protocolId] = receiverChan
// Start Goroutine to handle outbound messages
go func() {
for {
msg := <-senderChan
m.Send(msg)
}
}()
return senderChan, receiverChan
}

func (m *Muxer) Send(msg *Message) error {
buf := &bytes.Buffer{}
err := binary.Write(buf, binary.BigEndian, msg.MessageHeader)
if err != nil {
return err
}
buf.Write(msg.Payload)
_, err = m.conn.Write(buf.Bytes())
if err != nil {
return err
}
return nil
}

func (m *Muxer) readLoop() {
for {
header := MessageHeader{}
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
m.ErrorChan <- err
}
msg := &Message{
MessageHeader: header,
Payload: make([]byte, header.PayloadLength),
}
// We use ReadFull because it guarantees to read the expected number of bytes or
// return an error
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
m.ErrorChan <- err
}
// Send message payload to proper receiver
recvChan := m.protocolReceivers[msg.GetProtocolId()]
if recvChan == nil {
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
} else {
m.protocolReceivers[msg.GetProtocolId()] <- msg
}
}
}
Loading

0 comments on commit c51b45f

Please sign in to comment.