-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from cloudstruct/feature/protocol-state-machine
Muxer improvements and handshake protocol state machine
- Loading branch information
Showing
7 changed files
with
325 additions
and
222 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package muxer | ||
|
||
import ( | ||
"time" | ||
) | ||
|
||
const ( | ||
MESSAGE_PROTOCOL_ID_RESPONSE_FLAG = 0x8000 | ||
) | ||
|
||
type MessageHeader struct { | ||
Timestamp uint32 | ||
ProtocolId uint16 | ||
PayloadLength uint16 | ||
} | ||
|
||
type Message struct { | ||
MessageHeader | ||
Payload []byte | ||
} | ||
|
||
func NewMessage(protocolId uint16, payload []byte, isResponse bool) *Message { | ||
header := MessageHeader{ | ||
Timestamp: uint32(time.Now().UnixNano() & 0xffffffff), | ||
ProtocolId: protocolId, | ||
} | ||
if isResponse { | ||
header.ProtocolId = header.ProtocolId + MESSAGE_PROTOCOL_ID_RESPONSE_FLAG | ||
} | ||
header.PayloadLength = uint16(len(payload)) | ||
msg := &Message{ | ||
MessageHeader: header, | ||
Payload: payload, | ||
} | ||
return msg | ||
} | ||
|
||
func (s *MessageHeader) IsRequest() bool { | ||
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) == 0 | ||
} | ||
|
||
func (s *MessageHeader) IsResponse() bool { | ||
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) > 0 | ||
} | ||
|
||
func (s *MessageHeader) GetProtocolId() uint16 { | ||
if s.ProtocolId >= MESSAGE_PROTOCOL_ID_RESPONSE_FLAG { | ||
return s.ProtocolId - MESSAGE_PROTOCOL_ID_RESPONSE_FLAG | ||
} | ||
return s.ProtocolId | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package muxer | ||
|
||
import ( | ||
"bytes" | ||
"encoding/binary" | ||
"fmt" | ||
"io" | ||
) | ||
|
||
type Muxer struct { | ||
conn io.ReadWriteCloser | ||
ErrorChan chan error | ||
protocolSenders map[uint16]chan *Message | ||
protocolReceivers map[uint16]chan *Message | ||
} | ||
|
||
func New(conn io.ReadWriteCloser) *Muxer { | ||
m := &Muxer{ | ||
conn: conn, | ||
ErrorChan: make(chan error, 10), | ||
protocolSenders: make(map[uint16]chan *Message), | ||
protocolReceivers: make(map[uint16]chan *Message), | ||
} | ||
go m.readLoop() | ||
return m | ||
} | ||
|
||
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Message) { | ||
// Generate channels | ||
senderChan := make(chan *Message, 10) | ||
receiverChan := make(chan *Message, 10) | ||
// Record channels in protocol sender/receiver maps | ||
m.protocolSenders[protocolId] = senderChan | ||
m.protocolReceivers[protocolId] = receiverChan | ||
// Start Goroutine to handle outbound messages | ||
go func() { | ||
for { | ||
msg := <-senderChan | ||
m.Send(msg) | ||
} | ||
}() | ||
return senderChan, receiverChan | ||
} | ||
|
||
func (m *Muxer) Send(msg *Message) error { | ||
buf := &bytes.Buffer{} | ||
err := binary.Write(buf, binary.BigEndian, msg.MessageHeader) | ||
if err != nil { | ||
return err | ||
} | ||
buf.Write(msg.Payload) | ||
_, err = m.conn.Write(buf.Bytes()) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func (m *Muxer) readLoop() { | ||
for { | ||
header := MessageHeader{} | ||
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil { | ||
m.ErrorChan <- err | ||
} | ||
msg := &Message{ | ||
MessageHeader: header, | ||
Payload: make([]byte, header.PayloadLength), | ||
} | ||
// We use ReadFull because it guarantees to read the expected number of bytes or | ||
// return an error | ||
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil { | ||
m.ErrorChan <- err | ||
} | ||
// Send message payload to proper receiver | ||
recvChan := m.protocolReceivers[msg.GetProtocolId()] | ||
if recvChan == nil { | ||
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId()) | ||
} else { | ||
m.protocolReceivers[msg.GetProtocolId()] <- msg | ||
} | ||
} | ||
} |
Oops, something went wrong.