From 3c448da8148cf153bebb22d3aa35681f5caff076 Mon Sep 17 00:00:00 2001 From: Artur Troian <troian.ap@gmail.com> Date: Tue, 5 Sep 2017 14:26:48 +0300 Subject: [PATCH] Ref #42 1. Session expiry 2. Topic alias 3. Publication expiry 4. No local 5. Retain as published 6. Retain handling 7. Receive maximum Split encode/decode operations by type Signed-off-by: Artur Troian <troian.ap@gmail.com> --- README.md | 14 +- clients/session.go | 498 ++++++++--------- clients/sessions.go | 601 ++++++++++++--------- configuration/init.go | 34 +- connection/ack.go | 36 +- connection/connection.go | 417 ++++++++++----- connection/flowControl.go | 81 ++- connection/netCallbacks.go | 232 ++++---- connection/receiver.go | 116 ++-- connection/transmitter.go | 333 ++++++++---- examples/tcp&ws/volantmq.go | 2 +- examples/tcp/volantmq.go | 7 +- examples/websocket/ws.go | 2 +- packet/connack.go | 27 +- packet/connack_test.go | 8 +- packet/connect.go | 155 +++--- packet/connect_test.go | 26 +- packet/disconnect.go | 68 ++- packet/disconnect_test.go | 2 +- packet/errors.go | 2 + packet/header.go | 79 +-- packet/packet.go | 43 +- packet/packetType_test.go | 10 +- packet/ping_test.go | 4 +- packet/property.go | 1009 ++++++++++++++++++++++------------- packet/puback.go | 85 +-- packet/puback_test.go | 4 +- packet/pubcomp_test.go | 4 +- packet/publish.go | 157 ++++-- packet/pubrec_test.go | 4 +- packet/pubrel_test.go | 4 +- packet/reasonCodes.go | 4 +- packet/suback.go | 3 +- packet/suback_test.go | 4 +- packet/subscribe_test.go | 4 +- packet/unsuback.go | 44 +- packet/unsuback_test.go | 4 +- packet/unsubscribe_test.go | 4 +- persistence/types/types.go | 6 + subscriber/subscriber.go | 89 +-- systree/clients.go | 9 +- systree/interfaces.go | 4 +- systree/server.go | 4 +- systree/sessions.go | 6 +- systree/types.go | 4 +- topics/mem/node.go | 100 ++-- topics/mem/regex.txt | 14 - topics/mem/topics.go | 67 +-- topics/mem/trie_test.go | 185 +++++-- topics/topics_test.go | 62 ++- topics/types/types.go | 29 +- transport/base.go | 6 +- transport/tcp.go | 2 +- transport/websocket.go | 2 +- types/types.go | 19 +- volantmq.go | 150 +++--- 56 files changed, 2848 insertions(+), 2041 deletions(-) delete mode 100644 topics/mem/regex.txt diff --git a/README.md b/README.md index ec514c8..645d539 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ VolantMQ is a high performance MQTT broker that aims to be fully compliant with ### Features, Limitations, and Future **Features** -* [MQTT v3.1 - V3.1.1 compliant](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) +* [MQTT v3.1 - V3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) +* [MQTT V5.0](http://docs.oasis-open.org/mqtt/mqtt/v5.0/mqtt-v5.0.html), in progress refer to [TODO](#TODO) * Full support of WebSockets transport * SSL for both plain tcp and WebSockets transports * Independent auth providers for each transport @@ -27,13 +28,14 @@ VolantMQ is a high performance MQTT broker that aims to be fully compliant with * [BoltDB](https://github.com/boltdb/bolt) * In memory -**Future** -* V5.0 specification +**TODO** +* V5.0: + * Publication expiry + * Packets testing * Cluster * Bridge -### Performance - -TBD +* Benchmarking +* Plugins ### Compatibility diff --git a/clients/session.go b/clients/session.go index c78308e..de966e6 100644 --- a/clients/session.go +++ b/clients/session.go @@ -1,21 +1,13 @@ package clients import ( - "net" "sync" - "sync/atomic" "time" - "fmt" - - "unsafe" - - "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/connection" "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/persistence/types" "github.com/VolantMQ/volantmq/subscriber" - "github.com/VolantMQ/volantmq/systree" "github.com/VolantMQ/volantmq/types" ) @@ -23,157 +15,160 @@ type exitReason int const ( exitReasonClean exitReason = iota - exitReasonKeepSubscriber exitReasonShutdown exitReasonExpired ) +type switchStatus int + +const ( + swStatusSwitched switchStatus = iota + swStatusIsOnline + swStatusFinalized +) + type onSessionClose func(string, exitReason) type onSessionPersist func(string, *persistenceTypes.SessionMessages) -type onDisconnect func(string, bool, packet.ReasonCode) - -type sessionConfig struct { - id string - createdAt time.Time - onPersist onSessionPersist - onClose onSessionClose - onDisconnect onDisconnect - messenger types.TopicMessenger - clean bool -} - -type connectionConfig struct { - username string - state *persistenceTypes.SessionMessages - metric systree.Metric - conn net.Conn - auth auth.SessionPermissions - keepAlive uint16 - sendQuota uint16 - version packet.ProtocolVersion +type onDisconnect func(string, packet.ReasonCode, bool) +type onSubscriberShutdown func(subscriber.ConnectionProvider) + +type sessionEvents struct { + persist onSessionPersist + signalClose onSessionClose + signalDisconnected onDisconnect + shutdownSubscriber onSubscriberShutdown } -type setupConfig struct { - subscriber subscriber.ConnectionProvider - will *packet.Publish - expireIn *time.Duration - willDelay time.Duration +type sessionPreConfig struct { + sessionEvents + id string + createdAt time.Time + messenger types.TopicMessenger } -type session struct { - self uintptr - createdAt time.Time - id string - messenger types.TopicMessenger +type sessionReConfig struct { subscriber subscriber.ConnectionProvider - notifyDisconnect onDisconnect - onPersist onSessionPersist - onClose onSessionClose - startLock sync.Mutex - lock sync.Mutex - timerWorker sync.WaitGroup - wgStopped sync.WaitGroup - timerChan chan struct{} - connStop *types.Once - timer *time.Timer - conn *connection.Type will *packet.Publish expireIn *time.Duration willDelay time.Duration - timerStartedAt time.Time - isOnline uintptr - clean bool + killOnDisconnect bool } -func newSession(c *sessionConfig) (*session, error) { - s := &session{ - id: c.id, - createdAt: c.createdAt, - clean: c.clean, - messenger: c.messenger, - onPersist: c.onPersist, - onClose: c.onClose, - notifyDisconnect: c.onDisconnect, - connStop: &types.Once{}, - isOnline: 1, - timerChan: make(chan struct{}), - } +type session struct { + sessionEvents + *sessionReConfig + id string + idLock *sync.Mutex + messenger types.TopicMessenger + createdAt time.Time + expiringSince time.Time + lock sync.Mutex + connStop *types.Once + disconnectOnce *types.OnceWait + wgDisconnected sync.WaitGroup + conn *connection.Type + timer *time.Timer + timerLock sync.Mutex + finalized bool + isOnline chan struct{} +} - s.self = uintptr(unsafe.Pointer(s)) - close(s.timerChan) +type sessionWrap struct { + s *session + lock sync.Mutex +} - return s, nil +func (s *sessionWrap) acquire() { + s.lock.Lock() } -func (s *session) acquire() { - s.startLock.Lock() +func (s *sessionWrap) release() { + s.lock.Unlock() } -func (s *session) release() { - s.startLock.Unlock() +func (s *sessionWrap) swap(w *sessionWrap) *session { + s.s = w.s + s.s.idLock = &s.lock + return s.s } -func (s *session) configure(c *setupConfig, runExpiry bool) { - defer s.lock.Unlock() - s.lock.Lock() +func newSession(c *sessionPreConfig) *session { + s := &session{ + sessionEvents: c.sessionEvents, + id: c.id, + createdAt: c.createdAt, + messenger: c.messenger, + isOnline: make(chan struct{}), + } - s.will = c.will - s.expireIn = c.expireIn - s.willDelay = c.willDelay - s.subscriber = c.subscriber - s.isOnline = 1 + s.timer = time.AfterFunc(10*time.Second, s.timerCallback) + s.timer.Stop() + + close(s.isOnline) + return s +} +func (s *session) reconfigure(c *sessionReConfig, runExpiry bool) { + s.sessionReConfig = c + s.finalized = false if runExpiry { s.runExpiry(true) } } -func (s *session) allocConnection(c *connectionConfig) error { - var err error - - s.conn, err = connection.New(&connection.Config{ - ID: s.id, - OnDisconnect: s.onDisconnect, - Subscriber: s.subscriber, - Messenger: s.messenger, - Clean: s.clean, - ExpireIn: s.expireIn, - Username: c.username, - Auth: c.auth, - State: c.state, - Conn: c.conn, - Metric: c.metric, - KeepAlive: c.keepAlive, - SendQuota: c.sendQuota, - Version: c.version, - }) - - if err == nil { - s.connStop = &types.Once{} +func (s *session) allocConnection(c *connection.PreConfig) (present bool, err error) { + cfg := &connection.Config{ + PreConfig: c, + ID: s.id, + OnDisconnect: s.onDisconnect, + Subscriber: s.subscriber, + Messenger: s.messenger, + KillOnDisconnect: s.killOnDisconnect, + ExpireIn: s.expireIn, } - return err + s.disconnectOnce = &types.OnceWait{} + s.connStop = &types.Once{} + + s.conn, present, err = connection.New(cfg) + + return } func (s *session) start() { - s.wgStopped.Add(1) + s.isOnline = make(chan struct{}) + s.wgDisconnected.Add(1) s.conn.Start() + s.idLock.Unlock() } func (s *session) stop(reason packet.ReasonCode) *persistenceTypes.SessionState { s.connStop.Do(func() { if s.conn != nil { s.conn.Stop(reason) + s.conn = nil } }) - s.wgStopped.Wait() + s.wgDisconnected.Wait() - select { - case <-s.timerChan: - default: - close(s.timerChan) - s.timerWorker.Wait() + if !s.timer.Stop() { + s.timerLock.Lock() + s.timerLock.Unlock() // nolint: megacheck + } + + if !s.finalized { + s.signalClose(s.id, exitReasonShutdown) + s.finalized = true + } + + elapsed := time.Since(s.expiringSince) + if s.willDelay > 0 && (s.willDelay-elapsed) > 0 { + s.willDelay = s.willDelay - elapsed + } + + if s.expireIn != nil && *s.expireIn > 0 && (*s.expireIn-elapsed) > 0 { + *s.expireIn = *s.expireIn - elapsed } state := &persistenceTypes.SessionState{ @@ -195,218 +190,163 @@ func (s *session) stop(reason packet.ReasonCode) *persistenceTypes.SessionState return state } -func (s *session) toOnline() bool { - if atomic.CompareAndSwapUintptr(&s.isOnline, 0, 1) { - select { - case <-s.timerChan: - default: - close(s.timerChan) +// setOnline try switch session state from offline to online. This is necessary when +// when previous network connection has set session expiry or will delay or both +// if switch is successful then swStatusSwitched returned. +// if session has active network connection then returned value is swStatusIsOnline +// if connection has been closed and must not be used anymore then it returns swStatusFinalized +func (s *session) setOnline() switchStatus { + isOnline := false + // check session online status + s.lock.Lock() + select { + case <-s.isOnline: + default: + isOnline = true + } + s.lock.Unlock() + + status := swStatusSwitched + if !isOnline { + // session is offline. before making any further step wait disconnect procedure is done + s.wgDisconnected.Wait() + + // if stop returns false timer has been fired and there is goroutine might be running + if !s.timer.Stop() { + s.timerLock.Lock() + s.timerLock.Unlock() // nolint: megacheck } - return true + if s.finalized { + status = swStatusFinalized + } + } else { + status = swStatusIsOnline } - return false + return status } func (s *session) runExpiry(will bool) { - var expire *time.Duration + var timerPeriod time.Duration // if meet will requirements point that if will && s.will != nil && s.willDelay >= 0 { - expire = &s.willDelay + timerPeriod = s.willDelay } else { s.will = nil } - //check if we set will delay before - //if will delay bigger than session expiry interval set timer period to expireIn value - //as will message (if presented) has to be published either session expiry or will delay (which done first) - //if will delay less than session expiry set timer to will delay interval and store difference between - //will delay and session expiry to let timer restart keep tick after will - - if s.expireIn != nil && *s.expireIn != 0 { + if s.expireIn != nil { // if will delay is set before and value less than expiration // then timer should fire 2 times - if expire != nil && *expire < *s.expireIn { - *s.expireIn = *s.expireIn - s.willDelay + if (timerPeriod > 0) && (timerPeriod < *s.expireIn) { + *s.expireIn = *s.expireIn - timerPeriod } else { - expire = s.expireIn + timerPeriod = *s.expireIn *s.expireIn = 0 } } - s.timerStartedAt = time.Now() - s.timerWorker.Add(1) - s.timerChan = make(chan struct{}) - s.timer = time.NewTimer(*expire * time.Second) - go s.expiryWorker() + s.expiringSince = time.Now() + s.timer.Reset(timerPeriod * time.Second) } func (s *session) onDisconnect(p *connection.DisconnectParams) { - defer func() { - if r := recover(); r != nil { - fmt.Println(r) - } - }() + s.disconnectOnce.Do(func() { + defer s.wgDisconnected.Done() - defer atomic.CompareAndSwapUintptr(&s.isOnline, 1, 0) - defer s.wgStopped.Done() + s.lock.Lock() + close(s.isOnline) + s.lock.Unlock() + + finalize := func(err exitReason) { + s.signalClose(s.id, err) + s.finalized = true + } - signalClose := true - var shutdownReason exitReason + s.connStop.Do(func() { + s.conn = nil + }) - defer func() { - if signalClose { - s.onClose(s.id, shutdownReason) + if p.ExpireAt != nil { + s.expireIn = p.ExpireAt } - }() - s.connStop.Do(func() { - s.conn = nil - }) + // If session expiry is set to 0, the Session ends when the Network Connection is closed + if s.expireIn != nil && *s.expireIn == 0 { + s.killOnDisconnect = true + } - // valid willMsg pointer tells we have will message - if p.Will && s.will != nil { + // valid willMsg pointer tells we have will message // if session is clean send will regardless to will delay - if s.clean || s.willDelay == 0 { + if p.Will && s.will != nil && (s.killOnDisconnect || s.willDelay == 0) { s.messenger.Publish(s.will) // nolint: errcheck s.will = nil } - } - s.notifyDisconnect(s.id, !s.clean, 0) + s.signalDisconnected(s.id, p.Reason, !s.killOnDisconnect) - if s.clean { - // session is clean. Signal upper layer to wipe it - shutdownReason = exitReasonClean - } else { - // session is not clean thus more difficult case - - // persist state - s.onPersist(s.id, p.State) - - // check if remaining subscriptions exists, expiry is presented and will delay not set to 0 - if s.expireIn == nil && s.willDelay == 0 { - if !s.subscriber.HasSubscriptions() { - // above false thus no remaining subscriptions, session does not expire - // and does not require delayed will - // signal upper layer to persist state and wipe this object - shutdownReason = exitReasonShutdown - } else { - shutdownReason = exitReasonKeepSubscriber - } - } else if (s.expireIn != nil && *s.expireIn > 0) || s.willDelay > 0 { - // session has either to expire or will delay or both set + if s.killOnDisconnect || !s.subscriber.HasSubscriptions() { + s.shutdownSubscriber(s.subscriber) + s.subscriber = nil + } - // do not signal upper layer about exit - signalClose = false + if s.killOnDisconnect { + defer finalize(exitReasonClean) + } else { + // session is not clean thus more difficult case + // persist state + s.persist(s.id, p.State) + + // check if remaining subscriptions exists, expiry is presented and will delay not set to 0 + if s.expireIn == nil && s.willDelay == 0 { + // signal to shutdown session + defer finalize(exitReasonShutdown) + } else if (s.expireIn != nil && *s.expireIn > 0) || s.willDelay > 0 { + // new expiry value might be received upon disconnect message from the client + if p.ExpireAt != nil { + s.expireIn = p.ExpireAt + } - // new expiry value might be received upon disconnect message from the client - if p.ExpireAt != nil { - s.expireIn = p.ExpireAt + s.runExpiry(p.Will) } - - s.runExpiry(p.Will) } - } + }) } -func (s *session) expiryWorker() { - var shutdownReason exitReason - - defer func() { - select { - case <-s.timerChan: - default: - close(s.timerChan) - } - s.timerWorker.Done() - - if shutdownReason > 0 { - s.onClose(s.id, shutdownReason) - } - }() - - for { - select { - case <-s.timer.C: - // timer fired - // 1. check if it requires to publish will - if s.will != nil { - // publish if exists and wipe state - s.messenger.Publish(s.will) // nolint: errcheck - s.will = nil - } - - // 2. if expireIn is present this session has expiry set - if s.expireIn != nil { - // 2.a if value pointed by expireIn is non zero there is some time after will left wait - if *s.expireIn != 0 { - // restart timer and wait again - val := *s.expireIn - // clear value pointed by expireIn so when next fire comes we signal session is expired - *s.expireIn = 0 - s.timer.Reset(val) - } else { - // session has expired. WIPE IT - s.subscriber.Offline(true) - shutdownReason = exitReasonExpired - return - } - } else { - // 2.b session has processed delayed will - // if there is any subscriptions left tell upper layer to keep subscriber - // otherwise completely shutdown the session - if s.subscriber.HasSubscriptions() { - shutdownReason = exitReasonKeepSubscriber - } else { - shutdownReason = exitReasonShutdown - } - - return - } - case <-s.timerChan: - if s.timer != nil { - s.timer.Stop() - elapsed := time.Since(s.timerStartedAt) - if s.willDelay > 0 && (s.willDelay-elapsed) > 0 { - s.willDelay = s.willDelay - elapsed - } +func (s *session) timerCallback() { + defer s.timerLock.Unlock() + s.timerLock.Lock() - if s.expireIn != nil && *s.expireIn > 0 && (*s.expireIn-elapsed) > 0 { - *s.expireIn = *s.expireIn - elapsed - } - } - return - } + finalize := func(reason exitReason) { + s.signalClose(s.id, reason) + s.finalized = true } -} -type properties struct { - ExpireIn *time.Duration - WillDelay time.Duration - UserProperties interface{} - AuthData []byte - AuthMethod string - MaximumPacketSize uint32 - ReceiveMaximum uint16 - TopicAliasMaximum uint16 - RequestResponse bool - RequestProblemInfo bool -} + // 1. check for will message available + if s.will != nil { + // publish if exists and wipe state + s.messenger.Publish(s.will) // nolint: errcheck + s.will = nil + s.willDelay = 0 + } -func newProperties() *properties { - return &properties{ - ExpireIn: nil, - WillDelay: 0, - UserProperties: nil, - AuthData: []byte{}, - AuthMethod: "", - MaximumPacketSize: 2684354565, - ReceiveMaximum: 65535, - TopicAliasMaximum: 0xFFFF, - RequestResponse: false, - RequestProblemInfo: false, + if s.expireIn == nil { + // 2.a session has processed delayed will and there is nothing to do + // completely shutdown the session + defer finalize(exitReasonShutdown) + } else if *s.expireIn == 0 { + // session has expired. WIPE IT + if s.subscriber != nil { + s.shutdownSubscriber(s.subscriber) + } + defer finalize(exitReasonExpired) + } else { + // restart timer and wait again + val := *s.expireIn + // clear value pointed by expireIn so when next fire comes we signal session is expired + *s.expireIn = 0 + s.timer.Reset(val) } } diff --git a/clients/sessions.go b/clients/sessions.go index 6f0a37b..2c966f3 100644 --- a/clients/sessions.go +++ b/clients/sessions.go @@ -14,12 +14,15 @@ import ( "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/configuration" + "github.com/VolantMQ/volantmq/connection" "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/persistence/types" "github.com/VolantMQ/volantmq/routines" "github.com/VolantMQ/volantmq/subscriber" "github.com/VolantMQ/volantmq/systree" "github.com/VolantMQ/volantmq/topics/types" + "github.com/VolantMQ/volantmq/types" + "github.com/troian/easygo/netpoll" "go.uber.org/zap" ) @@ -30,48 +33,39 @@ var ( // Config manager configuration type Config struct { - // Topics manager for all the client subscriptions - TopicsMgr topicsTypes.Provider - - Persist persistenceTypes.Provider - - Systree systree.Provider - - // OnReplaceAttempt If requested we notify if there is attempt to dup session - OnReplaceAttempt func(string, bool) - - NodeName string - - // The number of seconds to wait for the CONNACK message before disconnecting. - // If not set then default to 2 seconds. - ConnectTimeout int - - // The number of seconds to keep the connection live if there's no data. - // If not set then defaults to 5 minutes. - KeepAlive int - - // AllowReplace Either allow or deny replacing of existing session if there new client with same clientID - AllowReplace bool - - OfflineQoS0 bool + TopicsMgr topicsTypes.Provider + Persist persistenceTypes.Provider + Systree systree.Provider + OnReplaceAttempt func(string, bool) + NodeName string + ConnectTimeout int + KeepAlive int + MaxPacketSize uint32 + ReceiveMax uint16 + TopicAliasMaximum uint16 + MaximumQoS packet.QosType + AvailableRetain bool + AvailableWildcardSubscription bool + AvailableSubscriptionID bool + AvailableSharedSubscription bool + OfflineQoS0 bool + AllowReplace bool + ForceKeepAlive bool } // Manager clients manager type Manager struct { - systree systree.Provider - persistence persistenceTypes.Sessions - topics topicsTypes.SubscriberInterface - onReplaceAttempt func(string, bool) - log *zap.Logger - quit chan struct{} - sessionsCount sync.WaitGroup - sessions sync.Map - subscribers sync.Map - allowReplace bool - offlineQoS0 bool + Config + persistence persistenceTypes.Sessions + log *zap.Logger + quit chan struct{} + sessionsCount sync.WaitGroup + sessions sync.Map + subscribers sync.Map + poll netpoll.EventPoll } -// StartConfig used to configure session after connection is created +// StartConfig used to reconfigure session after connection is created type StartConfig struct { Req *packet.Connect Resp *packet.ConnAck @@ -82,15 +76,12 @@ type StartConfig struct { // NewManager create new clients manager func NewManager(c *Config) (*Manager, error) { m := &Manager{ - systree: c.Systree, - topics: c.TopicsMgr, - onReplaceAttempt: c.OnReplaceAttempt, - allowReplace: c.AllowReplace, - offlineQoS0: c.OfflineQoS0, - quit: make(chan struct{}), - log: configuration.GetProdLogger().Named("sessions"), + Config: *c, + quit: make(chan struct{}), + log: configuration.GetLogger().Named("sessions"), } + m.poll, _ = netpoll.New(nil) m.persistence, _ = c.Persist.Sessions() var err error @@ -144,13 +135,10 @@ func (m *Manager) Shutdown() error { func (m *Manager) NewSession(config *StartConfig) { var id string var ses *session - var sub subscriber.ConnectionProvider var err error idGenerated := false - sessionPresent := false - username, _ := config.Req.Credentials() - sesProperties := newProperties() + var systreeConnStatus *systree.ClientConnectStatus defer func() { if err != nil { @@ -162,12 +150,6 @@ func (m *Manager) NewSession(config *StartConfig) { reason = packet.CodeRefusedServerUnavailable } config.Resp.SetReturnCode(reason) // nolint: errcheck - } else { - config.Resp.SetSessionPresent(sessionPresent) - - if idGenerated { - config.Resp.PropertySet(packet.PropertyAssignedClientIdentifier, id) // nolint: errcheck - } } if err = routines.WriteMessage(config.Conn, config.Resp); err != nil { @@ -175,186 +157,329 @@ func (m *Manager) NewSession(config *StartConfig) { } else { if ses != nil { ses.start() - m.systree.Clients().Connected( - id, - &systree.ClientConnectStatus{ - Address: config.Conn.RemoteAddr().String(), - Username: string(username), - Timestamp: time.Now().Format(time.RFC3339), - ReceiveMaximum: uint32(sesProperties.ReceiveMaximum), - MaximumPacketSize: sesProperties.MaximumPacketSize, - KeepAlive: config.Req.KeepAlive(), - GeneratedID: idGenerated, - CleanSession: config.Req.CleanStart(), - SessionPresent: sessionPresent, - Protocol: config.Req.Version(), - ConnAckCode: config.Resp.ReturnCode(), - }) + m.Systree.Clients().Connected(id, systreeConnStatus) } } + }() - if ses != nil { - ses.release() + m.checkServerStatus(config.Req.Version(), config.Resp) + + // if response has return code differs from CodeSuccess return from this point + // and send connack in deferred statement + if config.Resp.ReturnCode() != packet.CodeSuccess { + return + } + + // client might come with empty client id + if id = string(config.Req.ClientID()); len(id) == 0 { + id = m.genClientID() + idGenerated = true + } + + if ses, err = m.loadSession(id, config.Req.Version(), config.Resp); err == nil { + if systreeConnStatus, err = m.configureSession(config, ses, id, idGenerated); err != nil { + m.sessions.Delete(id) + m.sessionsCount.Done() } - }() + } +} + +func (m *Manager) loadSession(id string, v packet.ProtocolVersion, resp *packet.ConnAck) (*session, error) { + var err error + var ses *session + wrap := m.allocSession(id, time.Now()) + if ss, ok := m.sessions.LoadOrStore(id, wrap); ok { + // release lock of newly allocated session as lock from old one will be used + wrap.release() + // there is some old session exists, check if it has active network connection then stop if so + // if it is offline (either waiting for expiration to fire or will event or) switch back to online + // and use this session henceforth + oldWrap := ss.(*sessionWrap) + + // lock id to prevent other upcoming session make any changes until we done + oldWrap.acquire() + + old := oldWrap.s + + switch old.setOnline() { + case swStatusIsOnline: + // existing session has active network connection + // exempt it if allowed + m.OnReplaceAttempt(id, m.AllowReplace) + if !m.AllowReplace { + // we do not make any changes to current network connection + // response to new one with error and release both new & old sessions + err = packet.CodeRefusedIdentifierRejected + if v >= packet.ProtocolV50 { + err = packet.CodeInvalidClientID + } + oldWrap.release() + } else { + // session will be replaced with new connection + // stop current active connection + old.stop(packet.CodeSessionTakenOver) + ses = oldWrap.swap(wrap) + m.sessions.Store(id, oldWrap) + m.sessionsCount.Add(1) + } + case swStatusSwitched: + // session has been turned online successfully + ses = old + default: + ses = oldWrap.swap(wrap) + m.sessions.Store(id, oldWrap) + m.sessionsCount.Add(1) + } + } else { + ses = wrap.s + m.sessionsCount.Add(1) + } + + return ses, err +} + +func (m *Manager) checkServerStatus(v packet.ProtocolVersion, resp *packet.ConnAck) { // check first if server is not about to shutdown // if so just give reject and exit select { case <-m.quit: var reason packet.ReasonCode - switch config.Req.Version() { + switch v { case packet.ProtocolV50: reason = packet.CodeServerShuttingDown // TODO: if cluster route client to another node default: reason = packet.CodeRefusedServerUnavailable } - config.Resp.SetReturnCode(reason) // nolint: errcheck + resp.SetReturnCode(reason) // nolint: errcheck default: } +} - // if response has return code differs from CodeSuccess return from this point - // and send connack in deferred statement - if config.Resp.ReturnCode() != packet.CodeSuccess { - return +func (m *Manager) allocSession(id string, createdAt time.Time) *sessionWrap { + wrap := &sessionWrap{ + s: newSession(&sessionPreConfig{ + id: id, + createdAt: createdAt, + messenger: m.TopicsMgr, + sessionEvents: sessionEvents{ + persist: m.onSessionPersist, + signalClose: m.onSessionClose, + signalDisconnected: m.onDisconnect, + shutdownSubscriber: m.onSubscriberShutdown, + }, + })} + wrap.acquire() + wrap.s.idLock = &wrap.lock + + return wrap +} + +func (m *Manager) getWill(pkt *packet.Connect) *packet.Publish { + var willPkt *packet.Publish + if willTopic, willPayload, willQoS, willRetain, will := pkt.Will(); will { + _m, _ := packet.New(pkt.Version(), packet.PUBLISH) + willPkt = _m.(*packet.Publish) + willPkt.Set(willTopic, willPayload, willQoS, willRetain, false) // nolint: errcheck } - // client might come with empty client id - if id = string(config.Req.ClientID()); len(id) == 0 { - id = m.genClientID() - idGenerated = true + return willPkt +} + +func (m *Manager) newConnectionPreConfig(config *StartConfig) *connection.PreConfig { + username, _ := config.Req.Credentials() + + return &connection.PreConfig{ + Username: string(username), + Auth: config.Auth, + Conn: config.Conn, + KeepAlive: config.Req.KeepAlive(), + Version: config.Req.Version(), + Desc: netpoll.Must(netpoll.HandleReadOnce(config.Conn)), + MaxTxPacketSize: types.DefaultMaxPacketSize, + SendQuota: types.DefaultReceiveMax, + State: m.persistence, + StartReceiving: m.signalStartRead, + WaitForData: m.signalWaitForRead, + StopReceiving: m.signalReceiveStop, + Metric: m.Systree.Metric(), + RetainAvailable: m.AvailableRetain, + OfflineQoS0: m.OfflineQoS0, + MaxRxPacketSize: m.MaxPacketSize, + MaxRxTopicAlias: m.TopicAliasMaximum, + MaxTxTopicAlias: 0, + } +} + +func (m *Manager) configureSession(config *StartConfig, ses *session, id string, idGenerated bool) (*systree.ClientConnectStatus, error) { + sub, sessionPresent := m.getSubscriber(id, config.Req.IsClean(), config.Req.Version()) + + sConfig := &sessionReConfig{ + subscriber: sub, + will: m.getWill(config.Req), + killOnDisconnect: false, } - config.Req.PropertyForEach(func(id packet.PropertyID, val interface{}) { // nolint: errcheck - switch id { - case packet.PropertySessionExpiryInterval: - v := time.Duration(val.(uint32)) - sesProperties.ExpireIn = &v - case packet.PropertyWillDelayInterval: - sesProperties.WillDelay = time.Duration(val.(uint32)) - case packet.PropertyReceiveMaximum: - sesProperties.ReceiveMaximum = val.(uint16) - case packet.PropertyMaximumPacketSize: - sesProperties.MaximumPacketSize = val.(uint32) - case packet.PropertyTopicAliasMaximum: - sesProperties.TopicAliasMaximum = val.(uint16) - case packet.PropertyRequestProblemInfo: - sesProperties.RequestProblemInfo = val.(bool) - case packet.PropertyRequestResponseInfo: - sesProperties.RequestResponse = val.(bool) - case packet.PropertyUserProperty: - sesProperties.UserProperties = val - case packet.PropertyAuthMethod: - sesProperties.AuthMethod = val.(string) - case packet.PropertyAuthData: - sesProperties.AuthData = val.([]byte) + cConfig := m.newConnectionPreConfig(config) + + if config.Req.Version() >= packet.ProtocolV50 { + if err := readSessionProperties(config.Req, sConfig, cConfig); err != nil { + return nil, err } - }) - if ss, ok := m.sessions.Load(id); ok { - ses = ss.(*session) - ses.acquire() + ids := "" + if idGenerated { + ids = id + } + + m.writeSessionProperties(config.Resp, ids) + if err := config.Resp.PropertySet(packet.PropertyServerKeepAlive, m.KeepAlive); err != nil { + return nil, err + } } - if ses != nil && !ses.toOnline() { - if !m.allowReplace { - // duplicate prohibited. send identifier rejected - var reason packet.ReasonCode - switch config.Req.Version() { - case packet.ProtocolV50: - reason = packet.CodeInvalidClientID - default: - reason = packet.CodeRefusedIdentifierRejected - } + // MQTT v5 has different meaning of clean comparing to MQTT v3 + // - v3: if session is clean it lasts when Network connection os close + // - v5: clean means clean start and server must wipe any previously created session with same id + // but keep this one if Network Connection is closed + if (config.Req.Version() <= packet.ProtocolV311 && config.Req.IsClean()) || + (sConfig.expireIn != nil && *sConfig.expireIn == 0) { + sConfig.killOnDisconnect = true + } - config.Resp.SetReturnCode(reason) // nolint: errcheck - err = ErrReplaceNotAllowed - m.onReplaceAttempt(id, false) - ses = nil - return + ses.reconfigure(sConfig, false) + + present, err := ses.allocConnection(cConfig) + var status *systree.ClientConnectStatus + + sessionPresent = sessionPresent || present + + if err == nil { + if !config.Req.IsClean() { + m.persistence.Delete([]byte(id)) // nolint: errcheck + } + + status = &systree.ClientConnectStatus{ + Username: cConfig.Username, + Timestamp: time.Now().Format(time.RFC3339), + ReceiveMaximum: uint32(cConfig.SendQuota), + MaximumPacketSize: cConfig.MaxTxPacketSize, + GeneratedID: idGenerated, + SessionPresent: sessionPresent, + Address: config.Conn.RemoteAddr().String(), + KeepAlive: config.Req.KeepAlive(), + Protocol: config.Req.Version(), + ConnAckCode: config.Resp.ReturnCode(), + CleanSession: config.Req.IsClean(), + KillOnDisconnect: sConfig.killOnDisconnect, } - // replace allowed stop current session - ses.stop(packet.CodeSessionTakenOver) - ses.release() - m.onReplaceAttempt(id, true) - ses = nil } - var persistedMessages *persistenceTypes.SessionMessages + config.Resp.SetSessionPresent(sessionPresent) - sub, persistedMessages, sessionPresent = m.getSubscriber(id, config.Req.CleanStart(), config.Req.Version()) + return status, err +} - expireInterval := "absent" - if sesProperties.ExpireIn != nil { - expireInterval = strconv.FormatUint(uint64(*sesProperties.ExpireIn), 10) +func boolToByte(v bool) byte { + if v { + return 1 } - if !sessionPresent { - state := &systree.SessionCreatedStatus{ - ExpiryInterval: expireInterval, - WillDelay: strconv.FormatUint(uint64(sesProperties.WillDelay), 10), - Timestamp: time.Now().Format(time.RFC3339), - Clean: config.Req.CleanStart(), + return 0 +} + +func readSessionProperties(req *packet.Connect, sc *sessionReConfig, cc *connection.PreConfig) (err error) { + // [MQTT-3.1.2.11.2] + if prop := req.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { + if val, e := prop.AsInt(); e == nil { + v := time.Duration(val) + sc.expireIn = &v } + } - m.systree.Sessions().Created(id, state) + // [MQTT-3.1.2.11.3] + if prop := req.PropertyGet(packet.PropertyWillDelayInterval); prop != nil { + if val, e := prop.AsInt(); e == nil { + sc.willDelay = time.Duration(val) + } } - var willMsg *packet.Publish - if willTopic, willPayload, willQoS, willRetain, will := config.Req.Will(); will { - _m, _ := packet.NewMessage(config.Req.Version(), packet.PUBLISH) - willMsg = _m.(*packet.Publish) - willMsg.SetQoS(willQoS) // nolint: errcheck - willMsg.SetTopic(willTopic) // nolint: errcheck - willMsg.SetPayload(willPayload) - willMsg.SetRetain(willRetain) + // [MQTT-3.1.2.11.4] + if prop := req.PropertyGet(packet.PropertyReceiveMaximum); prop != nil { + if val, e := prop.AsShort(); e == nil { + cc.SendQuota = int32(val) + } } - if ses == nil { - ses, err = newSession(&sessionConfig{ - id: id, - onPersist: m.onSessionPersist, - onClose: m.onSessionClose, - onDisconnect: m.onClientDisconnect, - messenger: m.topics, - clean: config.Req.CleanStart(), - }) - if err == nil { - ses.acquire() - m.sessions.Store(id, ses) - m.sessionsCount.Add(1) + // [MQTT-3.1.2.11.5] + if prop := req.PropertyGet(packet.PropertyMaximumPacketSize); prop != nil { + if val, e := prop.AsInt(); e == nil { + cc.MaxTxPacketSize = val } } - if ses != nil { - ses.configure(&setupConfig{ - subscriber: sub, - will: willMsg, - expireIn: sesProperties.ExpireIn, - willDelay: sesProperties.WillDelay, - }, false) - - err = ses.allocConnection(&connectionConfig{ - username: string(username), - state: persistedMessages, - auth: config.Auth, - metric: m.systree.Metric(), - conn: config.Conn, - keepAlive: config.Req.KeepAlive(), - sendQuota: sesProperties.ReceiveMaximum, - version: config.Req.Version(), - }) + // [MQTT-3.1.2.11.6] + if prop := req.PropertyGet(packet.PropertyTopicAliasMaximum); prop != nil { + if val, e := prop.AsShort(); e == nil { + cc.MaxTxTopicAlias = val + } + } + + // [MQTT-3.1.2.11.10] + if prop := req.PropertyGet(packet.PropertyAuthMethod); prop != nil { + if val, e := prop.AsString(); e == nil { + cc.AuthMethod = val + } + } - if err == nil && persistedMessages != nil { - m.persistence.MessagesWipe([]byte(id)) // nolint: errcheck + // [MQTT-3.1.2.11.11] + if prop := req.PropertyGet(packet.PropertyAuthData); prop != nil { + if len(cc.AuthMethod) == 0 { + err = packet.CodeProtocolError + return + } + if val, e := prop.AsBinary(); e == nil { + cc.AuthData = val } } + + return } -func (m *Manager) getSubscriber(id string, clean bool, v packet.ProtocolVersion) (subscriber.ConnectionProvider, *persistenceTypes.SessionMessages, bool) { +func (m *Manager) writeSessionProperties(resp *packet.ConnAck, id string) { + // [MQTT-3.2.2.3.2] if server receive max less than 65536 than let client to know about + if m.ReceiveMax < types.DefaultReceiveMax { + resp.PropertySet(packet.PropertyReceiveMaximum, m.ReceiveMax) // nolint: errcheck + } + // [MQTT-3.2.2.3.3] if supported server's QoS less than 2 notify client + if m.MaximumQoS < packet.QoS2 { + resp.PropertySet(packet.PropertyMaximumQoS, byte(m.MaximumQoS)) // nolint: errcheck + } + // [MQTT-3.2.2.3.4] tell client whether retained messages supported + resp.PropertySet(packet.PropertyRetainAvailable, boolToByte(m.AvailableRetain)) // nolint: errcheck + // [MQTT-3.2.2.3.5] if server max packet size less than 268435455 than let client to know about + if m.MaxPacketSize < types.DefaultMaxPacketSize { + resp.PropertySet(packet.PropertyMaximumPacketSize, m.MaxPacketSize) // nolint: errcheck + } + // [MQTT-3.2.2.3.6] + if len(id) > 0 { + resp.PropertySet(packet.PropertyAssignedClientIdentifier, id) // nolint: errcheck + } + // [MQTT-3.2.2.3.7] + if m.TopicAliasMaximum > 0 { + resp.PropertySet(packet.PropertyTopicAliasMaximum, m.TopicAliasMaximum) // nolint: errcheck + } + // [MQTT-3.2.2.3.10] tell client whether server supports wildcard subscriptions or not + resp.PropertySet(packet.PropertyWildcardSubscriptionAvailable, boolToByte(m.AvailableWildcardSubscription)) // nolint: errcheck + // [MQTT-3.2.2.3.11] tell client whether server supports subscription identifiers or not + resp.PropertySet(packet.PropertySubscriptionIdentifierAvailable, boolToByte(m.AvailableSubscriptionID)) // nolint: errcheck + // [MQTT-3.2.2.3.12] tell client whether server supports shared subscriptions or not + resp.PropertySet(packet.PropertySharedSubscriptionAvailable, boolToByte(m.AvailableSharedSubscription)) // nolint: errcheck +} + +func (m *Manager) getSubscriber(id string, clean bool, v packet.ProtocolVersion) (subscriber.ConnectionProvider, bool) { var sub subscriber.ConnectionProvider - var state *persistenceTypes.SessionMessages present := false if clean { @@ -363,33 +488,29 @@ func (m *Manager) getSubscriber(id string, clean bool, v packet.ProtocolVersion) sub.Offline(true) m.subscribers.Delete(id) } - m.persistence.Delete([]byte(id)) // nolint: errcheck - } else { - var err error - if state, err = m.persistence.MessagesLoad([]byte(id)); err != nil && err != persistenceTypes.ErrNotFound { - m.log.Error("Couldn't load session state", zap.String("ClientID", id), zap.Error(err)) - } else if err == nil { - present = true - m.persistence.Delete([]byte(id)) // nolint: errcheck + if err := m.persistence.Delete([]byte(id)); err != nil && err != persistenceTypes.ErrNotFound { + m.log.Error("Couldn't wipe session", zap.String("ClientID", id), zap.Error(err)) } } if sb, ok := m.subscribers.Load(id); !ok { sub = subscriber.New(&subscriber.Config{ ID: id, - Topics: m.topics, + Topics: m.TopicsMgr, OnOfflinePublish: m.onPublish, - OfflineQoS0: m.offlineQoS0, + OfflineQoS0: m.OfflineQoS0, Version: v, }) m.subscribers.Store(id, sub) + m.log.Debug("Subscriber created", zap.String("ClientID", id)) } else { + m.log.Debug("Subscriber obtained", zap.String("ClientID", id)) sub = sb.(subscriber.ConnectionProvider) present = true } - return sub, state, present + return sub, present } func (m *Manager) genClientID() string { @@ -405,20 +526,35 @@ func (m *Manager) onSessionPersist(id string, state *persistenceTypes.SessionMes m.persistence.MessagesStore([]byte(id), state) // nolint: errcheck } -func (m *Manager) onClientDisconnect(id string, clean bool, reason packet.ReasonCode) { - m.systree.Clients().Disconnected(id, reason, clean) +func (m *Manager) signalStartRead(desc *netpoll.Desc, cb netpoll.CallbackFn) error { + return m.poll.Start(desc, cb) } -func (m *Manager) onSessionClose(id string, reason exitReason) { - if reason != exitReasonKeepSubscriber { - m.subscribers.Delete(id) - } +func (m *Manager) signalWaitForRead(desc *netpoll.Desc) error { + return m.poll.Resume(desc) +} - if reason == exitReasonClean { - m.persistence.Delete([]byte(id)) // nolint: errcheck - } +func (m *Manager) signalReceiveStop(desc *netpoll.Desc) error { + return m.poll.Stop(desc) +} + +func (m *Manager) onDisconnect(id string, reason packet.ReasonCode, retain bool) { + m.log.Debug("Disconnected", zap.String("ClientID", id)) + m.Systree.Clients().Disconnected(id, reason, retain) +} +func (m *Manager) onSubscriberShutdown(sub subscriber.ConnectionProvider) { + m.log.Debug("Shutdown subscriber", zap.String("ClientID", sub.ID())) + sub.Offline(true) + m.subscribers.Delete(sub.ID()) +} + +func (m *Manager) onSessionClose(id string, reason exitReason) { if reason == exitReasonClean || reason == exitReasonExpired { + if err := m.persistence.Delete([]byte(id)); err != nil && err != persistenceTypes.ErrNotFound { + m.log.Error("Couldn't wipe session", zap.String("ClientID", id), zap.Error(err)) + } + rs := "clean" if reason == exitReasonExpired { rs = "expired" @@ -429,9 +565,11 @@ func (m *Manager) onSessionClose(id string, reason exitReason) { Reason: rs, } - m.systree.Sessions().Removed(id, state) + m.Systree.Sessions().Removed(id, state) } + m.log.Debug("Session close", zap.String("ClientID", id)) + m.sessions.Delete(id) m.sessionsCount.Done() } @@ -454,39 +592,26 @@ func (m *Manager) loadSessions() error { } if state.ExpireIn != nil || state.Will != nil { - var err error - var ses *session createdAt, _ := time.Parse(time.RFC3339, state.Timestamp) - ses, err = newSession(&sessionConfig{ - id: string(id), - createdAt: createdAt, - onPersist: m.onSessionPersist, - onClose: m.onSessionClose, - onDisconnect: m.onClientDisconnect, - messenger: m.topics, - clean: false, - }) - if err == nil { - m.sessions.Store(string(id), ses) - m.sessionsCount.Add(1) - - setup := &setupConfig{ - subscriber: nil, - expireIn: state.ExpireIn, - } + ses := m.allocSession(string(id), createdAt) + setup := &sessionReConfig{ + subscriber: nil, + expireIn: state.ExpireIn, + killOnDisconnect: false, + } - if state.Will != nil { - msg, _, _ := packet.Decode(state.Version, state.Will.Message) - willMsg, _ := msg.(*packet.Publish) - setup.will = willMsg - setup.willDelay = state.Will.Delay - } - ses.configure(setup, true) - } else { - return err + if state.Will != nil { + msg, _, _ := packet.Decode(state.Version, state.Will.Message) + willMsg, _ := msg.(*packet.Publish) + setup.will = willMsg + setup.willDelay = state.Will.Delay } + ses.s.reconfigure(setup, true) + m.sessions.Store(id, ses) + m.sessionsCount.Add(1) + ses.release() } - m.systree.Sessions().Created(string(id), status) + m.Systree.Sessions().Created(string(id), status) return nil }) @@ -516,9 +641,9 @@ func (m *Manager) loadSubscribers() error { offset += total - params := &subscriber.SubscriptionParams{} + params := &topicsTypes.SubscriptionParams{} - params.Requested = packet.SubscriptionOptions(data[offset]) + params.Ops = packet.SubscriptionOptions(data[offset]) offset++ params.ID = binary.BigEndian.Uint32(data[offset:]) @@ -541,9 +666,9 @@ func (m *Manager) loadSubscribers() error { sub := subscriber.New( &subscriber.Config{ ID: id, - Topics: m.topics, + Topics: m.TopicsMgr, OnOfflinePublish: m.onPublish, - OfflineQoS0: m.offlineQoS0, + OfflineQoS0: m.OfflineQoS0, Version: t.version, }) @@ -594,7 +719,7 @@ func (m *Manager) storeSubscribers() error { for s, params := range topics { total, _ := packet.WriteLPBytes(buf[offset:], []byte(s)) offset += total - buf[offset] = byte(params.Requested) + buf[offset] = byte(params.Ops) offset++ binary.BigEndian.PutUint32(buf[offset:], params.ID) offset += 4 diff --git a/configuration/init.go b/configuration/init.go index b70a0ed..16c347a 100644 --- a/configuration/init.go +++ b/configuration/init.go @@ -7,11 +7,7 @@ import ( ) type config struct { - log struct { - Prod *zap.Logger - Dev *zap.Logger - } - + log *zap.Logger once sync.Once } @@ -25,17 +21,13 @@ var cfg config func init() { logCfg := zap.NewProductionConfig() - logDebugCfg := zap.NewProductionConfig() logCfg.DisableStacktrace = true - logDebugCfg.DisableStacktrace = true - logDebugCfg.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + logCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) log, _ := logCfg.Build() - dLog, _ := logDebugCfg.Build() - cfg.log.Prod = log.Named("mqtt") - cfg.log.Dev = dLog.Named("mqtt") + cfg.log = log.Named("mqtt") } // Init global MQTT config with given options @@ -43,30 +35,20 @@ func init() { func Init(ops Options) { cfg.once.Do(func() { logCfg := zap.NewProductionConfig() - logDebugCfg := zap.NewDevelopmentConfig() - logDebugCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) + logCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) logCfg.DisableStacktrace = true - logDebugCfg.DisableStacktrace = true if !ops.LogWithTs { logCfg.EncoderConfig.TimeKey = "" - logDebugCfg.EncoderConfig.TimeKey = "" } log, _ := logCfg.Build() - dLog, _ := logDebugCfg.Build() - cfg.log.Prod = log.Named("mqtt") - cfg.log.Dev = dLog.Named("mqtt") + cfg.log = log.Named("mqtt") }) } -// GetProdLogger return production logger -func GetProdLogger() *zap.Logger { - return cfg.log.Prod -} - -// GetDevLogger return development logger -func GetDevLogger() *zap.Logger { - return cfg.log.Prod +// GetLogger return production logger +func GetLogger() *zap.Logger { + return cfg.log } diff --git a/connection/ack.go b/connection/ack.go index 729fc48..004213b 100644 --- a/connection/ack.go +++ b/connection/ack.go @@ -6,51 +6,37 @@ import ( "github.com/VolantMQ/volantmq/packet" ) -type onRelease func(msg packet.Provider) +type onRelease func(o, n packet.Provider) type ackQueue struct { - lock sync.Mutex - messages map[packet.IDType]packet.Provider + messages sync.Map onRelease onRelease } func newAckQueue(cb onRelease) *ackQueue { a := ackQueue{ - messages: make(map[packet.IDType]packet.Provider), onRelease: cb, } return &a } -func (a *ackQueue) store(msg packet.Provider) { - a.lock.Lock() - defer a.lock.Unlock() - - id, _ := msg.ID() - - a.messages[id] = msg +func (a *ackQueue) store(pkt packet.Provider) { + id, _ := pkt.ID() + a.messages.Store(id, pkt) } -func (a *ackQueue) release(msg packet.Provider) { - a.lock.Lock() - defer a.lock.Unlock() - - id, _ := msg.ID() +func (a *ackQueue) release(pkt packet.Provider) { + id, _ := pkt.ID() - if e, ok := a.messages[id]; ok { - if a.onRelease != nil { - a.onRelease(e) + if value, ok := a.messages.Load(id); ok { + if orig, ok := value.(packet.Provider); ok && a.onRelease != nil { + a.onRelease(orig, pkt) } - a.messages[id] = nil - delete(a.messages, id) + a.messages.Delete(id) } } -func (a *ackQueue) get() map[packet.IDType]packet.Provider { - return a.messages -} - //func (a *ackQueue) wipe() { // a.lock.Lock() // defer a.lock.Unlock() diff --git a/connection/connection.go b/connection/connection.go index f87daa6..e0e6c06 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -15,6 +15,7 @@ package connection import ( + "container/list" "errors" "net" "sync" @@ -27,6 +28,7 @@ import ( "github.com/VolantMQ/volantmq/subscriber" "github.com/VolantMQ/volantmq/systree" "github.com/VolantMQ/volantmq/types" + "github.com/troian/easygo/netpoll" "go.uber.org/zap" ) @@ -38,9 +40,11 @@ var ( // DisconnectParams session state when stopped type DisconnectParams struct { - Will bool ExpireAt *time.Duration State *persistenceTypes.SessionMessages + Desc *netpoll.Desc + Reason packet.ReasonCode + Will bool } type onDisconnect func(*DisconnectParams) @@ -59,116 +63,164 @@ type WillConfig struct { QoS packet.QosType } +// PreConfig used by session manager to when configuring session +// a bit of ugly +// TODO(troian): try rid it off +type PreConfig struct { + StartReceiving func(*netpoll.Desc, netpoll.CallbackFn) error + StopReceiving func(*netpoll.Desc) error + WaitForData func(*netpoll.Desc) error + Username string + AuthMethod string + AuthData []byte + State persistenceTypes.ConnectionMessages + Metric systree.Metric + Conn net.Conn + Auth auth.SessionPermissions + Desc *netpoll.Desc + MaxRxPacketSize uint32 + MaxTxPacketSize uint32 + SendQuota int32 + MaxTxTopicAlias uint16 + MaxRxTopicAlias uint16 + KeepAlive uint16 + Version packet.ProtocolVersion + RetainAvailable bool + PreserveOrder bool + OfflineQoS0 bool +} + // Config is system wide configuration parameters for every session type Config struct { - ID string - Username string - State *persistenceTypes.SessionMessages - Subscriber subscriber.ConnectionProvider - Auth auth.SessionPermissions - Messenger types.TopicMessenger - Metric systree.Metric - Conn net.Conn - OnDisconnect onDisconnect - ExpireIn *time.Duration - WillDelay time.Duration - KeepAlive uint16 - SendQuota uint16 - Clean bool - PreserveOrder bool - Version packet.ProtocolVersion + *PreConfig + ID string + Subscriber subscriber.ConnectionProvider + Messenger types.TopicMessenger + OnDisconnect onDisconnect + ExpireIn *time.Duration + WillDelay time.Duration + KillOnDisconnect bool } -// Type session +// Type connection type Type struct { - conn net.Conn + *Config pubIn *ackQueue pubOut *ackQueue flowControl *packetsFlowControl tx *transmitter rx *receiver onDisconnect onDisconnect - subscriber subscriber.ConnectionProvider - messenger types.TopicMessenger - auth auth.SessionPermissions - id string - username string quit chan struct{} onStart types.Once onConnDisconnect types.Once - - retained struct { + started sync.WaitGroup + topicAlias map[uint16]string + retained struct { lock sync.Mutex list []*packet.Publish } - log *zap.Logger - sendQuota int32 - offlineQoS0 bool - clean bool - version packet.ProtocolVersion - will bool + preProcessPublish func(*packet.Publish) error + postProcessPublish func(*packet.Publish) error + + log *zap.Logger + will bool } type unacknowledgedPublish struct { msg packet.Provider } -// New allocate new sessions object -func New(c *Config) (s *Type, err error) { +func (u *unacknowledgedPublish) Size() (int, error) { + return u.msg.Size() +} + +type sizeAble interface { + Size() (int, error) +} + +// New allocate new connection object +func New(c *Config) (s *Type, present bool, err error) { s = &Type{ - id: c.ID, - username: c.Username, - auth: c.Auth, - subscriber: c.Subscriber, - messenger: c.Messenger, - version: c.Version, - clean: c.Clean, - conn: c.Conn, + Config: c, onDisconnect: c.OnDisconnect, - sendQuota: int32(c.SendQuota), quit: make(chan struct{}), + topicAlias: make(map[uint16]string), will: true, } + s.started.Add(1) s.pubIn = newAckQueue(s.onReleaseIn) s.pubOut = newAckQueue(s.onReleaseOut) - s.flowControl = newFlowControl(s.quit, c.PreserveOrder) - s.log = configuration.GetProdLogger().Named("connection." + s.id) + s.flowControl = newFlowControl(s.quit, c.SendQuota) + + s.log = configuration.GetLogger().Named("connection." + s.ID) + + if s.Version >= packet.ProtocolV50 { + s.preProcessPublish = s.preProcessPublishV50 + s.postProcessPublish = s.postProcessPublishV50 + } else { + s.preProcessPublish = func(*packet.Publish) error { return nil } + s.postProcessPublish = func(*packet.Publish) error { return nil } + } s.tx = newTransmitter(&transmitterConfig{ - quit: s.quit, - log: s.log, - id: s.id, - pubIn: s.pubIn, - pubOut: s.pubOut, - flowControl: s.flowControl, - conn: s.conn, - onDisconnect: s.onConnectionClose, + quit: s.quit, + log: s.log, + id: s.ID, + pubIn: s.pubIn, + pubOut: s.pubOut, + flowControl: s.flowControl, + conn: s.Conn, + onDisconnect: s.onConnectionClose, + maxPacketSize: c.MaxTxPacketSize, + topicAliasMax: c.MaxTxTopicAlias, }) + var keepAlive time.Duration + if c.KeepAlive > 0 { + keepAlive = time.Second * time.Duration(c.KeepAlive) + keepAlive = keepAlive + (keepAlive / 2) + } + s.rx = newReceiver(&receiverConfig{ - quit: s.quit, - conn: s.conn, - onDisconnect: s.onConnectionClose, - onPacket: s.processIncoming, - version: s.version, - will: &s.will, + keepAlive: keepAlive, + quit: s.quit, + conn: s.Conn, + waitForRead: s.onWaitForRead, + onDisconnect: s.onConnectionClose, + onPacket: s.processIncoming, + version: s.Version, + maxPacketSize: c.MaxRxPacketSize, + will: &s.will, }) - if c.KeepAlive > 0 { - s.rx.keepAlive = time.Second * time.Duration(c.KeepAlive) - s.rx.keepAlive = s.rx.keepAlive + (s.rx.keepAlive / 2) + gList := list.New() + qList := list.New() + + subscribedPublish := func(p *packet.Publish) { + if p.QoS() == packet.QoS0 { + gList.PushBack(p) + } else { + qList.PushBack(p) + } } // transmitter queue is ready, assign to subscriber new online callback // and signal it to forward messages to online callback by creating channel - s.subscriber.Online(s.onSubscribedPublish) - if !s.clean && c.State != nil { - // restore persisted state of the session if any - s.loadPersistence(c.State) // nolint: errcheck + s.Subscriber.Online(subscribedPublish) + + // restore persisted state of the session if any + if present, err = s.loadPersistence(); err != nil { + return } + s.Subscriber.OnlineRedirect(s.onSubscribedPublish) + + s.tx.gLoadList(gList) + s.tx.qLoadList(qList) + return } @@ -176,7 +228,8 @@ func New(c *Config) (s *Type, err error) { func (s *Type) Start() { s.onStart.Do(func() { s.tx.run() - s.rx.run() + s.StartReceiving(s.Desc, s.rx.run) // nolint: errcheck + s.started.Done() }) } @@ -184,7 +237,7 @@ func (s *Type) Start() { // or session is being replaced // Effective only first invoke func (s *Type) Stop(reason packet.ReasonCode) { - s.onConnectionClose(true) + s.onConnectionClose(true, reason) } func (s *Type) processIncoming(p packet.Provider) error { @@ -193,7 +246,7 @@ func (s *Type) processIncoming(p packet.Provider) error { switch pkt := p.(type) { case *packet.Publish: - resp = s.onPublish(pkt) + resp, err = s.onPublish(pkt) case *packet.Ack: resp = s.onAck(pkt) case *packet.Subscribe: @@ -202,73 +255,102 @@ func (s *Type) processIncoming(p packet.Provider) error { resp = s.onUnSubscribe(pkt) case *packet.PingReq: // For PINGREQ message, we should send back PINGRESP - mR, _ := packet.NewMessage(s.version, packet.PINGRESP) + mR, _ := packet.New(s.Version, packet.PINGRESP) resp, _ = mR.(*packet.PingResp) case *packet.Disconnect: // For DISCONNECT message, we should quit without sending Will s.will = false - //if s.version == packet.ProtocolV50 { - // // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill - // if m.ReasonCode() == packet.CodeRefusedBadUsernameOrPassword { - // s.will = true - // } - // - // expireIn := time.Duration(0) - // if val, e := m.PropertyGet(packet.PropertySessionExpiryInterval); e == nil { - // expireIn = time.Duration(val.(uint32)) - // } - // - // // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- - // // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session - // // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server - // // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. - // if s.expireIn != nil && *s.expireIn == 0 && expireIn != 0 { - // m, _ := packet.NewMessage(packet.ProtocolV50, packet.DISCONNECT) - // msg, _ := m.(*packet.Disconnect) - // msg.SetReasonCode(packet.CodeProtocolError) - // s.WriteMessage(msg, true) // nolint: errcheck - // } - //} - //return - err = errors.New("disconnect") + if s.Version == packet.ProtocolV50 { + // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill + if pkt.ReasonCode() == packet.CodeRefusedBadUsernameOrPassword { + s.will = true + } + + err = errors.New("disconnect") + + if prop := pkt.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { + if val, ok := prop.AsInt(); ok == nil { + // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- + // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session + // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server + // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. + if (s.ExpireIn != nil && *s.ExpireIn == 0) && val != 0 { + err = packet.CodeProtocolError + } else { + newExpireIn := time.Duration(val) + s.ExpireIn = &newExpireIn + } + } + } + } default: s.log.Error("Unsupported incoming message type", - zap.String("ClientID", s.id), + zap.String("ClientID", s.ID), zap.String("type", p.Type().Name())) return nil } if resp != nil { - s.tx.sendPacket(resp) + s.tx.gPush(resp) } return err } -func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err error) { - for _, d := range state.OutMessages { - var msg packet.Provider - if msg, _, err = packet.Decode(s.version, d); err != nil { - s.log.Error("Couldn't decode persisted message", zap.Error(err)) - err = ErrPersistence - return - } +func (s *Type) loadPersistence() (present bool, err error) { + var state *persistenceTypes.SessionMessages + present = false - s.tx.loadFront(msg) + if s.State == nil { + return present, nil } - for _, d := range state.UnAckMessages { - var msg packet.Provider - if msg, _, err = packet.Decode(s.version, d); err != nil { - s.log.Error("Couldn't decode persisted message", zap.Error(err)) - err = ErrPersistence - return + if state, err = s.State.MessagesLoad([]byte(s.ID)); err != nil && err != persistenceTypes.ErrNotFound { + s.log.Error("Couldn't load session state", zap.String("ClientID", s.ID), zap.Error(err)) + } else if err == nil { + // load first unacknowledged publishes if any to keep flow control working + for _, d := range state.UnAckMessages { + var pkt packet.Provider + if pkt, _, err = packet.Decode(s.Version, d); err != nil { + s.log.Error("Couldn't decode persisted message", zap.Error(err)) + err = ErrPersistence + return + } + + switch p := pkt.(type) { + case *packet.Publish: + id, _ := p.ID() + s.flowControl.reAcquire(id) + case *packet.Ack: + id, _ := p.ID() + s.flowControl.reAcquire(id) + } + + s.tx.qLoad(&unacknowledgedPublish{msg: pkt}) + } + + for _, d := range state.OutMessages { + var pkt packet.Provider + if pkt, _, err = packet.Decode(s.Version, d); err != nil { + s.log.Error("Couldn't decode persisted message", zap.Error(err)) + err = ErrPersistence + return + } + + if p, ok := pkt.(*packet.Publish); ok { + if p.QoS() == packet.QoS0 { + s.tx.gLoad(pkt) + } else { + s.tx.qLoad(pkt) + } + } } - s.tx.loadFront(&unacknowledgedPublish{msg: msg}) + present = true } + err = nil return } @@ -279,63 +361,118 @@ func (s *Type) loadPersistence(state *persistenceTypes.SessionMessages) (err err // For the server, when this method is called, it means there's a message that // should be published to the client on the other end of this connection. So we // will call publish() to send the message. -func (s *Type) onSubscribedPublish(p *packet.Publish) error { - _pkt, _ := packet.NewMessage(s.version, packet.PUBLISH) - pkt := _pkt.(*packet.Publish) - - // [MQTT-3.3.1-9] - // [MQTT-3.3.1-3] - pkt.Set(p.Topic(), p.Payload(), p.QoS(), false, false) // nolint: errcheck - - s.tx.queuePacket(pkt) - - return nil +func (s *Type) onSubscribedPublish(p *packet.Publish) { + if p.QoS() == packet.QoS0 { + s.tx.gPush(p) + } else { + s.tx.qPush(p) + } } // forward PUBLISH message to topics manager which takes care about subscribers -func (s *Type) publishToTopic(msg *packet.Publish) error { +func (s *Type) publishToTopic(p *packet.Publish) error { + if err := s.postProcessPublish(p); err != nil { + return err + } + + p.SetPublishID(s.Subscriber.Hash()) + // [MQTT-3.3.1.3] - if msg.Retain() { - if err := s.messenger.Retain(msg); err != nil { - s.log.Error("Error retaining message", zap.String("ClientID", s.id), zap.Error(err)) + if p.Retain() { + if err := s.Messenger.Retain(p); err != nil { + s.log.Error("Error retaining message", zap.String("ClientID", s.ID), zap.Error(err)) } // [MQTT-3.3.1-7] - if msg.QoS() == packet.QoS0 { - _m, _ := packet.NewMessage(s.version, packet.PUBLISH) + if p.QoS() == packet.QoS0 { + _m, _ := packet.New(s.Version, packet.PUBLISH) m := _m.(*packet.Publish) - m.SetQoS(msg.QoS()) // nolint: errcheck - m.SetTopic(msg.Topic()) // nolint: errcheck + m.SetQoS(p.QoS()) // nolint: errcheck + m.SetTopic(p.Topic()) // nolint: errcheck s.retained.lock.Lock() s.retained.list = append(s.retained.list, m) s.retained.lock.Unlock() } } - if err := s.messenger.Publish(msg); err != nil { - s.log.Error("Couldn't publish", zap.String("ClientID", s.id), zap.Error(err)) + if err := s.Messenger.Publish(p); err != nil { + s.log.Error("Couldn't publish", zap.String("ClientID", s.ID), zap.Error(err)) } return nil } // onReleaseIn ack process for incoming messages -func (s *Type) onReleaseIn(msg packet.Provider) { - switch m := msg.(type) { +func (s *Type) onReleaseIn(o, n packet.Provider) { + switch p := o.(type) { case *packet.Publish: - s.publishToTopic(m) // nolint: errcheck + s.publishToTopic(p) // nolint: errcheck } } // onReleaseOut process messages that required ack cycle // onAckTimeout if publish message has not been acknowledged withing specified ackTimeout // server should mark it as a dup and send again -func (s *Type) onReleaseOut(msg packet.Provider) { - switch msg.Type() { +func (s *Type) onReleaseOut(o, n packet.Provider) { + switch n.Type() { case packet.PUBACK: fallthrough case packet.PUBCOMP: - id, _ := msg.ID() - s.flowControl.release(id) + id, _ := n.ID() + if s.flowControl.release(id) { + s.tx.signalQuota() + } + } +} + +func (s *Type) onWaitForRead() error { + return s.WaitForData(s.Desc) +} + +func (s *Type) preProcessPublishV50(p *packet.Publish) error { + // v5.0 + // If the Server included Retain Available in its CONNACK response to a Client with its value set to 0 and it + // receives a PUBLISH packet with the RETAIN flag is set to 1, then it uses the DISCONNECT Reason + // Code of 0x9A (Retain not supported) as described in section 4.13. + if s.Version >= packet.ProtocolV50 && !s.RetainAvailable && p.Retain() { + return packet.CodeRetainNotSupported + } + + if prop := p.PropertyGet(packet.PropertyTopicAlias); prop != nil { + if val, ok := prop.AsShort(); ok == nil && (val == 0 || val > s.MaxRxTopicAlias) { + return packet.CodeInvalidTopicAlias + } } + + return nil +} + +func (s *Type) postProcessPublishV50(p *packet.Publish) error { + // [MQTT-3.3.2.3.4] + if prop := p.PropertyGet(packet.PropertyTopicAlias); prop != nil { + if val, ok := prop.AsShort(); ok == nil { + if len(p.Topic()) != 0 { + // renew alias with new topic + s.topicAlias[val] = p.Topic() + } else { + if topic, kk := s.topicAlias[val]; kk { + // do not check for error as topic has been validated when arrived + p.SetTopic(topic) // nolint: errcheck + } else { + return packet.CodeInvalidTopicAlias + } + } + } + } + + // [MQTT-3.3.2.3.3] + //if prop := p.PropertyGet(packet.PropertyPublicationExpiry); prop != nil { + // pkt = p.ToExpiring() + // + // //if val, ok := prop.AsInt(); ok == nil { + // //p.SetExpiry(time.Duration(val) * time.Second) + // //} + //} + + return nil } diff --git a/connection/flowControl.go b/connection/flowControl.go index 57a616f..6224e8a 100644 --- a/connection/flowControl.go +++ b/connection/flowControl.go @@ -8,36 +8,40 @@ import ( "github.com/VolantMQ/volantmq/packet" ) +var ( + errExit = errors.New("exit") + errQuotaExceeded = errors.New("quota exceeded") +) + type packetsFlowControl struct { - counter uint64 - quit chan struct{} - cond *sync.Cond - inUse map[packet.IDType]bool - sendQuota int32 - preserveOrder bool + counter uint64 + quit chan struct{} + inUse sync.Map + quota int32 } -func newFlowControl(quit chan struct{}, preserveOrder bool) *packetsFlowControl { +func newFlowControl(quit chan struct{}, quota int32) *packetsFlowControl { return &packetsFlowControl{ - inUse: make(map[packet.IDType]bool), - cond: sync.NewCond(new(sync.Mutex)), - quit: quit, - preserveOrder: preserveOrder, + quit: quit, + quota: quota, } } +func (s *packetsFlowControl) reAcquire(id packet.IDType) { + atomic.AddInt32(&s.quota, -1) + s.inUse.Store(id, true) +} + func (s *packetsFlowControl) acquire() (packet.IDType, error) { - defer s.cond.L.Unlock() - s.cond.L.Lock() + select { + case <-s.quit: + return 0, errExit + default: + } - if (s.preserveOrder && !atomic.CompareAndSwapInt32(&s.sendQuota, 0, 1)) || - (atomic.AddInt32(&s.sendQuota, -1) == 0) { - s.cond.Wait() - select { - case <-s.quit: - return 0, errors.New("exit") - default: - } + var err error + if atomic.AddInt32(&s.quota, -1) == 0 { + err = errQuotaExceeded } var id packet.IDType @@ -45,41 +49,16 @@ func (s *packetsFlowControl) acquire() (packet.IDType, error) { for count := 0; count <= 0xFFFF; count++ { s.counter++ id = packet.IDType(s.counter) - if _, ok := s.inUse[id]; !ok { - s.inUse[id] = true + if _, ok := s.inUse.LoadOrStore(id, true); !ok { break } } - return id, nil + return id, err } -//func (s *packetsFlowControl) reAcquire(id message.IDType) error { -// defer s.lock.Unlock() -// s.lock.Lock() -// -// if (s.preserveOrder && !atomic.CompareAndSwapInt32(&s.sendQuota, 0, 1)) || -// (atomic.AddInt32(&s.sendQuota, -1) == 0) { -// s.cond.Wait() -// select { -// case <-s.quit: -// return errors.New("exit") -// default: -// } -// } -// -// s.inUse[id] = true -// -// return nil -//} - -func (s *packetsFlowControl) release(id packet.IDType) { - defer func() { - atomic.AddInt32(&s.sendQuota, -1) - s.cond.Signal() - }() +func (s *packetsFlowControl) release(id packet.IDType) bool { + s.inUse.Delete(id) - defer s.cond.L.Unlock() - s.cond.L.Lock() - delete(s.inUse, id) + return atomic.AddInt32(&s.quota, 1) == 1 } diff --git a/connection/netCallbacks.go b/connection/netCallbacks.go index 364875a..6d6e0a1 100644 --- a/connection/netCallbacks.go +++ b/connection/netCallbacks.go @@ -2,12 +2,11 @@ package connection import ( "container/list" - "sync/atomic" "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/persistence/types" - "github.com/VolantMQ/volantmq/subscriber" + "github.com/VolantMQ/volantmq/topics/types" "go.uber.org/zap" ) @@ -30,19 +29,16 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { unAckMessages := [][]byte{} var next *list.Element - for elem := s.tx.messages.Front(); elem != nil; elem = next { + for elem := s.tx.qMessages.Front(); elem != nil; elem = next { next = elem.Next() - switch m := s.tx.messages.Remove(elem).(type) { + switch m := s.tx.qMessages.Remove(elem).(type) { case *packet.Publish: - qos := m.QoS() - if qos != packet.QoS0 || (s.offlineQoS0 && qos == packet.QoS0) { - // make sure message has some IDType to prevent encode error - m.SetPacketID(0) - if buf, err := encodeMessage(m); err != nil { - s.log.Error("Couldn't encode message for persistence", zap.Error(err)) - } else { - outMessages = append(outMessages, buf) - } + // make sure message has some IDType to prevent encode error + m.SetPacketID(0) + if buf, err := encodeMessage(m); err != nil { + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) + } else { + outMessages = append(outMessages, buf) } case *unacknowledgedPublish: if buf, err := encodeMessage(m.msg); err != nil { @@ -53,49 +49,80 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { } } - for _, m := range s.pubOut.get() { - switch msg := m.(type) { - case *packet.Publish: - if msg.QoS() == packet.QoS1 { - msg.SetDup(true) + if s.OfflineQoS0 { + for elem := s.tx.qMessages.Front(); elem != nil; elem = next { + next = elem.Next() + switch m := s.tx.qMessages.Remove(elem).(type) { + case *packet.Publish: + if buf, err := encodeMessage(m); err != nil { + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) + } else { + outMessages = append(outMessages, buf) + } } } - - if buf, err := encodeMessage(m); err != nil { - s.log.Error("Couldn't encode message for persistence", zap.Error(err)) - } else { - unAckMessages = append(unAckMessages, buf) - } } + s.pubOut.messages.Range( + func(k, v interface{}) bool { + pkt, ok := v.(packet.Provider) + if ok { + switch msg := v.(type) { + case *packet.Publish: + if msg.QoS() == packet.QoS1 { + msg.SetDup(true) + } + } + + if buf, err := encodeMessage(pkt); err != nil { + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) + } else { + unAckMessages = append(unAckMessages, buf) + } + } + + return true + }) + return &persistenceTypes.SessionMessages{ OutMessages: outMessages, UnAckMessages: unAckMessages, } } -func (s *Type) onConnectionClose(will bool) { +func (s *Type) onConnectionClose(will bool, err error) { s.onConnDisconnect.Do(func() { - params := &DisconnectParams{ - Will: will, - ExpireAt: nil, - } + // make sure connection has been started before proceeding to any shutdown procedures + s.started.Wait() - //s.started.Wait() + // shutdown quit channel tells all routines finita la commedia + close(s.quit) + s.StopReceiving(s.Desc) // nolint: errcheck + s.rx.shutdown() + // clean up transmitter to allow send disconnect command to client if needed + s.tx.shutdown() - if err := s.conn.Close(); err != nil { - s.log.Error("close connection", zap.String("ClientID", s.id), zap.Error(err)) - } + // put subscriber in offline mode + s.Subscriber.Offline(s.KillOnDisconnect) - close(s.quit) + if err != nil && s.Version >= packet.ProtocolV50 { + // server wants to tell client disconnect reason + reason, _ := err.(packet.ReasonCode) + p, _ := packet.New(s.Version, packet.DISCONNECT) + pkt, _ := p.(*packet.Disconnect) + pkt.SetReasonCode(reason) - s.subscriber.Offline(s.clean) + sz, _ := pkt.Size() + buf := make([]byte, sz) + pkt.Encode(buf) // nolint: errcheck - s.tx.shutdown() - s.rx.shutdown() + if _, err = s.Conn.Write(buf); err != nil { + s.log.Info("Couldn't write disconnect message", zap.String("ClientID", s.ID), zap.Error(err)) + } + } - if !s.clean { - params.State = s.getState() + if err = s.Conn.Close(); err != nil { + s.log.Error("close connection", zap.String("ClientID", s.ID), zap.Error(err)) } // [MQTT-3.3.1-7] @@ -106,7 +133,22 @@ func (s *Type) onConnectionClose(will bool) { //} s.retained.list = []*packet.Publish{} s.retained.lock.Unlock() - s.conn = nil + s.Conn = nil + + params := &DisconnectParams{ + Will: will, + ExpireAt: s.ExpireIn, + Desc: s.Desc, + Reason: packet.CodeSuccess, + } + + if rc, ok := err.(packet.ReasonCode); ok { + params.Reason = rc + } + + if !s.KillOnDisconnect { + params.State = s.getState() + } s.onDisconnect(params) }) @@ -116,25 +158,29 @@ func (s *Type) onConnectionClose(will bool) { // On QoS == 0, we should just take the next step, no ack required // On QoS == 1, send back PUBACK, then take the next step // On QoS == 2, we need to put it in the ack queue, send back PUBREC -func (s *Type) onPublish(msg *packet.Publish) packet.Provider { +func (s *Type) onPublish(pkt *packet.Publish) (packet.Provider, error) { // check for topic access - + var err error reason := packet.CodeSuccess + if err = s.preProcessPublish(pkt); err != nil { + return nil, err + } + var resp packet.Provider // This case is for V5.0 actually as ack messages may return status. // To deal with V3.1.1 two ways left: // - ignore the message but send acks // - return error which leads to disconnect - if status := s.auth.ACL(s.id, s.username, msg.Topic(), auth.AccessTypeWrite); status == auth.StatusDeny { + if status := s.Auth.ACL(s.ID, s.Username, pkt.Topic(), auth.AccessTypeWrite); status == auth.StatusDeny { reason = packet.CodeAdministrativeAction } - switch msg.QoS() { + switch pkt.QoS() { case packet.QoS2: - resp, _ = packet.NewMessage(s.version, packet.PUBREC) + resp, _ = packet.New(s.Version, packet.PUBREC) r, _ := resp.(*packet.Ack) - id, _ := msg.ID() + id, _ := pkt.ID() r.SetPacketID(id) r.SetReason(reason) @@ -143,39 +189,31 @@ func (s *Type) onPublish(msg *packet.Publish) packet.Provider { // store incoming QoS 2 message before sending PUBREC as theoretically PUBREL // might come before store in case message store done after write PUBREC if reason < packet.CodeUnspecifiedError { - s.pubIn.store(msg) + s.pubIn.store(pkt) } case packet.QoS1: - resp, _ = packet.NewMessage(s.version, packet.PUBACK) + resp, _ = packet.New(s.Version, packet.PUBACK) r, _ := resp.(*packet.Ack) - id, _ := msg.ID() - reason := packet.CodeSuccess + id, _ := pkt.ID() r.SetPacketID(id) r.SetReason(reason) - //_, err = s.conn.WriteMessage(resp, false) - + fallthrough + case packet.QoS0: // QoS 0 + // [MQTT-4.3.1] // [MQTT-4.3.2-4] if reason < packet.CodeUnspecifiedError { - if err := s.publishToTopic(msg); err != nil { + if err = s.publishToTopic(pkt); err != nil { s.log.Error("Couldn't publish message", - zap.String("ClientID", s.id), - zap.Uint8("QoS", uint8(msg.QoS())), + zap.String("ClientID", s.ID), + zap.Uint8("QoS", uint8(pkt.QoS())), zap.Error(err)) } } - case packet.QoS0: // QoS 0 - // [MQTT-4.3.1] - if err := s.publishToTopic(msg); err != nil { - s.log.Error("Couldn't publish message", - zap.String("ClientID", s.id), - zap.Uint8("QoS", uint8(msg.QoS())), - zap.Error(err)) - } } - return resp + return resp, err } // onAck handle ack acknowledgment received from remote @@ -193,18 +231,22 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { discard := false - if s.version == packet.ProtocolV50 && mIn.Reason() >= packet.CodeUnspecifiedError { + id, _ := msg.ID() + + if s.Version == packet.ProtocolV50 && mIn.Reason() >= packet.CodeUnspecifiedError { // v5.9 [MQTT-4.9] - atomic.AddInt32(&s.sendQuota, 1) + //atomic.AddInt32(&s.SendQuota, 1) + if s.flowControl.release(id) { + s.tx.signalQuota() + } discard = true } if !discard { - resp, _ = packet.NewMessage(s.version, packet.PUBREL) + resp, _ = packet.New(s.Version, packet.PUBREL) r, _ := resp.(*packet.Ack) - id, _ := msg.ID() r.SetPacketID(id) // 2. Put PUBREL into ack queue @@ -214,7 +256,7 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { } case packet.PUBREL: // Remote has released PUBLISH - resp, _ = packet.NewMessage(s.version, packet.PUBCOMP) + resp, _ = packet.New(s.Version, packet.PUBCOMP) r, _ := resp.(*packet.Ack) id, _ := msg.ID() @@ -226,7 +268,7 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { s.pubOut.release(msg) default: s.log.Error("Unsupported ack message type", - zap.String("ClientID", s.id), + zap.String("ClientID", s.ID), zap.String("type", msg.Type().Name())) } default: @@ -237,14 +279,14 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { } func (s *Type) onSubscribe(msg *packet.Subscribe) packet.Provider { - m, _ := packet.NewMessage(s.version, packet.SUBACK) + m, _ := packet.New(s.Version, packet.SUBACK) resp, _ := m.(*packet.SubAck) id, _ := msg.ID() resp.SetPacketID(id) var retCodes []packet.ReasonCode - var retainedMessages []*packet.Publish + var retainedPublishes []*packet.Publish iter := msg.Topics().Iterator() for kv, ok := iter(); ok; kv, ok = iter() { @@ -259,33 +301,30 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) packet.Provider { subsID := uint32(0) // V5.0 [MQTT-3.8.2.1.2] - if sID, err := msg.PropertyGet(packet.PropertySubscriptionIdentifier); err == nil { - subsID = sID.(uint32) + if prop := msg.PropertyGet(packet.PropertySubscriptionIdentifier); prop != nil { + if v, e := prop.AsInt(); e == nil { + subsID = v + } } - subsParams := subscriber.SubscriptionParams{ - ID: subsID, - Requested: ops, + subsParams := topicsTypes.SubscriptionParams{ + ID: subsID, + Ops: ops, } - if rQoS, retained, err := s.subscriber.Subscribe(t, &subsParams); err != nil { - // [MQTT-3.9.3]Æ’ - if s.version == packet.ProtocolV50 { + if grantedQoS, retained, err := s.Subscriber.Subscribe(t, &subsParams); err != nil { + // [MQTT-3.9.3] + if s.Version == packet.ProtocolV50 { reason = packet.CodeUnspecifiedError } else { reason = packet.QosFailure } } else { - reason = packet.ReasonCode(rQoS) - retainedMessages = append(retainedMessages, retained...) + reason = packet.ReasonCode(grantedQoS) + retainedPublishes = append(retainedPublishes, retained...) } retCodes = append(retCodes, reason) - - s.log.Debug("Subscribing", - zap.String("ClientID", s.id), - zap.String("topic", t), - zap.Uint8("result_code", uint8(reason))) } if err := resp.AddReturnCodes(retCodes); err != nil { @@ -293,14 +332,13 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) packet.Provider { } // Now put retained messages into publish queue - for _, rm := range retainedMessages { - m, _ := packet.NewMessage(s.version, packet.PUBLISH) - msg, _ := m.(*packet.Publish) - - // [MQTT-3.3.1-8] - msg.Set(rm.Topic(), rm.Payload(), rm.QoS(), true, false) // nolint: errcheck - - s.tx.sendPacket(msg) + for _, rp := range retainedPublishes { + if pkt, err := rp.Clone(s.Version); err == nil { + pkt.SetRetain(true) + s.onSubscribedPublish(pkt) + } else { + s.log.Error("Couldn't clone PUBLISH message", zap.String("ClientID", s.ID), zap.Error(err)) + } } return resp @@ -318,7 +356,7 @@ func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) packet.Provider { reason := packet.CodeSuccess if authorized { - if err := s.subscriber.UnSubscribe(t); err != nil { + if err := s.Subscriber.UnSubscribe(t); err != nil { s.log.Error("Couldn't unsubscribe from topic", zap.Error(err)) } else { reason = packet.CodeNoSubscriptionExisted @@ -330,7 +368,7 @@ func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) packet.Provider { retCodes = append(retCodes, reason) } - m, _ := packet.NewMessage(s.version, packet.UNSUBACK) + m, _ := packet.New(s.Version, packet.UNSUBACK) resp, _ := m.(*packet.UnSubAck) id, _ := msg.ID() diff --git a/connection/receiver.go b/connection/receiver.go index 5b14e62..1be81f3 100644 --- a/connection/receiver.go +++ b/connection/receiver.go @@ -3,28 +3,32 @@ package connection import ( "bufio" "encoding/binary" - "errors" "net" "sync" "sync/atomic" "time" "github.com/VolantMQ/volantmq/packet" + "github.com/troian/easygo/netpoll" ) type receiverConfig struct { - conn net.Conn - quit chan struct{} - keepAlive time.Duration // nolint: structcheck - onPacket func(packet.Provider) error - onDisconnect func(bool) - will *bool - version packet.ProtocolVersion + conn net.Conn + keepAliveTimer *time.Timer // nolint: structcheck + will *bool + quit chan struct{} + keepAlive time.Duration // nolint: structcheck + waitForRead func() error + onPacket func(packet.Provider) error + onDisconnect func(bool, error) // nolint: megacheck + maxPacketSize uint32 + version packet.ProtocolVersion } type receiver struct { receiverConfig - wg sync.WaitGroup + wg sync.WaitGroup + //started sync.WaitGroup running uint32 recv []byte remainingRecv int @@ -34,47 +38,85 @@ func newReceiver(config *receiverConfig) *receiver { r := &receiver{ receiverConfig: *config, } + + if r.keepAlive > 0 { + r.keepAliveTimer = time.AfterFunc(r.keepAlive, r.keepAliveExpired) + } + return r } +func (r *receiver) keepAliveExpired() { + r.onDisconnect(true, nil) +} + func (r *receiver) shutdown() { r.wg.Wait() + + if r.keepAlive > 0 { + r.keepAliveTimer.Stop() + r.keepAliveTimer = nil + } } -func (r *receiver) run() { +func (r *receiver) run(event netpoll.Event) { + select { + case <-r.quit: + return + default: + } + if atomic.CompareAndSwapUint32(&r.running, 0, 1) { r.wg.Wait() r.wg.Add(1) - if r.keepAlive > 0 { - r.conn.SetReadDeadline(time.Now().Add(r.keepAlive)) // nolint: errcheck - } else { - r.conn.SetReadDeadline(time.Time{}) // nolint: errcheck + + exit := false + if event&(netpoll.EventReadHup|netpoll.EventWriteHup|netpoll.EventHup|netpoll.EventErr) != 0 { + exit = true } - go r.routine() + go r.routine(exit) } } -func (r *receiver) routine() { +func (r *receiver) routine(exit bool) { var err error - defer func() { + + signalDisconnect := func() { r.wg.Done() - r.onDisconnect(*r.will) - }() + if _, ok := err.(packet.ReasonCode); !ok { + err = nil + } + r.onDisconnect(*r.will, err) + } + + if exit { + defer signalDisconnect() + return + } buf := bufio.NewReader(r.conn) for atomic.LoadUint32(&r.running) == 1 { + if r.keepAlive > 0 { + r.keepAliveTimer.Reset(r.keepAlive) + } var pkt packet.Provider - if pkt, err = r.readPacket(buf); err != nil || pkt == nil { + if pkt, err = r.readPacket(buf); err == nil { + err = r.onPacket(pkt) + } + + if err != nil { atomic.StoreUint32(&r.running, 0) - } else { - if r.keepAlive > 0 { - r.conn.SetReadDeadline(time.Now().Add(r.keepAlive)) // nolint: errcheck - } - if err = r.onPacket(pkt); err != nil { - atomic.StoreUint32(&r.running, 0) - } + } + } + + if _, ok := err.(packet.ReasonCode); ok { + defer signalDisconnect() + } else { + r.wg.Done() + if err = r.waitForRead(); err != nil { + defer signalDisconnect() } } } @@ -85,15 +127,15 @@ func (r *receiver) readPacket(buf *bufio.Reader) (packet.Provider, error) { if len(r.recv) == 0 { var header []byte peekCount := 2 - // Let's read enough bytes to get the message header (msg type, remaining length) + // Let's read enough bytes to get the fixed header/fh (msg type/flags, remaining length) for { - // If we have read 5 bytes and still not done, then there's a problem. + // max length of fh is 5 bytes + // if we have read 5 bytes and still not done report protocol error and exit if peekCount > 5 { - return nil, errors.New("sendrecv/peekMessageSize: 4th byte of remaining length has continuation bit set") + return nil, packet.CodeProtocolError } - header, err = buf.Peek(peekCount) - if err != nil { + if header, err = buf.Peek(peekCount); err != nil { return nil, err } @@ -118,19 +160,23 @@ func (r *receiver) readPacket(buf *bufio.Reader) (packet.Provider, error) { r.recv = make([]byte, r.remainingRecv) } + if r.remainingRecv > int(r.maxPacketSize) { + return nil, packet.CodePacketTooLarge + } + offset := len(r.recv) - r.remainingRecv for offset != r.remainingRecv { var n int - n, err = buf.Read(r.recv[offset:]) - offset += n - if err != nil { + if n, err = buf.Read(r.recv[offset:]); err != nil { return nil, err } + offset += n } var pkt packet.Provider pkt, _, err = packet.Decode(r.version, r.recv) + r.recv = []byte{} r.remainingRecv = 0 diff --git a/connection/transmitter.go b/connection/transmitter.go index 1f466e2..f92d691 100644 --- a/connection/transmitter.go +++ b/connection/transmitter.go @@ -8,88 +8,129 @@ import ( "sync/atomic" "time" + "reflect" + + "math/rand" + "github.com/VolantMQ/volantmq/packet" + "github.com/VolantMQ/volantmq/types" "go.uber.org/zap" ) type transmitterConfig struct { - id string - quit chan struct{} - flowControl *packetsFlowControl - pubIn *ackQueue - pubOut *ackQueue - log *zap.Logger - conn net.Conn - onDisconnect func(bool) + id string + quit chan struct{} + flowControl *packetsFlowControl + pubIn *ackQueue + pubOut *ackQueue + log *zap.Logger + conn net.Conn + onDisconnect func(bool, error) + maxPacketSize uint32 + topicAliasMax uint16 } type transmitter struct { transmitterConfig - lock sync.Mutex - messages *list.List - available chan int - running uint32 - timer *time.Timer - wg sync.WaitGroup + available chan int + timer *time.Timer + gMessages *list.List + qMessages *list.List + wg sync.WaitGroup + onStop types.OnceWait + gLock sync.Mutex + qLock sync.Mutex + topicAlias map[string]uint16 + running uint32 + topicAliasCurrMax uint16 + qExceeded bool } func newTransmitter(config *transmitterConfig) *transmitter { p := &transmitter{ transmitterConfig: *config, - messages: list.New(), available: make(chan int, 1), timer: time.NewTimer(1 * time.Second), + topicAlias: make(map[string]uint16), + gMessages: list.New(), + qMessages: list.New(), } - p.timer.Stop() return p } func (p *transmitter) shutdown() { - p.timer.Stop() - p.wg.Wait() - select { - case <-p.available: - default: - close(p.available) - } + p.onStop.Do(func() { + atomic.StoreUint32(&p.running, 0) + p.timer.Stop() + p.wg.Wait() + + select { + case <-p.available: + default: + close(p.available) + } + }) } -func (p *transmitter) loadFront(value interface{}) { - p.lock.Lock() - p.messages.PushFront(value) - p.lock.Unlock() +func (p *transmitter) gPush(pkt packet.Provider) { + p.gLock.Lock() + p.gMessages.PushBack(pkt) + p.gLock.Unlock() p.signalAvailable() + p.run() } -//func (p *transmitter) loadBack(value interface{}) { -// p.lock.Lock() -// p.messages.PushBack(value) -// p.lock.Unlock() -// p.signalAvailable() -//} +func (p *transmitter) gLoad(pkt packet.Provider) { + p.gMessages.PushBack(pkt) + p.signalAvailable() +} -func (p *transmitter) sendPacket(pkt packet.Provider) { - p.lock.Lock() - p.messages.PushFront(pkt) - p.lock.Unlock() +func (p *transmitter) gLoadList(l *list.List) { + p.gMessages.PushBackList(l) p.signalAvailable() - p.run() } -func (p *transmitter) queuePacket(pkt packet.Provider) { - p.lock.Lock() - p.messages.PushBack(pkt) - p.lock.Unlock() +func (p *transmitter) qPush(pkt interface{}) { + p.qLock.Lock() + p.qMessages.PushBack(pkt) + p.qLock.Unlock() p.signalAvailable() p.run() } +func (p *transmitter) qLoad(pkt interface{}) { + p.qMessages.PushBack(pkt) + p.signalAvailable() +} + +func (p *transmitter) qLoadList(l *list.List) { + p.qMessages.PushBackList(l) + p.signalAvailable() +} + func (p *transmitter) signalAvailable() { select { - case p.available <- 1: + case <-p.quit: + return default: + select { + case p.available <- 1: + default: + } + } +} + +func (p *transmitter) signalQuota() { + p.qLock.Lock() + p.qExceeded = false + l := p.qMessages.Len() + p.qLock.Unlock() + + if l > 0 { + p.signalAvailable() + p.run() } } @@ -108,6 +149,96 @@ func (p *transmitter) flushBuffers(buf net.Buffers) error { return e } +func (p *transmitter) packetSize(value interface{}) (int, bool) { + var sz int + var err error + if obj, ok := value.(sizeAble); !ok { + p.log.Fatal("Object does not belong to allowed types", + zap.String("ClientID", p.id), + zap.String("Type", reflect.TypeOf(value).String())) + } else { + if sz, err = obj.Size(); err != nil { + p.log.Error("Couldn't calculate message size", zap.String("ClientID", p.id), zap.Error(err)) + return 0, false + } + } + + // ignore any packet with size bigger than negotiated + if sz > int(p.maxPacketSize) { + p.log.Warn("Ignore packet with size bigger than negotiated with client", + zap.String("ClientID", p.id), + zap.Uint32("negotiated", p.maxPacketSize), + zap.Int("actual", sz)) + return 0, false + } + + return sz, true +} + +func (p *transmitter) gAvailable() bool { + defer p.gLock.Unlock() + p.gLock.Lock() + + return p.gMessages.Len() > 0 +} + +func (p *transmitter) qAvailable() bool { + defer p.qLock.Unlock() + p.qLock.Lock() + + return !p.qExceeded && p.qMessages.Len() > 0 +} + +func (p *transmitter) gPopPacket() packet.Provider { + defer p.gLock.Unlock() + p.gLock.Lock() + + var elem *list.Element + + if elem = p.gMessages.Front(); elem == nil { + return nil + } + + value := p.gMessages.Remove(elem) + + return value.(packet.Provider) +} + +func (p *transmitter) qPopPacket() packet.Provider { + defer p.qLock.Unlock() + p.qLock.Lock() + + if elem := p.qMessages.Front(); !p.qExceeded && elem != nil { + var pkt packet.Provider + value := elem.Value + switch m := value.(type) { + case *packet.Publish: + // try acquire packet id + id, err := p.flowControl.acquire() + if err == errExit { + atomic.StoreUint32(&p.running, 0) + return nil + } + + if err == errQuotaExceeded { + p.qExceeded = true + } + + m.SetPacketID(id) + pkt = m + case *unacknowledgedPublish: + pkt = m.msg + + } + p.qMessages.Remove(elem) + p.pubOut.store(pkt) + + return pkt + } + + return nil +} + func (p *transmitter) routine() { var err error @@ -115,100 +246,102 @@ func (p *transmitter) routine() { p.wg.Done() if err != nil { - p.onDisconnect(true) + p.onDisconnect(true, nil) } }() sendBuffers := net.Buffers{} for atomic.LoadUint32(&p.running) == 1 { select { + case <-p.quit: + err = errors.New("exit") + atomic.StoreUint32(&p.running, 0) + return case <-p.timer.C: if err = p.flushBuffers(sendBuffers); err != nil { atomic.StoreUint32(&p.running, 0) return } - sendBuffers = net.Buffers{} - p.lock.Lock() - l := p.messages.Len() - p.lock.Unlock() - if l != 0 { + if p.qAvailable() || p.gAvailable() { p.signalAvailable() } else { atomic.StoreUint32(&p.running, 0) - return } - case <-p.available: - p.lock.Lock() - - var elem *list.Element + // check if there any control packets except PUBLISH QoS 1/2 + // and process them + var packets []packet.Provider + if pkt := p.gPopPacket(); pkt != nil { + packets = append(packets, pkt) + } - if elem = p.messages.Front(); elem == nil { - p.lock.Unlock() - atomic.StoreUint32(&p.running, 0) - break + if pkt := p.qPopPacket(); pkt != nil { + packets = append(packets, pkt) } - value := p.messages.Remove(p.messages.Front()) + prevLen := len(sendBuffers) + for _, pkt := range packets { + switch pkt := pkt.(type) { + case *packet.Publish: + p.setTopicAlias(pkt) + } - if p.messages.Len() != 0 { - p.signalAvailable() - } - p.lock.Unlock() - - var msg packet.Provider - switch m := value.(type) { - case *packet.Publish: - if m.QoS() != packet.QoS0 { - var id packet.IDType - if id, err = p.flowControl.acquire(); err == nil { - m.SetPacketID(id) - p.pubOut.store(m) + if sz, ok := p.packetSize(pkt); ok { + buf := make([]byte, sz) + if _, err = pkt.Encode(buf); err != nil { + p.log.Error("Message encode", zap.Error(err)) } else { - // if acquire id returned error session is about to exit. Queue message back and get away - p.lock.Lock() - p.messages.PushBack(m) - p.lock.Unlock() - err = errors.New("exit") - atomic.StoreUint32(&p.running, 0) - return + sendBuffers = append(sendBuffers, buf) } } - msg = m - case *unacknowledgedPublish: - msg = m.msg - p.pubOut.store(msg) - default: - msg = m.(packet.Provider) } - var sz int - if sz, err = msg.Size(); err != nil { - p.log.Error("Couldn't calculate message size", zap.String("ClientID", p.id), zap.Error(err)) - return - } + available := true - buf := make([]byte, sz) - _, err = msg.Encode(buf) - sendBuffers = append(sendBuffers, buf) + if p.qAvailable() || p.gAvailable() { + p.signalAvailable() + } else { + available = false + } - l := len(sendBuffers) - if l == 1 { + if prevLen == 0 { p.timer.Reset(1 * time.Millisecond) - } else if l == 5 { + } else if len(sendBuffers) >= 5 { p.timer.Stop() if err = p.flushBuffers(sendBuffers); err != nil { atomic.StoreUint32(&p.running, 0) } sendBuffers = net.Buffers{} + + if !available { + atomic.StoreUint32(&p.running, 0) + } } - case <-p.quit: - err = errors.New("exit") - atomic.StoreUint32(&p.running, 0) - return + } + } +} + +func (p *transmitter) setTopicAlias(pkt *packet.Publish) { + if p.topicAliasMax > 0 { + var ok bool + var alias uint16 + if alias, ok = p.topicAlias[pkt.Topic()]; !ok { + if p.topicAliasCurrMax < p.topicAliasMax { + p.topicAliasCurrMax++ + alias = p.topicAliasCurrMax + ok = true + } else { + alias = uint16(rand.Intn(int(p.topicAliasMax)) + 1) + } + } else { + ok = false + } + + if err := pkt.PropertySet(packet.PropertyTopicAlias, alias); err == nil && !ok { + pkt.SetTopic("") // nolint: errcheck } } } diff --git a/examples/tcp&ws/volantmq.go b/examples/tcp&ws/volantmq.go index 074fb30..8afb9a4 100644 --- a/examples/tcp&ws/volantmq.go +++ b/examples/tcp&ws/volantmq.go @@ -39,7 +39,7 @@ func main() { configuration.Init(ops) - logger := configuration.GetProdLogger().Named("example") + logger := configuration.GetLogger().Named("example") var err error diff --git a/examples/tcp/volantmq.go b/examples/tcp/volantmq.go index 1f86eaf..5e514d0 100644 --- a/examples/tcp/volantmq.go +++ b/examples/tcp/volantmq.go @@ -17,6 +17,7 @@ package main import ( "os" "os/signal" + "runtime" "syscall" "github.com/VolantMQ/volantmq" @@ -27,8 +28,8 @@ import ( "github.com/spf13/viper" "go.uber.org/zap" + "net/http" _ "net/http/pprof" - "runtime" _ "runtime/debug" ) @@ -39,7 +40,7 @@ func main() { configuration.Init(ops) - logger := configuration.GetProdLogger().Named("example") + logger := configuration.GetLogger().Named("example") var err error @@ -118,6 +119,8 @@ func main() { logger.Error("Couldn't start listener", zap.Error(err)) } + go http.ListenAndServe(":6061", nil) // nolint: errcheck + ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) sig := <-ch diff --git a/examples/websocket/ws.go b/examples/websocket/ws.go index eb19089..598551c 100644 --- a/examples/websocket/ws.go +++ b/examples/websocket/ws.go @@ -38,7 +38,7 @@ func main() { configuration.Init(ops) - logger := configuration.GetProdLogger().Named("example") + logger := configuration.GetLogger().Named("example") var err error diff --git a/packet/connack.go b/packet/connack.go index e108b54..8878c01 100644 --- a/packet/connack.go +++ b/packet/connack.go @@ -107,32 +107,32 @@ func (msg *ConnAck) decodeMessage(src []byte) (int, error) { return total, nil } -func (msg *ConnAck) encodeMessage(dst []byte) (int, error) { - total := 0 +func (msg *ConnAck) encodeMessage(to []byte) (int, error) { + offset := 0 if msg.sessionPresent { - dst[total] = 1 + to[offset] = 1 } else { - dst[total] = 0 + to[offset] = 0 } - total++ + offset++ - dst[total] = msg.returnCode.Value() - total++ + to[offset] = msg.returnCode.Value() + offset++ var err error // V5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { var n int - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err + n, err = encodeProperties(msg.properties, to[offset:]) + offset += n + if err != nil { + return offset, err } - - total += n } - return total, err + return offset, err } func (msg *ConnAck) size() int { @@ -140,8 +140,7 @@ func (msg *ConnAck) size() int { // v5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += pLen + total += int(msg.properties.FullLen()) } return total diff --git a/packet/connack_test.go b/packet/connack_test.go index 7637b43..e0719bf 100644 --- a/packet/connack_test.go +++ b/packet/connack_test.go @@ -21,7 +21,7 @@ import ( ) func TestConnAckMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, CONNACK) + m, err := New(ProtocolV311, CONNACK) require.NoError(t, err) msg, ok := m.(*ConnAck) @@ -117,7 +117,7 @@ func TestConnAckMessageEncode(t *testing.T) { 0, // connection accepted } - m, err := NewMessage(ProtocolV311, CONNACK) + m, err := New(ProtocolV311, CONNACK) require.NoError(t, err) msg, ok := m.(*ConnAck) @@ -169,7 +169,7 @@ func TestConnAckDecodeEncodeEquiv(t *testing.T) { func TestConnAckEncodeEnsureSize(t *testing.T) { dst := make([]byte, 3) - m, err := NewMessage(ProtocolV311, CONNACK) + m, err := New(ProtocolV311, CONNACK) require.NoError(t, err) msg, ok := m.(*ConnAck) @@ -183,7 +183,7 @@ func TestConnAckEncodeEnsureSize(t *testing.T) { } func TestConnAckCodeWrite(t *testing.T) { - m, err := NewMessage(ProtocolV311, CONNACK) + m, err := New(ProtocolV311, CONNACK) require.NoError(t, err) msg, ok := m.(*ConnAck) diff --git a/packet/connect.go b/packet/connect.go index e751137..6a323e3 100644 --- a/packet/connect.go +++ b/packet/connect.go @@ -73,20 +73,20 @@ func newConnect() *Connect { // return msg.version //} -// CleanStart returns the bit that specifies the handling of the Session state. +// IsClean returns the bit that specifies the handling of the Session state. // The Client and Server can store Session state to enable reliable messaging to // continue across a sequence of Network Connections. This bit is used to control // the lifetime of the Session state. -func (msg *Connect) CleanStart() bool { - return (msg.connectFlags & maskConnFlagCleanSession) != 0 +func (msg *Connect) IsClean() bool { + return (msg.connectFlags & maskConnFlagClean) != 0 } -// SetCleanStart sets the bit that specifies the handling of the Session state. -func (msg *Connect) SetCleanStart(v bool) { +// SetClean sets the bit that specifies the handling of the Session state. +func (msg *Connect) SetClean(v bool) { if v { - msg.connectFlags |= maskConnFlagCleanSession // 0x02 // 00000010 + msg.connectFlags |= maskConnFlagClean // 0x02 // 00000010 } else { - msg.connectFlags &= ^maskConnFlagCleanSession // 0xFD // 11111101 + msg.connectFlags &= ^maskConnFlagClean // 0xFD // 11111101 } } @@ -232,131 +232,131 @@ func (msg *Connect) passwordFlag() bool { return (msg.connectFlags & maskConnFlagPassword) != 0 } -func (msg *Connect) encodeMessage(dst []byte) (int, error) { +func (msg *Connect) encodeMessage(to []byte) (int, error) { if _, ok := SupportedVersions[msg.version]; !ok { return 0, ErrInvalidProtocolVersion } - total := 0 + offset := 0 // V3.1.1 [MQTT-3.1.2.1] // V5.0 [MQTT-3.1.2.1] - n, err := WriteLPBytes(dst[total:], []byte(SupportedVersions[msg.version])) - total += n + n, err := WriteLPBytes(to[offset:], []byte(SupportedVersions[msg.version])) + offset += n if err != nil { - return total, err + return offset, err } // V3.1.1 [MQTT-3.1.2.2] // V5.0 [MQTT-3.1.2.2] - dst[total] = byte(msg.version) - total++ + to[offset] = byte(msg.version) + offset++ // V3.1.1 [MQTT-3.1.2.3] // V5.0 [MQTT-3.1.2.3] - dst[total] = msg.connectFlags - total++ + to[offset] = msg.connectFlags + offset++ // V3.1.1 [MQTT-3.1.2.10] // V5.0 [MQTT-3.1.2.10] - binary.BigEndian.PutUint16(dst[total:], msg.keepAlive) - total += 2 + binary.BigEndian.PutUint16(to[offset:], msg.keepAlive) + offset += 2 // V5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err + if n, err = encodeProperties(msg.properties, to[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } // V3.1.1 [MQTT-3.1.3.1] // V5.0 [MQTT-3.1.3.1] - n, err = WriteLPBytes(dst[total:], msg.clientID) - total += n + n, err = WriteLPBytes(to[offset:], msg.clientID) + offset += n if err != nil { - return total, err + return offset, err } if msg.willFlag() { // V3.1.1 [MQTT-3.1.3.2] // V5.0 [MQTT-3.1.3.2] - n, err = WriteLPBytes(dst[total:], []byte(msg.will.topic)) - total += n + n, err = WriteLPBytes(to[offset:], []byte(msg.will.topic)) + offset += n if err != nil { - return total, err + return offset, err } // V3.1.1 [MQTT-3.1.3.3] // V5.0 [MQTT-3.1.3.3] - n, err = WriteLPBytes(dst[total:], msg.will.message) - total += n + n, err = WriteLPBytes(to[offset:], msg.will.message) + offset += n if err != nil { - return total, err + return offset, err } } if msg.usernameFlag() { // v3.1.1 [MQTT-3.1.3.4] // v5.0 [MQTT-3.1.3.4] - n, err = WriteLPBytes(dst[total:], msg.username) - total += n + n, err = WriteLPBytes(to[offset:], msg.username) + offset += n if err != nil { - return total, err + return offset, err } } if msg.passwordFlag() { // v3.1.1 [MQTT-3.1.3.5] // v5.0 [MQTT-3.1.3.5] - n, err = WriteLPBytes(dst[total:], msg.password) - total += n + n, err = WriteLPBytes(to[offset:], msg.password) + offset += n if err != nil { - return total, err + return offset, err } } - return total, nil + return offset, nil } -func (msg *Connect) decodeMessage(src []byte) (int, error) { +func (msg *Connect) decodeMessage(from []byte) (int, error) { var err error var n int - total := 0 + offset := 0 var protoName []byte // V3.1.1 [MQTT-3.1.2.1] // V5.0 [MQTT-3.1.2.1] - if protoName, n, err = ReadLPBytes(src[total:]); err != nil { - return total, err + if protoName, n, err = ReadLPBytes(from[offset:]); err != nil { + return offset, err } - total += n + offset += n // V3.1.1 [MQTT-3.1.2-1] // V5.0 [MQTT-3.1.2-1] if !utf8.Valid(protoName) { - return total, ErrProtocolInvalidName + return offset, ErrProtocolInvalidName } // V3.1.1 [MQTT-3.1.2.2] // V5.0 [MQTT-3.1.2.2] - msg.version = ProtocolVersion(src[total]) - total++ + msg.version = ProtocolVersion(from[offset]) + offset++ // V3.1.1 [MQTT-3.1.2-2] // V5.0 [MQTT-3.1.2-2] if verStr, ok := SupportedVersions[msg.version]; !ok { - return total, ErrInvalidProtocolVersion + return offset, ErrInvalidProtocolVersion } else if verStr != string(protoName) { - return total, ErrInvalidProtocolVersion + return offset, ErrInvalidProtocolVersion } // V3.1.1 [MQTT-3.1.2.3] // V5.0 [MQTT-3.1.2.3] - msg.connectFlags = src[total] - total++ + msg.connectFlags = from[offset] + offset++ // V3.1.1 [MQTT-3.1.2-3] // V5.0 [MQTT-3.1.2-3] @@ -368,7 +368,7 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } // V3.1.1 [MQTT-3.1.2-14] @@ -381,7 +381,7 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } if !msg.willFlag() && (msg.willRetain() || (msg.willQos() != QoS0)) { @@ -392,38 +392,39 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } // V3.1.1 [MQTT-3.1.2-22]. if (!msg.usernameFlag() && msg.passwordFlag()) && msg.version < ProtocolV50 { - return total, CodeRefusedBadUsernameOrPassword + return offset, CodeRefusedBadUsernameOrPassword } // V3.1.1 [MQTT-3.1.2.10] // V5.0 [MQTT-3.1.2.10] - msg.keepAlive = binary.BigEndian.Uint16(src[total:]) - total += 2 + msg.keepAlive = binary.BigEndian.Uint16(from[offset:]) + offset += 2 // v5.0 [MQTT-3.1.2.11] specifies properties in variable header if msg.version == ProtocolV50 { - if msg.properties, n, err = decodeProperties(msg.mType, src[total:]); err != nil { - return total + n, err + msg.properties, n, err = decodeProperties(msg.mType, from[offset:]) + offset += n + if err != nil { + return offset, err } - total += n } // V3.1.1 [MQTT-3.1.3.1] - msg.clientID, n, err = ReadLPBytes(src[total:]) - total += n + msg.clientID, n, err = ReadLPBytes(from[offset:]) + offset += n if err != nil { - return total, err + return offset, err } // V3.1.1 [MQTT-3.1.3-7] // If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1 - if (len(msg.clientID) == 0 && !msg.CleanStart()) && msg.version < ProtocolV50 { - return total, CodeRefusedIdentifierRejected + if len(msg.clientID) == 0 && !msg.IsClean() { + return offset, CodeRefusedIdentifierRejected } // The ClientId must contain only characters 0-9, a-z, and A-Z @@ -437,7 +438,7 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedIdentifierRejected } - return total, rejectCode + return offset, rejectCode } if msg.willFlag() { @@ -445,19 +446,19 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { // V5.0 [MQTT-3.1.3.2] var buf []byte - if buf, n, err = ReadLPBytes(src[total:]); err != nil { - return total + n, err + if buf, n, err = ReadLPBytes(from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n msg.will.topic = string(buf) // V3.1.1 [3.1.3.3] // V5.0 [3.1.3.3] - if buf, n, err = ReadLPBytes(src[total:]); err != nil { - return total + n, err + if buf, n, err = ReadLPBytes(from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n msg.will.message = make([]byte, len(buf)) copy(msg.will.message, buf) @@ -469,22 +470,22 @@ func (msg *Connect) decodeMessage(src []byte) (int, error) { // v3.1.1 [MQTT-3.1.3.4] // v5.0 [MQTT-3.1.3.4] if msg.usernameFlag() { - if msg.username, n, err = ReadLPBytes(src[total:]); err != nil { - return total + n, err + if msg.username, n, err = ReadLPBytes(from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } // v3.1.1 [MQTT-3.1.3.5] // v5.0 [MQTT-3.1.3.5] if msg.passwordFlag() { - if msg.password, n, err = ReadLPBytes(src[total:]); err != nil { - return total + n, err + if msg.password, n, err = ReadLPBytes(from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } - return total, nil + return offset, nil } func (msg *Connect) size() int { diff --git a/packet/connect_test.go b/packet/connect_test.go index 2d175f8..5047199 100644 --- a/packet/connect_test.go +++ b/packet/connect_test.go @@ -21,7 +21,7 @@ import ( ) func newTestConnect(t *testing.T, p ProtocolVersion) *Connect { - m, err := NewMessage(p, CONNECT) + m, err := New(p, CONNECT) require.NoError(t, err) msg, ok := m.(*Connect) require.True(t, ok, "Couldn't cast message type") @@ -34,11 +34,11 @@ func TestConnectMessageFields(t *testing.T) { require.Equal(t, ProtocolV31, msg.Version(), "Incorrect version number") - msg.SetCleanStart(true) - require.True(t, msg.CleanStart(), "Error setting clean session flag") + msg.SetClean(true) + require.True(t, msg.IsClean(), "Error setting clean session flag") - msg.SetCleanStart(false) - require.False(t, msg.CleanStart(), "Error setting clean session flag") + msg.SetClean(false) + require.False(t, msg.IsClean(), "Error setting clean session flag") err := msg.SetWill("topic", []byte("message"), QoS1, true) require.NoError(t, err) @@ -447,7 +447,7 @@ func TestConnectMessageEncode(t *testing.T) { err := msg.SetWill("will", []byte("send me home"), QoS1, false) require.NoError(t, err) - msg.SetCleanStart(true) + msg.SetClean(true) err = msg.SetClientID([]byte("volantmq")) require.NoError(t, err) @@ -467,7 +467,7 @@ func TestConnectMessageEncode(t *testing.T) { // V5.0 msgBytes = []byte{ byte(CONNECT << 4), - 63, + 62, 0, // Length MSB (0) 4, // Length LSB (4) 'M', 'Q', 'T', 'T', @@ -477,7 +477,7 @@ func TestConnectMessageEncode(t *testing.T) { 10, // Keep Alive LSB (10) 0, 0, // Client ID MSB (0) - 8, // Client ID LSB (7) + 8, // Client ID LSB (8) 'v', 'o', 'l', 'a', 'n', 't', 'm', 'q', 0, // Will Topic MSB (0) 4, // Will Topic LSB (4) @@ -486,7 +486,7 @@ func TestConnectMessageEncode(t *testing.T) { 12, // Will Message LSB (12) 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', 0, // Username ID MSB (0) - 8, // Username ID LSB (7) + 8, // Username ID LSB (8) 'v', 'o', 'l', 'a', 'n', 't', 'm', 'q', 0, // Password ID MSB (0) 10, // Password ID LSB (10) @@ -498,7 +498,7 @@ func TestConnectMessageEncode(t *testing.T) { err = msg.SetWill("will", []byte("send me home"), QoS1, false) require.NoError(t, err) - msg.SetCleanStart(true) + msg.SetClean(true) err = msg.SetClientID([]byte("volantmq")) require.NoError(t, err) @@ -510,10 +510,10 @@ func TestConnectMessageEncode(t *testing.T) { dst = make([]byte, 100) n, err = msg.Encode(dst) - require.NoError(t, err, "Error decoding message.") + require.NoError(t, err, "Error decoding message") require.Equal(t, ProtocolV50, msg.Version()) - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(msgBytes), n, "Error decoding message") + require.Equal(t, msgBytes, dst[:n], "Error decoding message") } // test to ensure encoding and decoding are the same diff --git a/packet/disconnect.go b/packet/disconnect.go index 9d0784b..d910759 100644 --- a/packet/disconnect.go +++ b/packet/disconnect.go @@ -39,65 +39,68 @@ func (msg *Disconnect) SetReasonCode(c ReasonCode) { } // decode message -func (msg *Disconnect) decodeMessage(src []byte) (int, error) { - total := 0 +func (msg *Disconnect) decodeMessage(from []byte) (int, error) { + offset := 0 if msg.version == ProtocolV50 { + // [MQTT-3.14.2.1] if msg.remLen < 1 { - return total, CodeMalformedPacket + msg.reasonCode = CodeSuccess + return offset, nil } - msg.reasonCode = ReasonCode(src[total]) + msg.reasonCode = ReasonCode(from[offset]) if !msg.reasonCode.IsValidForType(msg.mType) { - return total, CodeProtocolError + return offset, CodeProtocolError } - total++ + offset++ // V5.0 [MQTT-3.14.2.2.1] - if len(src[total:]) < 1 && msg.remLen < 2 { - return total, CodeMalformedPacket + if len(from[offset:]) < 1 && msg.remLen < 2 { + return offset, CodeMalformedPacket } if msg.remLen < 2 { - total++ + offset++ } else { var err error var n int - if msg.properties, n, err = decodeProperties(msg.mType, src[total:]); err != nil { - return total + n, err + if msg.properties, n, err = decodeProperties(msg.mType, from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } } else { if msg.remLen > 0 { - return total, CodeRefusedServerUnavailable + return offset, CodeRefusedServerUnavailable } } - return total, nil + return offset, nil } -func (msg *Disconnect) encodeMessage(dst []byte) (int, error) { - total := 0 +func (msg *Disconnect) encodeMessage(to []byte) (int, error) { + offset := 0 + var err error if msg.version == ProtocolV50 { - dst[total] = byte(msg.reasonCode) - total++ - - var err error - var n int - - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err + pLen := msg.properties.FullLen() + if pLen > 1 || msg.reasonCode != CodeSuccess { + to[offset] = byte(msg.reasonCode) + offset++ + + if pLen > 1 { + var n int + n, err = encodeProperties(msg.properties, to[offset:]) + offset += n + } } - - total += n } - return total, nil + return offset, err } // Len of message @@ -105,8 +108,15 @@ func (msg *Disconnect) size() int { total := 0 if msg.version == ProtocolV50 { - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += 1 + pLen + pLen := msg.properties.FullLen() + // If properties exist (which indicated when pLen > 1) include in body size reason code and properties + // otherwise include only reason code if it differs from CodeSuccess + if pLen > 1 || msg.reasonCode != CodeSuccess { + total++ + if pLen > 1 { + total += int(pLen) + } + } } return total diff --git a/packet/disconnect_test.go b/packet/disconnect_test.go index 26b5b4c..a1e4f1b 100644 --- a/packet/disconnect_test.go +++ b/packet/disconnect_test.go @@ -54,7 +54,7 @@ func TestDisconnectMessageEncode(t *testing.T) { 0, } - msg, err := NewMessage(ProtocolV311, DISCONNECT) + msg, err := New(ProtocolV311, DISCONNECT) require.NoError(t, err) require.NotNil(t, msg) diff --git a/packet/errors.go b/packet/errors.go index 469d39c..2200b73 100644 --- a/packet/errors.go +++ b/packet/errors.go @@ -100,6 +100,8 @@ func (e Error) Error() string { return "Invalid arguments" case ErrInvalidUtf8: return "String is not UTF8" + case ErrInvalidProtocolVersion: + return "Invalid protocol name" } return "Unknown error" diff --git a/packet/header.go b/packet/header.go index 3c96fcf..737e0b0 100644 --- a/packet/header.go +++ b/packet/header.go @@ -38,18 +38,17 @@ const ( ) const ( - maskMessageFlags byte = 0x0F - maskConnFlagUsername byte = 0x80 - maskConnFlagPassword byte = 0x40 - maskConnFlagWillRetain byte = 0x20 - maskConnFlagWillQos byte = 0x18 - maskConnFlagWill byte = 0x04 - maskConnFlagCleanSession byte = 0x02 - maskConnFlagReserved byte = 0x01 - maskPublishFlagRetain byte = 0x01 - maskPublishFlagQoS byte = 0x06 - maskPublishFlagDup byte = 0x08 - + maskMessageFlags byte = 0x0F + maskConnFlagUsername byte = 0x80 + maskConnFlagPassword byte = 0x40 + maskConnFlagWillRetain byte = 0x20 + maskConnFlagWillQos byte = 0x18 + maskConnFlagWill byte = 0x04 + maskConnFlagClean byte = 0x02 + maskConnFlagReserved byte = 0x01 + maskPublishFlagRetain byte = 0x01 + maskPublishFlagQoS byte = 0x06 + maskPublishFlagDup byte = 0x08 maskSubscriptionQoS byte = 0x03 maskSubscriptionNL byte = 0x04 maskSubscriptionRAP byte = 0x08 @@ -99,28 +98,28 @@ func (h *header) ID() (IDType, error) { return IDType(binary.BigEndian.Uint16(h.packetID)), nil } -func (h *header) Encode(dst []byte) (int, error) { +func (h *header) Encode(to []byte) (int, error) { expectedSize, err := h.Size() if err != nil { return 0, err } - if expectedSize > len(dst) { + if expectedSize > len(to) { return expectedSize, ErrInsufficientBufferSize } - total := 0 + offset := 0 - dst[total] = byte(h.mType<<offsetPacketType) | h.mFlags - total++ + to[offset] = byte(h.mType<<offsetPacketType) | h.mFlags + offset++ - total += binary.PutUvarint(dst[total:], uint64(h.remLen)) + offset += binary.PutUvarint(to[offset:], uint64(h.remLen)) var n int - n, err = h.cb.encode(dst[total:]) - total += n - return total, err + n, err = h.cb.encode(to[offset:]) + offset += n + return offset, err } func (h *header) SetVersion(v ProtocolVersion) { @@ -138,9 +137,9 @@ func (h *header) Size() (int, error) { return h.size() + ml, nil } -func (h *header) PropertyGet(id PropertyID) (interface{}, error) { +func (h *header) PropertyGet(id PropertyID) PropertyToType { if h.version != ProtocolV50 { - return nil, ErrNotSupported + return nil } return h.properties.Get(id) @@ -154,11 +153,15 @@ func (h *header) PropertySet(id PropertyID, val interface{}) error { return h.properties.Set(h.mType, id, val) } -func (h *header) PropertyForEach(f func(PropertyID, interface{})) error { +func (h *header) PropertyForEach(f func(PropertyID, PropertyToType)) error { if h.version != ProtocolV50 { return ErrNotSupported } + if h.properties == nil { + return ErrNotSet + } + h.properties.ForEach(f) return nil @@ -223,13 +226,13 @@ func (h *header) setType(t Type) { // decode reads fixed header and remaining length // if decode successful size of decoded data provided // if error happened offset points to error place -func (h *header) decode(src []byte) (int, error) { - total := 0 +func (h *header) decode(from []byte) (int, error) { + offset := 0 // decode and validate fixed header //h.mTypeFlags = src[total] - h.mType = Type(src[total] >> offsetPacketType) - h.mFlags = src[total] & maskMessageFlags + h.mType = Type(from[offset] >> offsetPacketType) + h.mFlags = from[offset] & maskMessageFlags reject := false // [MQTT-2.2.2-1] @@ -246,33 +249,33 @@ func (h *header) decode(src []byte) (int, error) { if h.version == ProtocolV50 { rejectCode = CodeMalformedPacket } - return total, rejectCode + return offset, rejectCode } - total++ + offset++ - remLen, m := uvarint(src[total:]) + remLen, m := uvarint(from[offset:]) if m <= 0 { - return total, ErrInsufficientDataSize + return offset, ErrInsufficientDataSize } - total += m + offset += m h.remLen = int32(remLen) // verify if buffer has enough space for whole message // if not return expected size - if int(h.remLen) > len(src[total:]) { - return total + int(h.remLen), ErrInsufficientDataSize + if int(h.remLen) > len(from[offset:]) { + return offset + int(h.remLen), ErrInsufficientDataSize } var err error if h.cb.decode != nil { var msgTotal int - msgTotal, err = h.cb.decode(src[total:]) - total += msgTotal + msgTotal, err = h.cb.decode(from[offset:]) + offset += msgTotal } - return total, err + return offset, err } // uvarint decodes a uint32 from buf and returns that value and the diff --git a/packet/packet.go b/packet/packet.go index efe5d79..c0971fb 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -16,6 +16,18 @@ const ( maskConnAckSessionPresent byte = 0x01 ) +// RetainHandling describe how retained messages are handled during subscribe +type RetainHandling uint8 + +const ( + // RetainHandlingRetain publish retained messages on subscribe + RetainHandlingRetain RetainHandling = iota + // RetainHandlingIfNotExists publish retained messages on subscribe only when it's new subscription to given topic + RetainHandlingIfNotExists + // RetainHandlingDoNotRetain do not publish retained messages on subscribe + RetainHandlingDoNotRetain +) + // SubscriptionOptions as per [MQTT-3.8.3.1] type SubscriptionOptions byte @@ -52,8 +64,8 @@ func (s SubscriptionOptions) RAP() bool { // 1 = Send retained messages at subscribe only if the subscription does not currently exist // 2 = Do not send retained messages at the time of the subscribe // V5.0 ONLY -func (s SubscriptionOptions) RetainHandling() byte { - return (byte(s) & maskSubscriptionRetainHandling) >> offsetSubscriptionRetainHandling +func (s SubscriptionOptions) RetainHandling() RetainHandling { + return RetainHandling((byte(s) & maskSubscriptionRetainHandling) >> offsetSubscriptionRetainHandling) } // Provider is an interface defined for all MQTT message types. @@ -87,11 +99,11 @@ type Provider interface { // Version get protocol version used by message Version() ProtocolVersion - PropertyGet(PropertyID) (interface{}, error) + PropertyGet(PropertyID) PropertyToType PropertySet(PropertyID, interface{}) error - PropertyForEach(func(PropertyID, interface{})) error + PropertyForEach(func(PropertyID, PropertyToType)) error // decode reads the bytes in the byte slice from the argument. It returns the // total number of bytes decoded, and whether there's any errors during the @@ -120,10 +132,22 @@ type Provider interface { setType(t Type) } -// NewMessage creates a new message based on the message type. It is a shortcut to call +// New creates a new message based on the message type. It is a shortcut to call // one of the New*Message functions. If an error is returned then the message type // is invalid. -func NewMessage(v ProtocolVersion, t Type) (Provider, error) { +func New(v ProtocolVersion, t Type) (Provider, error) { + m, err := newMessage(v, t) + if err == nil { + h := m.getHeader() + if v == ProtocolV50 && (t != PINGREQ && t != PINGRESP) { + h.properties = newProperty() + } + } + + return m, err +} + +func newMessage(v ProtocolVersion, t Type) (Provider, error) { var m Provider switch t { @@ -206,8 +230,8 @@ func Decode(v ProtocolVersion, buf []byte) (msg Provider, total int, err error) // [MQTT-2.2] mType := Type(buf[0] >> offsetPacketType) - // [MQTT-2.2.1] Type.NewMessage validates message type - if msg, err = NewMessage(v, mType); err != nil { + // [MQTT-2.2.1] Type.New validates message type + if msg, err = New(v, mType); err != nil { return nil, 0, err } @@ -222,8 +246,7 @@ func Decode(v ProtocolVersion, buf []byte) (msg Provider, total int, err error) // considered valid if it's longer than 0 bytes, and doesn't contain any wildcard characters // such as + and #. func ValidTopic(topic string) bool { - return len(topic) > 0 && - utf8.Valid([]byte(topic)) && + return utf8.Valid([]byte(topic)) && !strings.Contains(topic, "#") && !strings.Contains(topic, "+") } diff --git a/packet/packetType_test.go b/packet/packetType_test.go index 1940698..4207b37 100644 --- a/packet/packetType_test.go +++ b/packet/packetType_test.go @@ -30,26 +30,26 @@ func TestMessageTypeValid(t *testing.T) { func TestMessageTypeNewMessage(t *testing.T) { tp := RESERVED - msg, err := NewMessage(ProtocolV311, tp) + msg, err := New(ProtocolV311, tp) require.EqualError(t, ErrInvalidMessageType, err.Error()) require.Nil(t, msg) tp = AUTH - msg, err = NewMessage(ProtocolV311, tp) + msg, err = New(ProtocolV311, tp) require.EqualError(t, ErrInvalidMessageType, err.Error()) require.Nil(t, msg) - msg, err = NewMessage(ProtocolV50, tp) + msg, err = New(ProtocolV50, tp) require.NoError(t, err) require.NotNil(t, msg) tp = Type(143) - msg, err = NewMessage(ProtocolV50, tp) + msg, err = New(ProtocolV50, tp) require.EqualError(t, ErrInvalidMessageType, err.Error()) require.Nil(t, msg) tp = CONNACK - msg, err = NewMessage(ProtocolV50, tp) + msg, err = New(ProtocolV50, tp) require.NoError(t, err) require.Equal(t, CONNACK, msg.Type()) } diff --git a/packet/ping_test.go b/packet/ping_test.go index 3610df0..c3feb87 100644 --- a/packet/ping_test.go +++ b/packet/ping_test.go @@ -41,7 +41,7 @@ func TestPingReqMessageEncode(t *testing.T) { 0, } - m, err := NewMessage(ProtocolV311, PINGREQ) + m, err := New(ProtocolV311, PINGREQ) require.NoError(t, err) msg, ok := m.(*PingReq) @@ -74,7 +74,7 @@ func TestPingRespMessageEncode(t *testing.T) { 0, } - m, err := NewMessage(ProtocolV311, PINGRESP) + m, err := New(ProtocolV311, PINGRESP) require.NoError(t, err) msg, ok := m.(*PingResp) diff --git a/packet/property.go b/packet/property.go index ad9c8fb..dfcf189 100644 --- a/packet/property.go +++ b/packet/property.go @@ -1,6 +1,9 @@ package packet -import "encoding/binary" +import ( + "encoding/binary" + "unicode/utf8" +) // PropertyID id as per [MQTT-2.2.2] type PropertyID uint32 @@ -19,6 +22,7 @@ const ( ErrPropertyTypeMismatch ErrPropertyDuplicate ErrPropertyUnsupported + ErrPropertyWrongType ) // Error @@ -36,15 +40,96 @@ func (e PropertyError) Error() string { return "property: duplicate of id not allowed" case ErrPropertyUnsupported: return "property: value type is unsupported" + case ErrPropertyWrongType: + return "property: value type differs from expected" default: return "property: unknown error" } } -// KVPair user defined properties -type KVPair struct { - Key string - Value string +// StringPair user defined properties +type StringPair struct { + K string + V string +} + +// PropertyToType represent property value as requested type +type PropertyToType interface { + Type() PropertyType + AsByte() (byte, error) + AsShort() (uint16, error) + AsInt() (uint32, error) + AsString() (string, error) + AsStringPair() (StringPair, error) + AsStringPairs() ([]StringPair, error) + AsBinary() ([]byte, error) +} + +type propertyToType struct { + t PropertyType + v interface{} +} + +var _ PropertyToType = (*propertyToType)(nil) + +func (t *propertyToType) Type() PropertyType { + return t.t +} + +func (t *propertyToType) AsByte() (byte, error) { + if v, ok := t.v.(byte); ok { + return v, nil + } + + return 0, ErrPropertyWrongType +} + +func (t *propertyToType) AsShort() (uint16, error) { + if v, ok := t.v.(uint16); ok { + return v, nil + } + + return 0, ErrPropertyWrongType +} + +func (t *propertyToType) AsInt() (uint32, error) { + if v, ok := t.v.(uint32); ok { + return v, nil + } + + return 0, ErrPropertyWrongType +} + +func (t *propertyToType) AsString() (string, error) { + if v, ok := t.v.(string); ok { + return v, nil + } + + return "", ErrPropertyWrongType +} + +func (t *propertyToType) AsStringPair() (StringPair, error) { + if v, ok := t.v.(StringPair); ok { + return v, nil + } + + return StringPair{}, ErrPropertyWrongType +} + +func (t *propertyToType) AsStringPairs() ([]StringPair, error) { + if v, ok := t.v.([]StringPair); ok { + return v, nil + } + + return []StringPair{}, ErrPropertyWrongType +} + +func (t *propertyToType) AsBinary() ([]byte, error) { + if v, ok := t.v.([]byte); ok { + return v, nil + } + + return []byte{}, ErrPropertyWrongType } // property implements Property @@ -96,22 +181,31 @@ const ( ) var propertyAllowedMessageTypes = map[PropertyID]map[Type]bool{ - PropertyPayloadFormat: {PUBLISH: false}, - PropertyPublicationExpiry: {PUBLISH: false}, - PropertyContentType: {PUBLISH: false}, - PropertyResponseTopic: {PUBLISH: false}, - PropertyCorrelationData: {PUBLISH: false}, - PropertySubscriptionIdentifier: {PUBLISH: true, SUBSCRIBE: false}, - PropertySessionExpiryInterval: {CONNECT: false, DISCONNECT: false}, - PropertyAssignedClientIdentifier: {CONNACK: false}, - PropertyServerKeepAlive: {CONNACK: false}, - PropertyAuthMethod: {CONNECT: false, CONNACK: false, AUTH: false}, - PropertyAuthData: {CONNECT: false, CONNACK: false, AUTH: false}, - PropertyRequestProblemInfo: {CONNECT: false}, - PropertyWillDelayInterval: {CONNECT: false}, - PropertyRequestResponseInfo: {CONNECT: false}, - PropertyResponseInfo: {CONNACK: false}, - PropertyServerReverence: {CONNACK: false, DISCONNECT: false}, + PropertyPayloadFormat: {PUBLISH: false}, + PropertyPublicationExpiry: {PUBLISH: false}, + PropertyContentType: {PUBLISH: false}, + PropertyResponseTopic: {PUBLISH: false}, + PropertyCorrelationData: {PUBLISH: false}, + PropertySubscriptionIdentifier: {PUBLISH: true, SUBSCRIBE: false}, + PropertySessionExpiryInterval: {CONNECT: false, DISCONNECT: false}, + PropertyAssignedClientIdentifier: {CONNACK: false}, + PropertyServerKeepAlive: {CONNACK: false}, + PropertyAuthMethod: {CONNECT: false, CONNACK: false, AUTH: false}, + PropertyAuthData: {CONNECT: false, CONNACK: false, AUTH: false}, + PropertyWillDelayInterval: {CONNECT: false}, + PropertyRequestProblemInfo: {CONNECT: false}, + PropertyRequestResponseInfo: {CONNECT: false}, + PropertyResponseInfo: {CONNACK: false}, + PropertyServerReverence: {CONNACK: false, DISCONNECT: false}, + PropertyReceiveMaximum: {CONNECT: false, CONNACK: false}, + PropertyTopicAliasMaximum: {CONNECT: false, CONNACK: false}, + PropertyTopicAlias: {PUBLISH: false}, + PropertyMaximumQoS: {CONNACK: false}, + PropertyRetainAvailable: {CONNACK: false}, + PropertyMaximumPacketSize: {CONNECT: false, CONNACK: false}, + PropertyWildcardSubscriptionAvailable: {CONNACK: false}, + PropertySubscriptionIdentifierAvailable: {CONNACK: false}, + PropertySharedSubscriptionAvailable: {CONNACK: false}, PropertyReasonString: { CONNACK: false, PUBACK: false, @@ -122,62 +216,20 @@ var propertyAllowedMessageTypes = map[PropertyID]map[Type]bool{ UNSUBACK: false, DISCONNECT: false, AUTH: false}, - PropertyReceiveMaximum: {CONNECT: false, CONNACK: false}, - PropertyTopicAliasMaximum: {CONNECT: false, CONNACK: false}, - PropertyTopicAlias: {PUBLISH: false}, - PropertyMaximumQoS: {CONNACK: false}, - PropertyRetainAvailable: {CONNACK: false}, PropertyUserProperty: { - CONNECT: false, - CONNACK: false, - PUBLISH: false, - PUBACK: false, - PUBREC: false, - PUBREL: false, - PUBCOMP: false, - SUBACK: false, - UNSUBACK: false, - DISCONNECT: false, - AUTH: false}, - PropertyMaximumPacketSize: {CONNECT: false, CONNACK: false}, - PropertyWildcardSubscriptionAvailable: {CONNACK: false}, - PropertySubscriptionIdentifierAvailable: {CONNACK: false}, - PropertySharedSubscriptionAvailable: {CONNACK: false}, + CONNECT: true, + CONNACK: true, + PUBLISH: true, + PUBACK: true, + PUBREC: true, + PUBREL: true, + PUBCOMP: true, + SUBACK: true, + UNSUBACK: true, + DISCONNECT: true, + AUTH: true}, } -//var propertyTypeMap = map[PropertyID]struct { -// p PropertyType -// n interface{} -//}{ -// PropertyPayloadFormat: {p: PropertyTypeByte, n: uint8(0)}, -// PropertyPublicationExpiry: {p: PropertyTypeInt, n: uint32(0)}, -// PropertyContentType: {p: PropertyTypeString, n: ""}, -// PropertyResponseTopic: {p: PropertyTypeString, n: ""}, -// PropertyCorrelationData: {p: PropertyTypeBinary, n: []byte{}}, -// PropertySubscriptionIdentifier: {p: PropertyTypeVarInt, n: uint32(0)}, -// PropertySessionExpiryInterval: {p: PropertyTypeInt, n: uint32(0)}, -// PropertyAssignedClientIdentifier: {p: PropertyTypeString, n: ""}, -// PropertyServerKeepAlive: {p: PropertyTypeShort, n: uint16(0)}, -// PropertyAuthMethod: {p: PropertyTypeString, n: ""}, -// PropertyAuthData: {p: PropertyTypeBinary, n: []byte{}}, -// PropertyRequestProblemInfo: {p: PropertyTypeByte, n: uint8(0)}, -// PropertyWillDelayInterval: {p: PropertyTypeInt, n: uint32(0)}, -// PropertyRequestResponseInfo: {p: PropertyTypeByte, n: uint8(0)}, -// PropertyResponseInfo: {p: PropertyTypeString, n: ""}, -// PropertyServerReverence: {p: PropertyTypeString, n: ""}, -// PropertyReasonString: {p: PropertyTypeString, n: ""}, -// PropertyReceiveMaximum: {p: PropertyTypeShort, n: uint16(0)}, -// PropertyTopicAliasMaximum: {p: PropertyTypeShort, n: uint16(0)}, -// PropertyTopicAlias: {p: PropertyTypeShort, n: uint16(0)}, -// PropertyMaximumQoS: {p: PropertyTypeByte, n: uint8(0)}, -// PropertyRetainAvailable: {p: PropertyTypeByte, n: uint8(0)}, -// PropertyUserProperty: {p: PropertyTypeString, n: ""}, -// PropertyMaximumPacketSize: {p: PropertyTypeInt, n: uint32(0)}, -// PropertyWildcardSubscriptionAvailable: {p: PropertyTypeByte, n: uint8(0)}, -// PropertySubscriptionIdentifierAvailable: {p: PropertyTypeByte, n: uint8(0)}, -// PropertySharedSubscriptionAvailable: {p: PropertyTypeByte, n: uint8(0)}, -//} - var propertyTypeMap = map[PropertyID]PropertyType{ PropertyPayloadFormat: PropertyTypeByte, PropertyPublicationExpiry: PropertyTypeInt, @@ -201,41 +253,41 @@ var propertyTypeMap = map[PropertyID]PropertyType{ PropertyTopicAlias: PropertyTypeShort, PropertyMaximumQoS: PropertyTypeByte, PropertyRetainAvailable: PropertyTypeByte, - PropertyUserProperty: PropertyTypeString, + PropertyUserProperty: PropertyTypeStringPair, PropertyMaximumPacketSize: PropertyTypeInt, PropertyWildcardSubscriptionAvailable: PropertyTypeByte, PropertySubscriptionIdentifierAvailable: PropertyTypeByte, PropertySharedSubscriptionAvailable: PropertyTypeByte, } -var propertyTypeDup = map[PropertyID]bool{ - PropertyPayloadFormat: false, - PropertyPublicationExpiry: false, - PropertyContentType: false, - PropertyResponseTopic: false, - PropertyCorrelationData: false, - PropertySubscriptionIdentifier: false, - PropertySessionExpiryInterval: false, - PropertyAssignedClientIdentifier: false, - PropertyServerKeepAlive: false, - PropertyAuthMethod: false, - PropertyAuthData: false, - PropertyRequestProblemInfo: false, - PropertyWillDelayInterval: false, - PropertyRequestResponseInfo: false, - PropertyResponseInfo: false, - PropertyServerReverence: false, - PropertyReasonString: false, - PropertyReceiveMaximum: false, - PropertyTopicAliasMaximum: false, - PropertyTopicAlias: false, - PropertyMaximumQoS: false, - PropertyRetainAvailable: false, - PropertyUserProperty: true, - PropertyMaximumPacketSize: false, - PropertyWildcardSubscriptionAvailable: false, - PropertySubscriptionIdentifierAvailable: false, - PropertySharedSubscriptionAvailable: false, +var propertyDecodeType = map[PropertyType]func(*property, PropertyID, []byte) (int, error){ + PropertyTypeByte: decodeByte, + PropertyTypeShort: decodeShort, + PropertyTypeInt: decodeInt, + PropertyTypeVarInt: decodeVarInt, + PropertyTypeString: decodeString, + PropertyTypeStringPair: decodeStringPair, + PropertyTypeBinary: decodeBinary, +} + +var propertyEncodeType = map[PropertyType]func(id PropertyID, val interface{}, to []byte) (int, error){ + PropertyTypeByte: encodeByte, + PropertyTypeShort: encodeShort, + PropertyTypeInt: encodeInt, + PropertyTypeVarInt: encodeVarInt, + PropertyTypeString: encodeString, + PropertyTypeStringPair: encodeStringPair, + PropertyTypeBinary: encodeBinary, +} + +var propertyCalcLen = map[PropertyType]func(id PropertyID, val interface{}) (int, error){ + PropertyTypeByte: calcLenByte, + PropertyTypeShort: calcLenShort, + PropertyTypeInt: calcLenInt, + PropertyTypeVarInt: calcLenVarInt, + PropertyTypeString: calcLenString, + PropertyTypeStringPair: calcLenStringPair, + PropertyTypeBinary: calcLenBinary, } func newProperty() *property { @@ -247,13 +299,13 @@ func newProperty() *property { } // DupAllowed check if property id allows keys duplication -func (p PropertyID) DupAllowed() bool { - d, ok := propertyTypeDup[p] +func (p PropertyID) DupAllowed(t Type) bool { + d, ok := propertyAllowedMessageTypes[p] if !ok { return false } - return d + return d[t] } // IsValid check if property id is valid spec value @@ -297,57 +349,36 @@ func (p *property) Set(t Type, id PropertyID, val interface{}) error { return ErrPropertyPacketTypeMismatch } - // Todo: check type allowed for id - - // calculate property size - switch valueType := val.(type) { - case uint8: - p.len++ - case uint16: - p.len += 2 - case uint32: - p.len += 4 - case string: - p.len += uint32(len(valueType)) - case []string: - for i := range valueType { - p.len += uint32(len(valueType[i])) - } - case []byte: - p.len += uint32(len(valueType)) - case [][]byte: - for i := range valueType { - p.len += uint32(len(valueType[i])) - } - case []uint16: - p.len += uint32(len(valueType)) - case []uint32: - p.len += uint32(len(valueType)) - default: - return ErrPropertyUnsupported - } - + fn := propertyCalcLen[propertyTypeMap[id]] + l, _ := fn(id, val) + p.len += uint32(l) p.properties[id] = val return nil } // Get property value -func (p *property) Get(id PropertyID) (interface{}, error) { - if p, ok := p.properties[id]; ok { - return p, nil +func (p *property) Get(id PropertyID) PropertyToType { + if val, ok := p.properties[id]; ok { + t := propertyTypeMap[id] + return &propertyToType{v: val, t: t} } - return nil, ErrPropertyNotFound + return nil } // ForEach iterate over existing properties -func (p *property) ForEach(f func(PropertyID, interface{})) { +func (p *property) ForEach(f func(PropertyID, PropertyToType)) { for k, v := range p.properties { - f(k, v) + t := propertyTypeMap[k] + f(k, &propertyToType{v: v, t: t}) } } +func writePrefixID(id PropertyID, b []byte) int { + return binary.PutUvarint(b, uint64(id)) +} + func decodeProperties(t Type, buf []byte) (*property, int, error) { p := newProperty() @@ -356,11 +387,6 @@ func decodeProperties(t Type, buf []byte) (*property, int, error) { return nil, total, err } - // If properties are empty return only size of decoded property header - if len(p.properties) == 0 { - return nil, total, nil - } - return p, total, nil } @@ -368,7 +394,6 @@ func encodeProperties(p *property, dst []byte) (int, error) { if p == nil { if len(dst) > 0 { dst[0] = 0 - } return 1, nil } @@ -376,279 +401,505 @@ func encodeProperties(p *property, dst []byte) (int, error) { return p.encode(dst) } -func (p *property) decode(t Type, buf []byte) (int, error) { - total := 0 +func (p *property) decode(t Type, from []byte) (int, error) { + offset := 0 // property length is encoded as variable byte integer - pLen, lCount := uvarint(buf) + pLen, lCount := uvarint(from) if lCount <= 0 { - return 0, CodeMalformedPacket + offset += -lCount + return offset, CodeMalformedPacket } - total += lCount + offset += lCount + + var err error for pLen != 0 { - pidVal, pidCount := uvarint(buf[total:]) - if pidCount <= 0 { - return total, CodeMalformedPacket - } + slice := from[offset:] - total += pidCount + idVal, count := uvarint(slice) + if count <= 0 { + return offset - count, CodeMalformedPacket + } - id := PropertyID(pidVal) + id := PropertyID(idVal) if !id.IsValidPacketType(t) { - return total, CodeMalformedPacket + return offset, CodeMalformedPacket } - if _, ok := p.properties[id]; ok && !id.DupAllowed() { - return total, CodeProtocolError + if _, ok := p.properties[id]; ok && !id.DupAllowed(t) { + return offset, CodeProtocolError } - total++ + if decodeFunc, ok := propertyDecodeType[propertyTypeMap[id]]; ok { + var decodeCount int + decodeCount, err = decodeFunc(p, id, slice[count:]) + count += decodeCount + offset += count + if err != nil { + return offset, err + } + } else { + return offset, CodeProtocolError + } - count := 0 + p.len += uint32(count) + pLen -= uint32(count) + } - switch propertyTypeMap[id] { - case PropertyTypeByte: - if len(buf[total+count:]) < 1 { - return total + count, CodeMalformedPacket - } + return offset, nil +} - v := buf[total+count] - count++ - if _, ok := p.properties[id]; ok { - return total + count, CodeMalformedPacket - } - if _, ok := p.properties[id]; !ok { - p.properties[id] = []byte{v} - } else { - p.properties[id] = append(p.properties[id].([]byte), v) - } - case PropertyTypeShort: - if len(buf[total+count:]) < 2 { - return total + count, CodeMalformedPacket - } +func (p *property) encode(to []byte) (int, error) { + pLen := p.FullLen() + if int(pLen) > len(to) { + return 0, ErrInsufficientBufferSize + } - v := binary.BigEndian.Uint16(buf[total+count:]) - count += 2 + if pLen == 1 { + return 1, nil + } - if _, ok := p.properties[id]; !ok { - p.properties[id] = []uint16{v} - } else { - p.properties[id] = append(p.properties[id].([]uint16), v) - } - case PropertyTypeInt: - if len(buf[total+count:]) < 4 { - return total + count, CodeMalformedPacket - } + var offset int + var err error + // Encode variable length header + total := binary.PutUvarint(to, uint64(p.len)) - v := binary.BigEndian.Uint32(buf[total:]) - count += 4 + for k, v := range p.properties { + fn := propertyEncodeType[propertyTypeMap[k]] + offset, err = fn(k, v, to[total:]) + total += offset - if _, ok := p.properties[id]; !ok { - p.properties[id] = []uint32{v} - } else { - p.properties[id] = append(p.properties[id].([]uint32), v) - } - case PropertyTypeVarInt: - v, cnt := uvarint(buf[total+count:]) - if cnt <= 0 { - return total + count, CodeMalformedPacket - } - count += cnt + if err != nil { + break + } + } - if _, ok := p.properties[id]; !ok { - p.properties[id] = []uint32{v} - } else { - p.properties[id] = append(p.properties[id].([]uint32), v) - } - case PropertyTypeString: - v, n, err := ReadLPBytes(buf[total+count:]) - if err != nil { - return total + count, CodeMalformedPacket - } - count += n + return total, err +} - if _, ok := p.properties[id]; !ok { - p.properties[id] = []string{string(v)} - } else { - p.properties[id] = append(p.properties[id].([]string), string(v)) - } - case PropertyTypeStringPair: - k, n, err := ReadLPBytes(buf[total+count:]) - if err != nil { - return total + count, CodeMalformedPacket - } - count += n +func calcLenByte(id PropertyID, val interface{}) (int, error) { + l := 0 + calc := func() int { + return 1 + uvarintCalc(uint32(id)) + } - v, n, err := ReadLPBytes(buf[total+count:]) - if err != nil { - return total + count, CodeMalformedPacket - } - count += n + switch valueType := val.(type) { + case uint8: + l = calc() + case []uint8: + for range valueType { + l += calc() + } + default: + return 0, nil + } - pair := KVPair{ - Key: string(k), - Value: string(v), - } + return l, nil +} - if _, ok := p.properties[id]; !ok { - p.properties[id] = []KVPair{pair} - } else { - p.properties[id] = append(p.properties[id].([]KVPair), pair) - } - case PropertyTypeBinary: - b, n, err := ReadLPBytes(buf[total+count:]) - if err != nil { - return total + count, CodeMalformedPacket - } - count += n +func calcLenShort(id PropertyID, val interface{}) (int, error) { + l := 0 - tmp := make([]byte, len(b)) + calc := func() int { + return 2 + uvarintCalc(uint32(id)) + } - copy(tmp, b) - if _, ok := p.properties[id]; !ok { - p.properties[id] = [][]byte{tmp} - } else { - p.properties[id] = append(p.properties[id].([][]byte), tmp) - } + switch valueType := val.(type) { + case uint16: + l = calc() + case []uint16: + for range valueType { + l += calc() } + default: + return 0, nil + } - p.len += uint32(count) - pLen -= uint32(count) - total += count + return l, nil +} + +func calcLenInt(id PropertyID, val interface{}) (int, error) { + l := 0 + + calc := func() int { + return 4 + uvarintCalc(uint32(id)) + } + + switch valueType := val.(type) { + case uint32: + l = calc() + case []uint32: + for range valueType { + l += calc() + } + default: + return 0, nil } - return total, nil + return l, nil } -func (p *property) encode(buf []byte) (int, error) { - pLen, pSizeCount := p.Len() - if int(pLen)+pSizeCount > len(buf) { - return 0, ErrInsufficientBufferSize +func calcLenVarInt(id PropertyID, val interface{}) (int, error) { + l := 0 + + calc := func(v uint32) int { + return uvarintCalc(v) + uvarintCalc(uint32(id)) } - total := 0 + switch valueType := val.(type) { + case uint32: + l = calc(valueType) + case []uint32: + for _, v := range valueType { + l += calc(v) + } + default: + return 0, nil + } - // Encode variable length header - total += binary.PutUvarint(buf, uint64(p.len)) + return l, nil +} - writePrefixID := func(id PropertyID, b []byte) int { - offset := 0 - b[offset] = byte(id) - offset++ +func calcLenString(id PropertyID, val interface{}) (int, error) { + l := 0 - return offset + calc := func(n int) int { + return 2 + n + uvarintCalc(uint32(id)) } - for k, v := range p.properties { - switch propertyTypeMap[k] { - case PropertyTypeByte: - switch valueType := v.(type) { - case uint8: - total += writePrefixID(k, buf) - buf[total] = valueType - total++ - case []uint8: - for i := range valueType { - total += writePrefixID(k, buf) - buf[total] = valueType[i] - total++ - } - } - case PropertyTypeShort: - switch valueType := v.(type) { - case uint16: - total += writePrefixID(k, buf) - binary.BigEndian.PutUint16(buf[total:], valueType) - total += 2 - case []uint16: - for i := range valueType { - total += writePrefixID(k, buf) - binary.BigEndian.PutUint16(buf[total:], valueType[i]) - total += 2 - } - } - case PropertyTypeInt: - switch valueType := v.(type) { - case uint32: - total += writePrefixID(k, buf) - binary.BigEndian.PutUint32(buf[total:], valueType) - total += 4 - case []uint32: - for i := range valueType { - total += writePrefixID(k, buf) - binary.BigEndian.PutUint32(buf[total:], valueType[i]) - total += 4 - } - } - case PropertyTypeVarInt: - switch valueType := v.(type) { - case uint32: - total += writePrefixID(k, buf) - total += binary.PutUvarint(buf[total:], uint64(valueType)) - case []uint32: - for i := range valueType { - total += writePrefixID(k, buf) - total += binary.PutUvarint(buf[total:], uint64(valueType[i])) - } - } - case PropertyTypeString: - switch valueType := v.(type) { - case string: - total += writePrefixID(k, buf) - total += copy(buf[total:], []byte(valueType)) - case []string: - for i := range valueType { - total += writePrefixID(k, buf) - total += copy(buf[total:], []byte(valueType[i])) - } - } - case PropertyTypeStringPair: - switch valueType := v.(type) { - case KVPair: - total += writePrefixID(k, buf) - n, err := WriteLPBytes(buf[total:], []byte(valueType.Key)) - if err != nil { - return total, err - } - total += n - - n, err = WriteLPBytes(buf[total:], []byte(valueType.Key)) - if err != nil { - return total, err - } - total += n - case []KVPair: - for i := range valueType { - total += writePrefixID(k, buf) - n, err := WriteLPBytes(buf[total:], []byte(valueType[i].Key)) - if err != nil { - return total, err - } - total += n - - n, err = WriteLPBytes(buf[total:], []byte(valueType[i].Key)) - if err != nil { - return total, err - } - total += n - - } - } - case PropertyTypeBinary: - switch valueType := v.(type) { - case []byte: - total += writePrefixID(k, buf) - total += copy(buf[total:], valueType) - case [][]byte: - for i := range valueType { - total += writePrefixID(k, buf) - total += copy(buf[total:], valueType[i]) - } - } + switch valueType := val.(type) { + case string: + l = calc(len(valueType)) + case []string: + for _, v := range valueType { + l += calc(len(v)) + } + default: + return 0, nil + } + + return l, nil +} + +func calcLenBinary(id PropertyID, val interface{}) (int, error) { + l := 0 + + calc := func(n int) int { + return 2 + n + uvarintCalc(uint32(id)) + } + + switch valueType := val.(type) { + case []byte: + l = calc(len(valueType)) + case [][]string: + for _, v := range valueType { + l += calc(len(v)) + } + default: + return 0, nil + } + + return l, nil +} + +func calcLenStringPair(id PropertyID, val interface{}) (int, error) { + l := 0 + + calc := func(k, v int) int { + return 4 + k + v + uvarintCalc(uint32(id)) + } + + switch valueType := val.(type) { + case StringPair: + l = calc(len(valueType.K), len(valueType.V)) + case []StringPair: + for _, v := range valueType { + l += calc(len(v.K), len(v.V)) + } + default: + return 0, nil + } + + return l, nil +} + +func decodeByte(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + if len(from[offset:]) < 1 { + return offset, CodeMalformedPacket + } + + p.properties[id] = from[offset] + offset++ + + return offset, nil +} + +func decodeShort(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + if len(from[offset:]) < 2 { + return offset, CodeMalformedPacket + } + + v := binary.BigEndian.Uint16(from[offset:]) + offset += 2 + + p.properties[id] = v + + return offset, nil +} + +func decodeInt(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + if len(from[offset:]) < 4 { + return offset, CodeMalformedPacket + } + + v := binary.BigEndian.Uint32(from[offset:]) + offset += 4 + + p.properties[id] = v + + return offset, nil +} + +func decodeVarInt(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + + v, cnt := uvarint(from[offset:]) + if cnt <= 0 { + return offset, CodeMalformedPacket + } + offset += cnt + + p.properties[id] = v + + return offset, nil +} + +func decodeString(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + + v, n, err := ReadLPBytes(from[offset:]) + if err != nil || !utf8.Valid(v) { + return offset, CodeMalformedPacket + } + + offset += n + + p.properties[id] = string(v) + + return offset, nil +} + +func decodeStringPair(p *property, id PropertyID, from []byte) (int, error) { + var k []byte + var v []byte + var n int + var err error + + k, n, err = ReadLPBytes(from) + offset := n + if err != nil || !utf8.Valid(k) { + return offset, CodeMalformedPacket + } + + v, n, err = ReadLPBytes(from[offset:]) + offset += n + + if err != nil || !utf8.Valid(v) { + return offset, CodeMalformedPacket + } + + if _, ok := p.properties[id]; !ok { + p.properties[id] = []StringPair{} + } + + p.properties[id] = append(p.properties[id].([]StringPair), StringPair{K: string(k), V: string(v)}) + + return offset, nil +} + +func decodeBinary(p *property, id PropertyID, from []byte) (int, error) { + offset := 0 + + b, n, err := ReadLPBytes(from[offset:]) + if err != nil { + return offset, CodeMalformedPacket + } + offset += n + + tmp := make([]byte, len(b)) + + copy(tmp, b) + + p.properties[id] = tmp + + return offset, nil +} + +func encodeByte(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v uint8, to []byte) int { + off := writePrefixID(id, to) + + to[off] = v + off++ + + return off + } + + switch valueType := val.(type) { + case uint8: + offset += encode(valueType, to[offset:]) + case []uint8: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + + return offset, nil +} + +func encodeShort(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v uint16, to []byte) int { + off := writePrefixID(id, to) + binary.BigEndian.PutUint16(to[off:], v) + off += 2 + + return off + } + + switch valueType := val.(type) { + case uint16: + offset += encode(valueType, to[offset:]) + case []uint16: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + + return offset, nil +} + +func encodeInt(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v uint32, to []byte) int { + off := writePrefixID(id, to) + binary.BigEndian.PutUint32(to[off:], v) + off += 4 + + return off + } + + switch valueType := val.(type) { + case uint32: + offset += encode(valueType, to[offset:]) + case []uint32: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + + return offset, nil +} + +func encodeVarInt(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v uint32, to []byte) int { + off := writePrefixID(id, to) + off += binary.PutUvarint(to[off:], uint64(v)) + + return off + } + + switch valueType := val.(type) { + case uint32: + offset += encode(valueType, to[offset:]) + case []uint32: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + + return offset, nil +} + +func encodeString(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v string, to []byte) int { + off := writePrefixID(id, to) + count, _ := WriteLPBytes(to[off:], []byte(v)) + off += count + + return off + } + + switch valueType := val.(type) { + case string: + offset += encode(valueType, to[offset:]) + case []string: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + + return offset, nil +} + +func encodeStringPair(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v StringPair, to []byte) int { + off := writePrefixID(id, to) + + n, _ := WriteLPBytes(to[off:], []byte(v.K)) + off += n + + n, _ = WriteLPBytes(to[off:], []byte(v.V)) + off += n + + return off + } + + switch valueType := val.(type) { + case StringPair: + offset += encode(valueType, to[offset:]) + case []StringPair: + for _, v := range valueType { + offset += encode(v, to[offset:]) } } - return total, nil + return offset, nil +} + +func encodeBinary(id PropertyID, val interface{}, to []byte) (int, error) { + offset := 0 + + encode := func(v []byte, to []byte) int { + off := writePrefixID(id, to) + count, _ := WriteLPBytes(to[off:], v) + off += count + + return off + } + + switch valueType := val.(type) { + case []byte: + offset += encode(valueType, to[offset:]) + case [][]byte: + for _, v := range valueType { + offset += encode(v, to[offset:]) + } + } + return offset, nil } diff --git a/packet/puback.go b/packet/puback.go index b99db0e..264dda5 100644 --- a/packet/puback.go +++ b/packet/puback.go @@ -60,71 +60,76 @@ func (msg *Ack) Reason() ReasonCode { return msg.reasonCode } -func (msg *Ack) decodeMessage(src []byte) (int, error) { - total := 0 - - total += msg.decodePacketID(src[total:]) +func (msg *Ack) decodeMessage(from []byte) (int, error) { + offset := msg.decodePacketID(from) if msg.version == ProtocolV50 { - msg.reasonCode = ReasonCode(src[total]) - if !msg.reasonCode.IsValidForType(msg.mType) { - return total, CodeMalformedPacket + // [MQTT-3.4.2.1] + if len(from[offset:]) == 0 { + msg.reasonCode = CodeSuccess + return offset, nil } - total++ - // v5 [MQTT-3.1.2.11] specifies properties in variable header - var err error - var n int - if msg.properties, n, err = decodeProperties(msg.mType, src[total:]); err != nil { - return total + n, err + msg.reasonCode = ReasonCode(from[offset]) + if !msg.reasonCode.IsValidForType(msg.mType) { + return offset, CodeMalformedPacket + } + offset++ + + if len(from[offset:]) > 0 { + // v5 [MQTT-3.1.2.11] specifies properties in variable header + var err error + var n int + if msg.properties, n, err = decodeProperties(msg.mType, from[offset:]); err != nil { + return offset + n, err + } + offset += n } - total += n } - return total, nil + return offset, nil } -func (msg *Ack) encodeMessage(dst []byte) (int, error) { +func (msg *Ack) encodeMessage(to []byte) (int, error) { // [MQTT-2.3.1] if len(msg.packetID) == 0 { return 0, ErrPackedIDZero } - total := 0 - - total += msg.encodePacketID(dst[total:]) + offset := msg.encodePacketID(to) + var err error if msg.version == ProtocolV50 { - if !msg.reasonCode.IsValidForType(msg.mType) { - return total, ErrInvalidReturnCode + pLen := msg.properties.FullLen() + if pLen > 1 || msg.reasonCode != CodeSuccess { + to[offset] = byte(msg.reasonCode) + offset++ + + if pLen > 1 { + var n int + n, err = encodeProperties(msg.properties, to[offset:]) + offset += n + } } - - dst[total] = byte(msg.reasonCode) - total++ - - // v5 [MQTT-3.1.2.11] specifies properties in variable header - var err error - var n int - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err - } - - total += n } - return total, nil + return offset, err } func (msg *Ack) size() int { + // include size of PacketID total := 2 if msg.version == ProtocolV50 { - // V5.0 [MQTT-3.4.2.1] - total++ - - // v5.0 [MQTT-3.1.2.11] - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += pLen + pLen := msg.properties.FullLen() + // If properties exist (which indicated when pLen > 1) include in body size reason code and properties + // otherwise include only reason code if it differs from CodeSuccess + if pLen > 1 || msg.reasonCode != CodeSuccess { + total++ + if pLen > 1 { + total += int(pLen) + } + } } return total diff --git a/packet/puback_test.go b/packet/puback_test.go index 9214ab9..ba25d22 100644 --- a/packet/puback_test.go +++ b/packet/puback_test.go @@ -21,7 +21,7 @@ import ( ) func TestPubAckMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, PUBACK) + m, err := New(ProtocolV311, PUBACK) require.NoError(t, err) msg, ok := m.(*Ack) @@ -74,7 +74,7 @@ func TestPubAckMessageEncode(t *testing.T) { 7, // packet ID LSB (7) } - m, err := NewMessage(ProtocolV311, PUBACK) + m, err := New(ProtocolV311, PUBACK) require.NoError(t, err) msg, ok := m.(*Ack) diff --git a/packet/pubcomp_test.go b/packet/pubcomp_test.go index ffb4292..7f09764 100644 --- a/packet/pubcomp_test.go +++ b/packet/pubcomp_test.go @@ -21,7 +21,7 @@ import ( ) func TestPubCompMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, PUBCOMP) + m, err := New(ProtocolV311, PUBCOMP) require.NoError(t, err) msg, ok := m.(*Ack) @@ -75,7 +75,7 @@ func TestPubCompMessageEncode(t *testing.T) { 7, // packet ID LSB (7) } - m, err := NewMessage(ProtocolV311, PUBCOMP) + m, err := New(ProtocolV311, PUBCOMP) require.NoError(t, err) msg, ok := m.(*Ack) diff --git a/packet/publish.go b/packet/publish.go index 9a43945..49e1b21 100644 --- a/packet/publish.go +++ b/packet/publish.go @@ -19,8 +19,9 @@ package packet type Publish struct { header - payload []byte - topic string + payload []byte + topic string + publishID uintptr } var _ Provider = (*Publish)(nil) @@ -29,6 +30,66 @@ func newPublish() *Publish { return &Publish{} } +// Clone packet +// qos, topic, payload, retain and properties +func (msg *Publish) Clone(v ProtocolVersion) (*Publish, error) { + // message version should be same as session as encode/decode depends on it + _pkt, _ := New(msg.version, PUBLISH) + pkt, _ := _pkt.(*Publish) + + // [MQTT-3.3.1-9] + // [MQTT-3.3.1-3] + pkt.Set(msg.topic, msg.Payload(), msg.QoS(), msg.Retain(), false) // nolint: errcheck + + if msg.version == ProtocolV50 && v == ProtocolV50 { + // [MQTT-3.3.2-4] forward Payload Format + if prop, ok := msg.properties.properties[PropertyPayloadFormat]; ok { + if err := pkt.properties.Set(msg.mType, PropertyPayloadFormat, prop); err != nil { + return nil, err + } + } + + // [MQTT-1892 3.3.2-15] forward Response Topic + if prop, ok := msg.properties.properties[PropertyResponseTopic]; ok { + if err := pkt.properties.Set(msg.mType, PropertyResponseTopic, prop); err != nil { + return nil, err + } + } + + // [MQTT-1908 3.3.2-16] forward Correlation Data + if prop, ok := msg.properties.properties[PropertyCorrelationData]; ok { + if err := pkt.properties.Set(msg.mType, PropertyCorrelationData, prop); err != nil { + return nil, err + } + } + + // [MQTT-3.3.2-17] forward User Property + if prop, ok := msg.properties.properties[PropertyUserProperty]; ok { + if err := pkt.properties.Set(msg.mType, PropertyUserProperty, prop); err != nil { + return nil, err + } + } + + // [MQTT-3.3.2-20] forward Content Type + if prop, ok := msg.properties.properties[PropertyContentType]; ok { + if err := pkt.properties.Set(msg.mType, PropertyContentType, prop); err != nil { + return nil, err + } + } + } + return pkt, nil +} + +// PublishID get publish ID to check No Local +func (msg *Publish) PublishID() uintptr { + return msg.publishID +} + +// SetPublishID internally used publish id to allow No Local option +func (msg *Publish) SetPublishID(id uintptr) { + msg.publishID = id +} + // Set topic/payload/qos/retained/bool func (msg *Publish) Set(t string, p []byte, q QosType, r bool, d bool) error { if !ValidTopic(t) { @@ -122,7 +183,7 @@ func (msg *Publish) Topic() string { // SetTopic sets the the topic name that identifies the information channel to which // payload data is published. An error is returned if ValidTopic() is falbase. func (msg *Publish) SetTopic(v string) error { - if !ValidTopic(v) { + if (msg.version < ProtocolV50 && len(v) == 0) || !ValidTopic(v) { return ErrInvalidTopic } @@ -146,11 +207,11 @@ func (msg *Publish) SetPacketID(v IDType) { msg.setPacketID(v) } -func (msg *Publish) decodeMessage(src []byte) (int, error) { +func (msg *Publish) decodeMessage(from []byte) (int, error) { var err error var n int var buf []byte - total := 0 + offset := 0 if !msg.QoS().IsValid() { var rejectCode ReasonCode @@ -160,7 +221,7 @@ func (msg *Publish) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } // [MQTT-3.3.1-2] @@ -172,17 +233,25 @@ func (msg *Publish) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } - buf, n, err = ReadLPBytes(src[total:]) - total += n + // [MQTT-3.3.2.1] + buf, n, err = ReadLPBytes(from[offset:]) + offset += n if err != nil { - return total, err + return offset, err } - if !ValidTopic(string(buf)) { - return total, ErrInvalidTopic + if len(buf) == 0 && msg.version < ProtocolV50 { + return offset, CodeRefusedServerUnavailable + } else if !ValidTopic(string(buf)) { + rejectCode := CodeRefusedServerUnavailable + if msg.version == ProtocolV50 { + rejectCode = CodeInvalidTopicName + } + + return offset, rejectCode } msg.topic = string(buf) @@ -190,18 +259,32 @@ func (msg *Publish) decodeMessage(src []byte) (int, error) { // The packet identifier field is only present in the PUBLISH packets where the // QoS level is 1 or 2 if msg.QoS() != QoS0 { - total += msg.decodePacketID(src[total:]) + offset += msg.decodePacketID(from[offset:]) } if msg.version == ProtocolV50 { - msg.properties, n, err = decodeProperties(msg.Type(), buf[total:]) - total += n + msg.properties, n, err = decodeProperties(msg.Type(), from[offset:]) + offset += n if err != nil { - return total, err + return offset, err + } + + // if packet does not have topic set there must be topic alias set in properties + if len(msg.topic) == 0 { + reject := CodeProtocolError + if prop := msg.PropertyGet(PropertyTopicAlias); prop != nil { + if val, ok := prop.AsShort(); ok == nil && val > 0 { + reject = CodeSuccess + } + } + + if reject != CodeSuccess { + return offset, reject + } } } - pLen := int(msg.remLen) - total + pLen := int(msg.remLen) - offset // check payload len is not malformed if pLen < 0 { @@ -212,11 +295,11 @@ func (msg *Publish) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } // check payload is not malformed - if len(src[total:]) < pLen { + if len(from[offset:]) < pLen { var rejectCode ReasonCode if msg.version == ProtocolV50 { rejectCode = CodeMalformedPacket @@ -224,19 +307,19 @@ func (msg *Publish) decodeMessage(src []byte) (int, error) { rejectCode = CodeRefusedServerUnavailable } - return total, rejectCode + return offset, rejectCode } if pLen > 0 { msg.payload = make([]byte, pLen) - copy(msg.payload, src[total:total+pLen]) - total += pLen + copy(msg.payload, from[offset:offset+pLen]) + offset += pLen } - return total, nil + return offset, nil } -func (msg *Publish) encodeMessage(dst []byte) (int, error) { +func (msg *Publish) encodeMessage(to []byte) (int, error) { if !ValidTopic(msg.topic) { return 0, ErrInvalidTopic } @@ -256,42 +339,44 @@ func (msg *Publish) encodeMessage(dst []byte) (int, error) { var err error var n int - total := 0 + offset := 0 - if n, err = WriteLPBytes(dst[total:], []byte(msg.topic)); err != nil { - return total, err + // [MQTT-3.3.2.1] + if n, err = WriteLPBytes(to[offset:], []byte(msg.topic)); err != nil { + return offset, err } - total += n + offset += n + // [MQTT-3.3.2.2] if msg.QoS() != QoS0 { - total += msg.encodePacketID(dst[total:]) + offset += msg.encodePacketID(to[offset:]) } // V5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err + if n, err = msg.properties.encode(to[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } - total += copy(dst[total:], msg.payload) + offset += copy(to[offset:], msg.payload) - return total, nil + return offset, nil } func (msg *Publish) size() int { total := 2 + len(msg.topic) + len(msg.payload) if msg.QoS() != 0 { + // QoS1/2 packets must include packet id total += 2 } // v5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += pLen + total += int(msg.properties.FullLen()) } return total diff --git a/packet/pubrec_test.go b/packet/pubrec_test.go index e9fc071..8dd5305 100644 --- a/packet/pubrec_test.go +++ b/packet/pubrec_test.go @@ -21,7 +21,7 @@ import ( ) func TestPubRecMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, PUBREC) + m, err := New(ProtocolV311, PUBREC) require.NoError(t, err) msg, ok := m.(*Ack) @@ -73,7 +73,7 @@ func TestPubRecMessageEncode(t *testing.T) { 7, // packet ID LSB (7) } - m, err := NewMessage(ProtocolV311, PUBREC) + m, err := New(ProtocolV311, PUBREC) require.NoError(t, err) msg, ok := m.(*Ack) diff --git a/packet/pubrel_test.go b/packet/pubrel_test.go index b5d6dc7..89eae04 100644 --- a/packet/pubrel_test.go +++ b/packet/pubrel_test.go @@ -21,7 +21,7 @@ import ( ) func TestPubRelMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, PUBREL) + m, err := New(ProtocolV311, PUBREL) require.NoError(t, err) msg, ok := m.(*Ack) @@ -74,7 +74,7 @@ func TestPubRelMessageEncode(t *testing.T) { 7, // packet ID LSB (7) } - m, err := NewMessage(ProtocolV311, PUBREL) + m, err := New(ProtocolV311, PUBREL) require.NoError(t, err) msg, ok := m.(*Ack) diff --git a/packet/reasonCodes.go b/packet/reasonCodes.go index 0620015..0672416 100644 --- a/packet/reasonCodes.go +++ b/packet/reasonCodes.go @@ -163,7 +163,7 @@ var packetTypeCodeMap = map[Type]map[ReasonCode]struct { CodeInvalidTopicName: {iss: CodeIssuerBoth, desc: "The topic name is valid, but is not accepted"}, CodePacketTooLarge: {iss: CodeIssuerBoth, desc: "The packet size is too large"}, CodeReceiveMaximumExceeded: {iss: CodeIssuerBoth, desc: "The Client or Server has received more than Receive Maximum publication for which it has not sent PUBACK or PUBCOMP"}, - CodeInvalidTopicAlias: {iss: CodeIssuerBoth, desc: "The Client or Server has received a PUBLISH packet containing a Topic Alias which is greater than the Maximum Topic Alias it sent in the CONNECT or CONNACK packet"}, + CodeInvalidTopicAlias: {iss: CodeIssuerBoth, desc: "Invalid topic alias"}, CodeMessageRateTooHigh: {iss: CodeIssuerBoth, desc: "The rate of publish is too high"}, CodeQuotaExceeded: {iss: CodeIssuerBoth, desc: "An implementation imposed limit has been exceeded"}, CodeAdministrativeAction: {iss: CodeIssuerBoth, desc: "The Connection is closed due to an administrative action"}, @@ -226,7 +226,7 @@ var codeDescMap = map[ReasonCode]string{ CodePacketIDInUse: "", CodePacketIDNotFound: "", CodeReceiveMaximumExceeded: "", - CodeInvalidTopicAlias: "", + CodeInvalidTopicAlias: "Invalid topic alias", CodePacketTooLarge: "", CodeMessageRateTooHigh: "", CodeQuotaExceeded: "", diff --git a/packet/suback.go b/packet/suback.go index 0aeaded..a494f85 100644 --- a/packet/suback.go +++ b/packet/suback.go @@ -123,8 +123,7 @@ func (msg *SubAck) size() int { total := 2 + len(msg.returnCodes) // v5.0 [MQTT-3.1.2.11] if msg.version == ProtocolV50 { - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += pLen + total += int(msg.properties.FullLen()) } return total diff --git a/packet/suback_test.go b/packet/suback_test.go index 74a777b..c04afab 100644 --- a/packet/suback_test.go +++ b/packet/suback_test.go @@ -21,7 +21,7 @@ import ( ) func TestSubAckMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, SUBACK) + m, err := New(ProtocolV311, SUBACK) require.NoError(t, err) msg, ok := m.(*SubAck) @@ -92,7 +92,7 @@ func TestSubAckMessageEncode(t *testing.T) { 0x80, // return code 4 } - m, err := NewMessage(ProtocolV311, SUBACK) + m, err := New(ProtocolV311, SUBACK) require.NoError(t, err) msg, ok := m.(*SubAck) diff --git a/packet/subscribe_test.go b/packet/subscribe_test.go index b092951..466400d 100644 --- a/packet/subscribe_test.go +++ b/packet/subscribe_test.go @@ -24,7 +24,7 @@ import ( ) func TestSubscribeMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, SUBSCRIBE) + m, err := New(ProtocolV311, SUBSCRIBE) require.NoError(t, err) msg, ok := m.(*Subscribe) @@ -123,7 +123,7 @@ func TestSubscribeMessageEncode(t *testing.T) { 2, // QoS } - m, err := NewMessage(ProtocolV311, SUBSCRIBE) + m, err := New(ProtocolV311, SUBSCRIBE) require.NoError(t, err) msg, ok := m.(*Subscribe) diff --git a/packet/unsuback.go b/packet/unsuback.go index c309b0c..e6a673d 100644 --- a/packet/unsuback.go +++ b/packet/unsuback.go @@ -62,47 +62,39 @@ func (msg *UnSubAck) AddReturnCode(ret ReasonCode) error { } // decode message -func (msg *UnSubAck) decodeMessage(src []byte) (int, error) { - total := msg.decodePacketID(src) +func (msg *UnSubAck) decodeMessage(from []byte) (int, error) { + offset := msg.decodePacketID(from) - if msg.version == ProtocolV50 && (int(msg.remLen)-total) > 0 { + if msg.version == ProtocolV50 && (int(msg.remLen)-offset) > 0 { var n int var err error - if msg.properties, n, err = decodeProperties(msg.Type(), src[total:]); err != nil { - return total + n, err + if msg.properties, n, err = decodeProperties(msg.Type(), from[offset:]); err != nil { + return offset + n, err } - total += n + offset += n } - return total, nil + return offset, nil } -func (msg *UnSubAck) encodeMessage(dst []byte) (int, error) { +func (msg *UnSubAck) encodeMessage(to []byte) (int, error) { // [MQTT-2.3.1] if len(msg.packetID) == 0 { return 0, ErrPackedIDZero } - total := msg.encodePacketID(dst) + offset := msg.encodePacketID(to) + var err error if msg.version == ProtocolV50 { var n int - var err error - - if n, err = encodeProperties(msg.properties, []byte{}); err != nil { - return total, err - } - if n > 1 { - if n, err = encodeProperties(msg.properties, dst[total:]); err != nil { - return total + n, err - } - total += n - } + n, err = encodeProperties(msg.properties, to[offset:]) + offset += n } - return total, nil + return offset, err } func (msg *UnSubAck) size() int { @@ -110,12 +102,10 @@ func (msg *UnSubAck) size() int { total := 2 if msg.version == ProtocolV50 { - pLen, _ := encodeProperties(msg.properties, []byte{}) - total += pLen - - if pLen > 1 { - total += pLen - } + total += int(msg.properties.FullLen()) + //if pLen := msg.properties.FullLen(); pLen > 1 { + // total += int(pLen) + //} } return total diff --git a/packet/unsuback_test.go b/packet/unsuback_test.go index b31a530..3275078 100644 --- a/packet/unsuback_test.go +++ b/packet/unsuback_test.go @@ -21,7 +21,7 @@ import ( ) func TestUnSubAckMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, UNSUBACK) + m, err := New(ProtocolV311, UNSUBACK) require.NoError(t, err) msg, ok := m.(*UnSubAck) @@ -74,7 +74,7 @@ func TestUnSubAckMessageEncode(t *testing.T) { 7, // packet ID LSB (7) } - m, err := NewMessage(ProtocolV311, UNSUBACK) + m, err := New(ProtocolV311, UNSUBACK) require.NoError(t, err) msg, ok := m.(*UnSubAck) diff --git a/packet/unsubscribe_test.go b/packet/unsubscribe_test.go index 7416031..27d6d55 100644 --- a/packet/unsubscribe_test.go +++ b/packet/unsubscribe_test.go @@ -21,7 +21,7 @@ import ( ) func TestUnSubscribeMessageFields(t *testing.T) { - m, err := NewMessage(ProtocolV311, UNSUBSCRIBE) + m, err := New(ProtocolV311, UNSUBSCRIBE) require.NoError(t, err) msg, ok := m.(*UnSubscribe) @@ -114,7 +114,7 @@ func TestUnSubscribeMessageEncode(t *testing.T) { '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', } - m, err := NewMessage(ProtocolV311, UNSUBSCRIBE) + m, err := New(ProtocolV311, UNSUBSCRIBE) require.NoError(t, err) msg, ok := m.(*UnSubscribe) diff --git a/persistence/types/types.go b/persistence/types/types.go index 2900826..42daee8 100644 --- a/persistence/types/types.go +++ b/persistence/types/types.go @@ -78,6 +78,12 @@ type Retained interface { Wipe() error } +// ConnectionMessages interface for connection to handle messages +type ConnectionMessages interface { + MessagesLoad([]byte) (*SessionMessages, error) + MessagesStore([]byte, *SessionMessages) error +} + // Sessions interface allows operating with sessions inside backend type Sessions interface { StatesIterate(func([]byte, *SessionState) error) error diff --git a/subscriber/subscriber.go b/subscriber/subscriber.go index 72a1470..15646a4 100644 --- a/subscriber/subscriber.go +++ b/subscriber/subscriber.go @@ -11,36 +11,26 @@ import ( // ConnectionProvider passed to present network connection type ConnectionProvider interface { + ID() string Subscriptions() Subscriptions - Subscribe(string, *SubscriptionParams) (packet.QosType, []*packet.Publish, error) + Subscribe(string, *topicsTypes.SubscriptionParams) (packet.QosType, []*packet.Publish, error) UnSubscribe(string) error HasSubscriptions() bool Online(c OnlinePublish) + OnlineRedirect(c OnlinePublish) Offline(bool) + Hash() uintptr Version() packet.ProtocolVersion } // OnlinePublish invoked when subscriber respective to sessions receive message -type OnlinePublish func(*packet.Publish) error +type OnlinePublish func(*packet.Publish) // OfflinePublish invoked when subscriber respective to sessions receive message type OfflinePublish func(string, *packet.Publish) -// SubscriptionParams parameters of the subscription -type SubscriptionParams struct { - // Subscription id - // V5.0 ONLY - ID uint32 - - // Requested QoS requested by subscriber - Requested packet.SubscriptionOptions - - // Granted QoS granted by topics manager - Granted packet.QosType -} - // Subscriptions contains active subscriptions with respective subscription parameters -type Subscriptions map[string]*SubscriptionParams +type Subscriptions map[string]*topicsTypes.SubscriptionParams // Config subscriber config options type Config struct { @@ -84,6 +74,11 @@ func New(c *Config) *Type { return p } +// ID get subscriber id +func (s *Type) ID() string { + return s.id +} + // Hash returns address of the provider struct. // Used by topics provider as a key to subscriber object func (s *Type) Hash() uintptr { @@ -116,10 +111,9 @@ func (s *Type) Subscriptions() Subscriptions { } // Subscribe to given topic -func (s *Type) Subscribe(topic string, params *SubscriptionParams) (packet.QosType, []*packet.Publish, error) { - q, r, err := s.topics.Subscribe(topic, params.Requested.QoS(), s, params.ID) +func (s *Type) Subscribe(topic string, params *topicsTypes.SubscriptionParams) (packet.QosType, []*packet.Publish, error) { + q, r, err := s.topics.Subscribe(topic, s, params) - params.Granted = q s.subscriptions[topic] = params return q, r, err @@ -135,22 +129,24 @@ func (s *Type) UnSubscribe(topic string) error { // Publish message accordingly to subscriber state // online: forward message to session // offline: persist message -func (s *Type) Publish(m *packet.Publish, grantedQoS packet.QosType, ids []uint32) error { - // message version should be same as session as encode/decode depends on it - mP, _ := packet.NewMessage(s.version, packet.PUBLISH) - msg, _ := mP.(*packet.Publish) - - // TODO: copy properties for V5.0 - msg.SetDup(false) - msg.SetQoS(m.QoS()) // nolint: errcheck - msg.SetTopic(m.Topic()) // nolint: errcheck - msg.SetRetain(false) - msg.SetPayload(m.Payload()) - - msg.PropertySet(packet.PropertySubscriptionIdentifier, ids) // nolint: errcheck - - if msg.QoS() != packet.QoS0 { - msg.SetPacketID(0) +func (s *Type) Publish(p *packet.Publish, grantedQoS packet.QosType, ops packet.SubscriptionOptions, ids []uint32) error { + pkt, err := p.Clone(s.version) + if err != nil { + return err + } + + if len(ids) > 0 { + if err = pkt.PropertySet(packet.PropertySubscriptionIdentifier, ids); err != nil { + return err + } + } + + if !ops.RAP() { + pkt.SetRetain(false) + } + + if pkt.QoS() != packet.QoS0 { + pkt.SetPacketID(0) } switch grantedQoS { @@ -160,8 +156,8 @@ func (s *Type) Publish(m *packet.Publish, grantedQoS packet.QosType, ids []uint3 // Message published to the same topic is downgraded by the Server to QoS 1 for delivery to the // Client, so that Client might receive duplicate copies of the Message. case packet.QoS1: - if msg.QoS() == packet.QoS2 { - msg.SetQoS(packet.QoS1) // nolint: errcheck + if pkt.QoS() == packet.QoS2 { + pkt.SetQoS(packet.QoS1) // nolint: errcheck } // If the subscribing Client has been granted maximum QoS 0, then an Application Message @@ -175,11 +171,13 @@ func (s *Type) Publish(m *packet.Publish, grantedQoS packet.QosType, ids []uint3 case <-s.isOnline: // if session is offline forward message to persisted storage // only with QoS1 and QoS2 and QoS0 if set by config - qos := msg.QoS() + qos := pkt.QoS() if qos != packet.QoS0 || (s.offlineQoS0 && qos == packet.QoS0) { defer s.wgOffline.Done() + s.publishLock.RLock() s.wgOffline.Add(1) - s.publishOffline(s.id, msg) + s.publishLock.RUnlock() + s.publishOffline(s.id, pkt) } default: // forward message to publish queue @@ -187,7 +185,7 @@ func (s *Type) Publish(m *packet.Publish, grantedQoS packet.QosType, ids []uint3 s.publishLock.RLock() s.wgOnline.Add(1) s.publishLock.RUnlock() - return s.publishOnline(msg) + s.publishOnline(pkt) } return nil @@ -196,9 +194,18 @@ func (s *Type) Publish(m *packet.Publish, grantedQoS packet.QosType, ids []uint3 // Online moves subscriber to online state // since this moment all of publishes are forwarded to provided callback func (s *Type) Online(c OnlinePublish) { + s.wgOffline.Wait() + defer s.publishLock.Unlock() + s.publishLock.Lock() s.publishOnline = c s.isOnline = make(chan struct{}) - s.wgOffline.Wait() +} + +// OnlineRedirect set new online publish callback +func (s *Type) OnlineRedirect(c OnlinePublish) { + defer s.publishLock.Unlock() + s.publishLock.Lock() + s.publishOnline = c } // Offline put session offline diff --git a/systree/clients.go b/systree/clients.go index 3a76d7e..52a2659 100644 --- a/systree/clients.go +++ b/systree/clients.go @@ -20,6 +20,7 @@ type ClientConnectStatus struct { KeepAlive uint16 GeneratedID bool CleanSession bool + KillOnDisconnect bool SessionPresent bool PreserveOrder bool MaximumQoS packet.QosType @@ -55,7 +56,7 @@ func (t *clients) Connected(id string, status *ClientConnectStatus) { } // notify client connected - nm, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ := nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck @@ -74,7 +75,7 @@ func (t *clients) Connected(id string, status *ClientConnectStatus) { } // notify remove previous disconnect if any - nm, _ = packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ = packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ = nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck @@ -89,7 +90,7 @@ func (t *clients) Connected(id string, status *ClientConnectStatus) { func (t *clients) Disconnected(id string, reason packet.ReasonCode, retain bool) { atomic.AddUint64(&t.curr.val, ^uint64(0)) - nm, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ := nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck @@ -113,7 +114,7 @@ func (t *clients) Disconnected(id string, reason packet.ReasonCode, retain bool) } // remove connected retained message - nm, _ = packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ = packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ = nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck diff --git a/systree/interfaces.go b/systree/interfaces.go index 34cffcb..b7b5c3c 100644 --- a/systree/interfaces.go +++ b/systree/interfaces.go @@ -41,8 +41,8 @@ type Sessions interface { // Clients Statistic of sessions type Clients interface { - Connected(string, *ClientConnectStatus) - Disconnected(string, packet.ReasonCode, bool) + Connected(id string, status *ClientConnectStatus) + Disconnected(id string, reason packet.ReasonCode, retain bool) } // TopicsStat statistic of topics diff --git a/systree/server.go b/systree/server.go index f0a5ab2..5c2ec31 100644 --- a/systree/server.go +++ b/systree/server.go @@ -32,7 +32,7 @@ func newServer(topicPrefix string, dynRetains, staticRetains *[]types.RetainObje version: "1.0.0", } - m, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + m, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) msg, _ := m.(*packet.Publish) msg.SetQoS(packet.QoS0) // nolint: errcheck msg.SetTopic(topicPrefix + "/version") // nolint: errcheck @@ -42,7 +42,7 @@ func newServer(topicPrefix string, dynRetains, staticRetains *[]types.RetainObje *dynRetains = append(*dynRetains, b.currTime) *staticRetains = append(*staticRetains, msg) - m, _ = packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + m, _ = packet.New(packet.ProtocolV311, packet.PUBLISH) msg, _ = m.(*packet.Publish) msg.SetQoS(packet.QoS0) // nolint: errcheck msg.SetTopic(topicPrefix + "/capabilities") // nolint: errcheck diff --git a/systree/sessions.go b/systree/sessions.go index 3df3840..44ab519 100644 --- a/systree/sessions.go +++ b/systree/sessions.go @@ -47,7 +47,7 @@ func (t *sessions) Created(id string, status *SessionCreatedStatus) { if t.topicsManager != nil { // notify client connected - nm, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ := nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck @@ -69,7 +69,7 @@ func (t *sessions) Created(id string, status *SessionCreatedStatus) { func (t *sessions) Removed(id string, status *SessionDeletedStatus) { atomic.AddUint64(&t.curr.val, ^uint64(0)) if t.topicsManager != nil { - nm, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ := nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck @@ -77,7 +77,7 @@ func (t *sessions) Removed(id string, status *SessionDeletedStatus) { t.topicsManager.Retain(notifyMsg) // nolint: errcheck - nm, _ = packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + nm, _ = packet.New(packet.ProtocolV311, packet.PUBLISH) notifyMsg, _ = nm.(*packet.Publish) notifyMsg.SetRetain(false) notifyMsg.SetQoS(packet.QoS0) // nolint: errcheck diff --git a/systree/types.go b/systree/types.go index 097e3c9..4554cd5 100644 --- a/systree/types.go +++ b/systree/types.go @@ -87,7 +87,7 @@ func (m *dynamicValue) Topic() string { func (m *dynamicValue) Retained() *packet.Publish { if m.retained == nil { - np, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + np, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) m.retained, _ = np.(*packet.Publish) m.retained.SetTopic(m.topic) // nolint: errcheck m.retained.SetQoS(packet.QoS0) // nolint: errcheck @@ -101,7 +101,7 @@ func (m *dynamicValue) Retained() *packet.Publish { func (m *dynamicValue) Publish() *packet.Publish { if m.publish == nil { - np, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + np, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) m.publish, _ = np.(*packet.Publish) m.publish.SetTopic(m.topic) // nolint: errcheck m.publish.SetQoS(packet.QoS0) // nolint: errcheck diff --git a/topics/mem/node.go b/topics/mem/node.go index a39e6bf..42cfdfa 100644 --- a/topics/mem/node.go +++ b/topics/mem/node.go @@ -10,15 +10,30 @@ import ( ) type subscribedEntry struct { - s topicsTypes.Subscriber - id uint32 - grantedQoS packet.QosType + s topicsTypes.Subscriber + p *topicsTypes.SubscriptionParams } type subscribedEntries map[uintptr]*subscribedEntry +func (s *subscribedEntry) acquire() *publishEntry { + s.s.Acquire() + pe := &publishEntry{ + s: s.s, + qos: s.p.Granted, + ops: s.p.Ops, + } + + if s.p.ID > 0 { + pe.ids = []uint32{s.p.ID} + } + + return pe +} + type publishEntry struct { s topicsTypes.Subscriber + ops packet.SubscriptionOptions qos packet.QosType ids []uint32 } @@ -30,7 +45,7 @@ type node struct { subs subscribedEntries parent *node children map[string]*node - getSubscribers func(p *publishEntries) + getSubscribers func(uintptr, *publishEntries) } func newNode(overlap bool, parent *node) *node { @@ -82,22 +97,25 @@ func (mT *provider) leafSearchNode(levels []string) *node { return root } -func (mT *provider) subscriptionInsert(filter string, qos packet.QosType, sub topicsTypes.Subscriber, id uint32) { +func (mT *provider) subscriptionInsert(filter string, sub topicsTypes.Subscriber, p *topicsTypes.SubscriptionParams) bool { levels := strings.Split(filter, "/") root := mT.leafInsertNode(levels) // Let's see if the subscriber is already on the list and just update QoS if so // Otherwise create new entry + exists := false if s, ok := root.subs[sub.Hash()]; !ok { root.subs[sub.Hash()] = &subscribedEntry{ - s: sub, - grantedQoS: qos, - id: id, + s: sub, + p: p, } } else { - s.grantedQoS = qos + s.p = p + exists = true } + + return exists } func (mT *provider) subscriptionRemove(topic string, sub topicsTypes.Subscriber) error { @@ -144,38 +162,38 @@ func (mT *provider) subscriptionRemove(topic string, sub topicsTypes.Subscriber) return err } -func subscriptionRecurseSearch(root *node, levels []string, p *publishEntries) { +func subscriptionRecurseSearch(root *node, levels []string, publishID uintptr, p *publishEntries) { if len(levels) == 0 { // leaf level of the topic // get all subscribers and return - root.getSubscribers(p) + root.getSubscribers(publishID, p) if n, ok := root.children[topicsTypes.MWC]; ok { - n.getSubscribers(p) + n.getSubscribers(publishID, p) } } else { if n, ok := root.children[topicsTypes.MWC]; ok && len(levels[0]) != 0 { - n.getSubscribers(p) + n.getSubscribers(publishID, p) } if n, ok := root.children[levels[0]]; ok { - subscriptionRecurseSearch(n, levels[1:], p) + subscriptionRecurseSearch(n, levels[1:], publishID, p) } if n, ok := root.children[topicsTypes.SWC]; ok { - subscriptionRecurseSearch(n, levels[1:], p) + subscriptionRecurseSearch(n, levels[1:], publishID, p) } } } -func (mT *provider) subscriptionSearch(topic string, p *publishEntries) { +func (mT *provider) subscriptionSearch(topic string, publishID uintptr, p *publishEntries) { root := mT.root levels := strings.Split(topic, "/") level := levels[0] if !strings.HasPrefix(level, "$") { - subscriptionRecurseSearch(root, levels, p) + subscriptionRecurseSearch(root, levels, publishID, p) } else if n, ok := root.children[level]; ok { - subscriptionRecurseSearch(n, levels[1:], p) + subscriptionRecurseSearch(n, levels[1:], publishID, p) } } @@ -277,48 +295,34 @@ func (sn *node) allRetained(retained *[]*packet.Publish) { } } -func (sn *node) overlappingSubscribers(p *publishEntries) { +func (sn *node) overlappingSubscribers(publishID uintptr, p *publishEntries) { for id, sub := range sn.subs { if s, ok := (*p)[id]; ok { - if sub.id > 0 { - s[0].ids = append(s[0].ids, sub.id) + if sub.p.ID > 0 { + s[0].ids = append(s[0].ids, sub.p.ID) } - if s[0].qos < sub.grantedQoS { - s[0].qos = sub.grantedQoS + if s[0].qos < sub.p.Granted { + s[0].qos = sub.p.Granted } } else { - sub.s.Acquire() - pe := &publishEntry{ - s: sub.s, - qos: sub.grantedQoS, - } - - if sub.id > 0 { - pe.ids = []uint32{sub.id} + if !sub.p.Ops.NL() || id != publishID { + pe := sub.acquire() + (*p)[id] = append((*p)[id], pe) } - - (*p)[id] = append((*p)[id], pe) } } } -func (sn *node) nonOverlappingSubscribers(p *publishEntries) { +func (sn *node) nonOverlappingSubscribers(publishID uintptr, p *publishEntries) { for id, sub := range sn.subs { - sub.s.Acquire() - pe := &publishEntry{ - s: sub.s, - qos: sub.grantedQoS, - } - - if sub.id > 0 { - pe.ids = []uint32{sub.id} - } - - if _, ok := (*p)[id]; ok { - (*p)[id] = append((*p)[id], pe) - } else { - (*p)[id] = []*publishEntry{pe} + if !sub.p.Ops.NL() || id != publishID { + pe := sub.acquire() + if _, ok := (*p)[id]; ok { + (*p)[id] = append((*p)[id], pe) + } else { + (*p)[id] = []*publishEntry{pe} + } } } } diff --git a/topics/mem/regex.txt b/topics/mem/regex.txt deleted file mode 100644 index 4f2c9b3..0000000 --- a/topics/mem/regex.txt +++ /dev/null @@ -1,14 +0,0 @@ -/ -# -/# -+ -/+ -$SYS/bla/bla/ -$SYS/bla/bla -$SYS/bla/bla# -$SYS/bla/bla/# -$SYS/bla/bla/+ -$SYS/bla/bla+ -$SYS/bla/bla - -(?=.)^(([^+#]*|\+)(\/([^+#]*|\+))*(\/#)?|#)$ \ No newline at end of file diff --git a/topics/mem/topics.go b/topics/mem/topics.go index fa00781..d2d4a03 100644 --- a/topics/mem/topics.go +++ b/topics/mem/topics.go @@ -29,27 +29,17 @@ import ( type provider struct { // Sub/unSub mutex smu sync.RWMutex - // Subscription tree - root *node - - stat systree.TopicsStat - - persist persistenceTypes.Retained - - log struct { - prod *zap.Logger - dev *zap.Logger - } - + root *node + stat systree.TopicsStat + persist persistenceTypes.Retained + log *zap.Logger onCleanUnsubscribe func([]string) wgPublisher sync.WaitGroup wgPublisherStarted sync.WaitGroup - - inbound chan *packet.Publish - inRetained chan types.RetainObject - - allowOverlapping bool + inbound chan *packet.Publish + inRetained chan types.RetainObject + allowOverlapping bool } var _ topicsTypes.Provider = (*provider)(nil) @@ -68,8 +58,7 @@ func NewMemProvider(config *topicsTypes.MemConfig) (topicsTypes.Provider, error) } p.root = newNode(p.allowOverlapping, nil) - p.log.prod = configuration.GetProdLogger().Named("topics").Named(config.Name) - p.log.dev = configuration.GetDevLogger().Named("topics").Named(config.Name) + p.log = configuration.GetLogger().Named("topics").Named(config.Name) if p.persist != nil { entries, err := p.persist.Load() @@ -81,16 +70,16 @@ func NewMemProvider(config *topicsTypes.MemConfig) (topicsTypes.Provider, error) v := packet.ProtocolVersion(d[0]) msg, _, err := packet.Decode(v, d[1:]) if err != nil { - p.log.prod.Error("Couldn't decode retained message", zap.Error(err)) + p.log.Error("Couldn't decode retained message", zap.Error(err)) } else { if m, ok := msg.(*packet.Publish); ok { - p.log.dev.Debug("Loading retained message", + p.log.Debug("Loading retained message", zap.String("topic", m.Topic()), zap.Int8("QoS", int8(m.QoS()))) p.Retain(m) // nolint: errcheck } else { - p.log.prod.Warn("Unsupported retained message type", zap.String("type", m.Type().Name())) + p.log.Warn("Unsupported retained message type", zap.String("type", m.Type().Name())) } } } @@ -109,26 +98,22 @@ func NewMemProvider(config *topicsTypes.MemConfig) (topicsTypes.Provider, error) return p, nil } -func (mT *provider) Subscribe(filter string, q packet.QosType, s topicsTypes.Subscriber, id uint32) (packet.QosType, []*packet.Publish, error) { - if !q.IsValid() { - return packet.QosFailure, nil, packet.ErrInvalidQoS - } - - if s == nil { - return packet.QosFailure, nil, topicsTypes.ErrInvalidSubscriber - } - +func (mT *provider) Subscribe(filter string, s topicsTypes.Subscriber, p *topicsTypes.SubscriptionParams) (packet.QosType, []*packet.Publish, error) { defer mT.smu.Unlock() mT.smu.Lock() - mT.subscriptionInsert(filter, q, s, id) + p.Granted = p.Ops.QoS() + exists := mT.subscriptionInsert(filter, s, p) var r []*packet.Publish // [MQTT-3.3.1-5] - mT.retainSearch(filter, &r) + rh := p.Ops.RetainHandling() + if (rh == packet.RetainHandlingRetain) || ((rh == packet.RetainHandlingIfNotExists) && !exists) { + mT.retainSearch(filter, &r) + } - return q, r, nil + return p.Granted, r, nil } func (mT *provider) UnSubscribe(topic string, sub topicsTypes.Subscriber) error { @@ -187,11 +172,11 @@ func (mT *provider) Close() error { // Skip retained QoS0 messages if m.QoS() != packet.QoS0 { if sz, err := m.Size(); err != nil { - mT.log.prod.Error("Couldn't get retained message size", zap.Error(err)) + mT.log.Error("Couldn't get retained message size", zap.Error(err)) } else { buf := make([]byte, sz) if _, err = m.Encode(buf); err != nil { - mT.log.prod.Error("Couldn't encode retained message", zap.Error(err)) + mT.log.Error("Couldn't encode retained message", zap.Error(err)) } else { encoded = append(encoded, buf) } @@ -199,9 +184,9 @@ func (mT *provider) Close() error { } } if len(encoded) > 0 { - mT.log.dev.Debug("Storing retained messages", zap.Int("amount", len(encoded))) + mT.log.Debug("Storing retained messages", zap.Int("amount", len(encoded))) if err := mT.persist.Store(encoded); err != nil { - mT.log.prod.Error("Couldn't persist retained messages", zap.Error(err)) + mT.log.Error("Couldn't persist retained messages", zap.Error(err)) } } } @@ -252,12 +237,12 @@ func (mT *provider) publisher() { pubEntries := publishEntries{} mT.smu.Lock() - mT.subscriptionSearch(msg.Topic(), &pubEntries) + mT.subscriptionSearch(msg.Topic(), msg.PublishID(), &pubEntries) for _, pub := range pubEntries { for _, e := range pub { - if err := e.s.Publish(msg, e.qos, e.ids); err != nil { - mT.log.prod.Error("Publish error", zap.Error(err)) + if err := e.s.Publish(msg, e.qos, e.ops, e.ids); err != nil { + mT.log.Error("Publish error", zap.Error(err)) } e.s.Release() } diff --git a/topics/mem/trie_test.go b/topics/mem/trie_test.go index 79e6738..18ef642 100644 --- a/topics/mem/trie_test.go +++ b/topics/mem/trie_test.go @@ -37,11 +37,14 @@ func TestMatch1(t *testing.T) { prov := allocProvider(t) sub := &subscriber.Type{} - prov.Subscribe("sport/tennis/player1/#", packet.QoS1, sub, 0) // nolint: errcheck + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + prov.Subscribe("sport/tennis/player1/#", sub, p) // nolint: errcheck subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -50,11 +53,14 @@ func TestMatch2(t *testing.T) { sub := &subscriber.Type{} - prov.Subscribe("sport/tennis/player1/#", packet.QoS2, sub, 0) // nolint: errcheck + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } + prov.Subscribe("sport/tennis/player1/#", sub, p) // nolint: errcheck subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -63,10 +69,14 @@ func TestSNodeMatch3(t *testing.T) { sub := &subscriber.Type{} - prov.Subscribe("sport/tennis/#", packet.QoS2, sub, 0) // nolint: errcheck + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } + + prov.Subscribe("sport/tennis/#", sub, p) // nolint: errcheck subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -74,53 +84,56 @@ func TestMatch4(t *testing.T) { prov := allocProvider(t) sub := &subscriber.Type{} - prov.Subscribe("#", packet.QoS2, sub, 0) // nolint: errcheck + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } + prov.Subscribe("#", sub, p) // nolint: errcheck subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") subscribers = publishEntries{} - prov.subscriptionSearch("/sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("/sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") err := prov.subscriptionRemove("#", sub) require.NoError(t, err) subscribers = publishEntries{} - prov.subscriptionSearch("#", &subscribers) + prov.subscriptionSearch("#", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") - prov.subscriptionInsert("/#", packet.QoS2, sub, 0) + prov.subscriptionInsert("/#", sub, p) subscribers = publishEntries{} - prov.subscriptionSearch("bla", &subscribers) + prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") subscribers = publishEntries{} - prov.subscriptionSearch("/bla", &subscribers) + prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") err = prov.subscriptionRemove("/#", sub) require.NoError(t, err) - prov.subscriptionInsert("bla/bla/#", packet.QoS2, sub, 0) + prov.subscriptionInsert("bla/bla/#", sub, p) subscribers = publishEntries{} - prov.subscriptionSearch("bla", &subscribers) + prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") subscribers = publishEntries{} - prov.subscriptionSearch("bla/bla", &subscribers) + prov.subscriptionSearch("bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") subscribers = publishEntries{} - prov.subscriptionSearch("bla/bla/bla", &subscribers) + prov.subscriptionSearch("bla/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") subscribers = publishEntries{} - prov.subscriptionSearch("bla/bla/bla/bla", &subscribers) + prov.subscriptionSearch("bla/bla/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") } @@ -129,11 +142,15 @@ func TestMatch5(t *testing.T) { sub1 := &subscriber.Type{} sub2 := &subscriber.Type{} - prov.subscriptionInsert("sport/tennis/+/+/#", packet.QoS1, sub1, 0) - prov.subscriptionInsert("sport/tennis/player1/anzel", packet.QoS1, sub2, 0) + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + + prov.subscriptionInsert("sport/tennis/+/+/#", sub1, p) + prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 2, len(subscribers)) } @@ -142,12 +159,15 @@ func TestMatch6(t *testing.T) { prov := allocProvider(t) sub1 := &subscriber.Type{} sub2 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert("sport/tennis/+/+/+/+/#", packet.QoS1, sub1, 0) - prov.subscriptionInsert("sport/tennis/player1/anzel", packet.QoS1, sub2, 0) + prov.subscriptionInsert("sport/tennis/+/+/+/+/#", sub1, p) + prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel/bla/bla", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -156,13 +176,18 @@ func TestMatch7(t *testing.T) { sub1 := &subscriber.Type{} sub2 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } + + prov.subscriptionInsert("sport/tennis/#", sub1, p) - prov.subscriptionInsert("sport/tennis/#", packet.QoS2, sub1, 0) + p.Ops = packet.SubscriptionOptions(packet.QoS1) - prov.subscriptionInsert("sport/tennis", packet.QoS1, sub2, 0) + prov.subscriptionInsert("sport/tennis", sub2, p) subscribers := publishEntries{} - prov.subscriptionSearch("sport/tennis/player1/anzel", &subscribers) + prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) require.Equal(t, sub1, subscribers[sub1.Hash()][0].s) } @@ -170,13 +195,16 @@ func TestMatch7(t *testing.T) { func TestMatch8(t *testing.T) { prov := allocProvider(t) - sub1 := &subscriber.Type{} + sub := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } - prov.subscriptionInsert("+/+", packet.QoS2, sub1, 0) + prov.subscriptionInsert("+/+", sub, p) subscribers := publishEntries{} - prov.subscriptionSearch("/finance", &subscribers) + prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -184,12 +212,15 @@ func TestMatch9(t *testing.T) { prov := allocProvider(t) sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } - prov.subscriptionInsert("/+", packet.QoS2, sub1, 0) + prov.subscriptionInsert("/+", sub1, p) subscribers := publishEntries{} - prov.subscriptionSearch("/finance", &subscribers) + prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -197,44 +228,50 @@ func TestMatch10(t *testing.T) { prov := allocProvider(t) sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } - prov.subscriptionInsert("+", packet.QoS2, sub1, 0) + prov.subscriptionInsert("+", sub1, p) subscribers := publishEntries{} - prov.subscriptionSearch("/finance", &subscribers) + prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 0, len(subscribers)) } func TestInsertRemove(t *testing.T) { prov := allocProvider(t) sub := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } - prov.subscriptionInsert("#", packet.QoS2, sub, 0) + prov.subscriptionInsert("#", sub, p) subscribers := publishEntries{} - prov.subscriptionSearch("bla", &subscribers) + prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) subscribers = publishEntries{} - prov.subscriptionSearch("/bla", &subscribers) + prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 0, len(subscribers)) err := prov.subscriptionRemove("#", sub) require.NoError(t, err) subscribers = publishEntries{} - prov.subscriptionSearch("#", &subscribers) + prov.subscriptionSearch("#", 0, &subscribers) require.Equal(t, 0, len(subscribers)) - prov.subscriptionInsert("/#", packet.QoS2, sub, 0) + prov.subscriptionInsert("/#", sub, p) subscribers = publishEntries{} - prov.subscriptionSearch("bla", &subscribers) + prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers)) subscribers = publishEntries{} - prov.subscriptionSearch("/bla", &subscribers) + prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) err = prov.subscriptionRemove("#", sub) @@ -249,7 +286,11 @@ func TestInsert1(t *testing.T) { topic := "sport/tennis/player1/#" sub1 := &subscriber.Type{} - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) require.Equal(t, 1, len(prov.root.children)) require.Equal(t, 0, len(prov.root.subs)) @@ -288,8 +329,11 @@ func TestSNodeInsert2(t *testing.T) { topic := "#" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) require.Equal(t, 1, len(prov.root.children)) require.Equal(t, 0, len(prov.root.subs)) @@ -311,8 +355,11 @@ func TestSNodeInsert3(t *testing.T) { topic := "+/tennis/#" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) require.Equal(t, 1, len(prov.root.children)) require.Equal(t, 0, len(prov.root.subs)) @@ -346,8 +393,11 @@ func TestSNodeInsert4(t *testing.T) { topic := "/finance" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) require.Equal(t, 1, len(prov.root.children)) require.Equal(t, 0, len(prov.root.subs)) @@ -375,9 +425,12 @@ func TestSNodeInsertDup(t *testing.T) { topic := "/finance" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) + prov.subscriptionInsert(topic, sub1, p) require.Equal(t, 1, len(prov.root.children)) require.Equal(t, 0, len(prov.root.subs)) @@ -406,8 +459,11 @@ func TestSNodeRemove1(t *testing.T) { topic := "sport/tennis/player1/#" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) err := prov.subscriptionRemove(topic, sub1) require.NoError(t, err) @@ -421,8 +477,11 @@ func TestSNodeRemove2(t *testing.T) { topic := "sport/tennis/player1/#" sub1 := &subscriber.Type{} + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) + prov.subscriptionInsert(topic, sub1, p) err := prov.subscriptionRemove("sport/tennis/player1", sub1) require.EqualError(t, topicsTypes.ErrNotFound, err.Error()) @@ -435,8 +494,12 @@ func TestSNodeRemove3(t *testing.T) { sub1 := &subscriber.Type{} sub2 := &subscriber.Type{} - prov.subscriptionInsert(topic, packet.QoS1, sub1, 0) - prov.subscriptionInsert(topic, packet.QoS1, sub2, 0) + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + + prov.subscriptionInsert(topic, sub1, p) + prov.subscriptionInsert(topic, sub2, p) err := prov.subscriptionRemove("sport/tennis/player1/#", nil) require.NoError(t, err) @@ -452,13 +515,17 @@ func TestRetain1(t *testing.T) { prov.retain(m) } - _, rMsg, _ := prov.Subscribe("#", packet.QoS1, sub, 0) + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + + _, rMsg, _ := prov.Subscribe("#", sub, p) require.Equal(t, 0, len(rMsg)) - _, rMsg, _ = prov.Subscribe("$SYS", packet.QoS1, sub, 0) + _, rMsg, _ = prov.Subscribe("$SYS", sub, p) require.Equal(t, 0, len(rMsg)) - _, rMsg, _ = prov.Subscribe("$SYS/#", packet.QoS1, sub, 0) + _, rMsg, _ = prov.Subscribe("$SYS/#", sub, p) require.Equal(t, len(retainedSystree), len(rMsg)) } @@ -473,13 +540,17 @@ func TestRetain2(t *testing.T) { msg := newPublishMessageLarge("sport/tennis/player1/ricardo", packet.QoS1) prov.retain(msg) - prov.Subscribe("#", packet.QoS1, sub, 0) // nolint: errcheck + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS1), + } + + prov.Subscribe("#", sub, p) // nolint: errcheck var rMsg []*packet.Publish prov.retainSearch("#", &rMsg) require.Equal(t, 1, len(rMsg)) - _, rMsg, _ = prov.Subscribe("$SYS/#", packet.QoS1, sub, 0) + _, rMsg, _ = prov.Subscribe("$SYS/#", sub, p) require.Equal(t, len(retainedSystree), len(rMsg)) } @@ -596,7 +667,7 @@ func TestRNodeMatch(t *testing.T) { } func newPublishMessageLarge(topic string, qos packet.QosType) *packet.Publish { - m, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) + m, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) msg := m.(*packet.Publish) diff --git a/topics/topics_test.go b/topics/topics_test.go index 1486953..4f6585b 100644 --- a/topics/topics_test.go +++ b/topics/topics_test.go @@ -49,31 +49,39 @@ func TestTopicsOpenCloseProvider(t *testing.T) { } } -func TestTopicsSubscribeInvalidQoS(t *testing.T) { - for _, p := range testProviders { - prov, err := New(p.config) - require.NoError(t, err) - - _, _, err = prov.Subscribe("test", packet.QosType(3), nil, 0) - require.Error(t, packet.ErrInvalidQoS, err.Error()) - - err = prov.Close() - require.NoError(t, err) - } -} - -func TestTopicsSubscribeInvalidMessage(t *testing.T) { - for _, p := range testProviders { - prov, err := New(p.config) - require.NoError(t, err) - - _, _, err = prov.Subscribe("test", packet.QosType(3), nil, 0) - require.Error(t, packet.ErrInvalidQoS, err.Error()) +//func TestTopicsSubscribeInvalidQoS(t *testing.T) { +// for _, p := range testProviders { +// prov, err := New(p.config) +// require.NoError(t, err) +// +// p := &topicsTypes.SubscriptionParams{ +// Ops: packet.SubscriptionOptions(packet.QosType(3)), +// } +// +// _, _, err = prov.Subscribe("test", nil, p) +// require.Error(t, packet.ErrInvalidQoS, err.Error()) +// +// err = prov.Close() +// require.NoError(t, err) +// } +//} - err = prov.Close() - require.NoError(t, err) - } -} +//func TestTopicsSubscribeInvalidMessage(t *testing.T) { +// for _, p := range testProviders { +// prov, err := New(p.config) +// require.NoError(t, err) +// +// p := &topicsTypes.SubscriptionParams{ +// Ops: packet.SubscriptionOptions(packet.QosType(3)), +// } +// +// _, _, err = prov.Subscribe("test", nil, p) +// require.Error(t, packet.ErrInvalidQoS, err.Error()) +// +// err = prov.Close() +// require.NoError(t, err) +// } +//} func TestTopicsSubscription(t *testing.T) { for _, p := range testProviders { @@ -81,7 +89,11 @@ func TestTopicsSubscription(t *testing.T) { require.NoError(t, err) sub1 := &subscriber.Type{} - qos, _, err := prov.Subscribe("sports/tennis/+/stats", packet.QoS2, sub1, 0) + p := &topicsTypes.SubscriptionParams{ + Ops: packet.SubscriptionOptions(packet.QoS2), + } + + qos, _, err := prov.Subscribe("sports/tennis/+/stats", sub1, p) require.NoError(t, err) require.Equal(t, packet.QoS2, qos) diff --git a/topics/types/types.go b/topics/types/types.go index 5606c89..35c23e4 100644 --- a/topics/types/types.go +++ b/topics/types/types.go @@ -69,7 +69,7 @@ var ( type Subscriber interface { Acquire() Release() - Publish(*packet.Publish, packet.QosType, []uint32) error + Publish(*packet.Publish, packet.QosType, packet.SubscriptionOptions, []uint32) error Hash() uintptr } @@ -78,7 +78,7 @@ type Subscribers []Subscriber // Provider interface type Provider interface { - Subscribe(string, packet.QosType, Subscriber, uint32) (packet.QosType, []*packet.Publish, error) + Subscribe(string, Subscriber, *SubscriptionParams) (packet.QosType, []*packet.Publish, error) UnSubscribe(string, Subscriber) error Publish(interface{}) error Retain(types.RetainObject) error @@ -89,21 +89,34 @@ type Provider interface { // SubscriberInterface used by subscriber to handle messages type SubscriberInterface interface { Publish(interface{}) error - Subscribe(string, packet.QosType, Subscriber, uint32) (packet.QosType, []*packet.Publish, error) + Subscribe(string, Subscriber, *SubscriptionParams) (packet.QosType, []*packet.Publish, error) UnSubscribe(string, Subscriber) error Retain(types.RetainObject) error Retained(string) ([]*packet.Publish, error) } +// SubscriptionParams parameters of the subscription +type SubscriptionParams struct { + // Subscription id + // V5.0 ONLY + ID uint32 + + // Ops requested subscription options + Ops packet.SubscriptionOptions + + // Granted QoS granted by topics manager + Granted packet.QosType +} + var ( // ErrMultiLevel multi-level wildcard - ErrMultiLevel = errors.New("Multi-level wildcard found in topic and it's not at the last level") + ErrMultiLevel = errors.New("multi-level wildcard found in topic and it's not at the last level") // ErrInvalidSubscriber invalid subscriber object - ErrInvalidSubscriber = errors.New("Subscriber cannot be nil") + ErrInvalidSubscriber = errors.New("subscriber cannot be nil") // ErrInvalidWildcardPlus Wildcard character '+' must occupy entire topic level - ErrInvalidWildcardPlus = errors.New("Wildcard character '+' must occupy entire topic level") + ErrInvalidWildcardPlus = errors.New("wildcard character '+' must occupy entire topic level") // ErrInvalidWildcardSharp Wildcard character '#' must occupy entire topic level - ErrInvalidWildcardSharp = errors.New("Wildcard character '#' must occupy entire topic level") + ErrInvalidWildcardSharp = errors.New("wildcard character '#' must occupy entire topic level") // ErrInvalidWildcard Wildcard characters '#' and '+' must occupy entire topic level - ErrInvalidWildcard = errors.New("Wildcard characters '#' and '+' must occupy entire topic level") + ErrInvalidWildcard = errors.New("wildcard characters '#' and '+' must occupy entire topic level") ) diff --git a/transport/base.go b/transport/base.go index 1704a44..8045fa3 100644 --- a/transport/base.go +++ b/transport/base.go @@ -119,7 +119,7 @@ func (c *baseConfig) handleConnection(conn conn) { conn.SetReadDeadline(time.Time{}) // nolint: errcheck switch r := req.(type) { case *packet.Connect: - m, _ := packet.NewMessage(req.Version(), packet.CONNACK) + m, _ := packet.New(req.Version(), packet.CONNACK) resp, _ := m.(*packet.ConnAck) var reason packet.ReasonCode @@ -135,10 +135,6 @@ func (c *baseConfig) handleConnection(conn conn) { if status := c.config.AuthManager.Password(string(user), string(pass)); status == auth.StatusAllow { reason = packet.CodeSuccess - if r.KeepAlive() == 0 { - r.SetKeepAlive(uint16(c.KeepAlive)) - resp.PropertySet(packet.PropertyServerKeepAlive, uint16(c.KeepAlive)) // nolint: errcheck - } } else { reason = packet.CodeRefusedBadUsernameOrPassword if req.Version() == packet.ProtocolV50 { diff --git a/transport/tcp.go b/transport/tcp.go index e3c2d2f..3298a86 100644 --- a/transport/tcp.go +++ b/transport/tcp.go @@ -44,7 +44,7 @@ func NewTCP(config *ConfigTCP, internal *InternalConfig) (Provider, error) { l.protocol = config.Scheme l.InternalConfig = *internal l.config = *config.transport - l.log = configuration.GetProdLogger().Named("server.transport.tcp") + l.log = configuration.GetLogger().Named("server.transport.tcp") var err error diff --git a/transport/websocket.go b/transport/websocket.go index d235f51..287068f 100644 --- a/transport/websocket.go +++ b/transport/websocket.go @@ -70,7 +70,7 @@ func NewWS(config *ConfigWS, internal *InternalConfig) (Provider, error) { l.protocol = "ws" l.InternalConfig = *internal l.config = *config.transport - l.log = configuration.GetProdLogger().Named("server.transport.ws") + l.log = configuration.GetLogger().Named("server.transport.ws") if len(config.Path) == 0 { config.Path = "/" diff --git a/types/types.go b/types/types.go index 36051ca..6aed5cf 100644 --- a/types/types.go +++ b/types/types.go @@ -16,10 +16,12 @@ type LogInterface struct { // Default configs const ( - DefaultKeepAlive = 300 // DefaultKeepAlive default keep - DefaultConnectTimeout = 2 // DefaultConnectTimeout connect timeout - DefaultAckTimeout = 20 // DefaultAckTimeout ack timeout - DefaultTimeoutRetries = 3 // DefaultTimeoutRetries retries + DefaultKeepAlive = 60 // DefaultKeepAlive default keep + DefaultConnectTimeout = 2 // DefaultConnectTimeout connect timeout + DefaultMaxPacketSize = 268435455 + DefaultReceiveMax = 65535 + DefaultAckTimeout = 20 // DefaultAckTimeout ack timeout + DefaultTimeoutRetries = 3 // DefaultTimeoutRetries retries MinKeepAlive = 30 DefaultSessionsProvider = "mem" // DefaultSessionsProvider default session provider DefaultAuthenticator = "mockSuccess" // DefaultAuthenticator default auth provider @@ -66,7 +68,7 @@ type Once struct { // // If f panics, Do considers it to have returned; future calls of Do return // without calling f. -func (o *OnceWait) Do(f func()) { +func (o *OnceWait) Do(f func()) bool { o.lock.Lock() res := atomic.CompareAndSwapUintptr(&o.val, 0, 1) if res { @@ -80,6 +82,8 @@ func (o *OnceWait) Do(f func()) { } else { o.wait.Wait() } + + return res } // Do calls the function f if and only if Do is being called for the @@ -99,8 +103,11 @@ func (o *OnceWait) Do(f func()) { // // If f panics, Do considers it to have returned; future calls of Do return // without calling f. -func (o *Once) Do(f func()) { +func (o *Once) Do(f func()) bool { if atomic.CompareAndSwapUintptr(&o.val, 0, 1) { f() + return true } + + return false } diff --git a/volantmq.go b/volantmq.go index 62137fc..260cb7a 100644 --- a/volantmq.go +++ b/volantmq.go @@ -21,6 +21,7 @@ import ( "github.com/VolantMQ/volantmq/transport" "github.com/VolantMQ/volantmq/types" "github.com/pborman/uuid" + "go.uber.org/zap" ) var ( @@ -68,6 +69,9 @@ type ServerConfig struct { // If not set than defaults to 0x3 and 0x04 AllowedVersions map[packet.ProtocolVersion]bool + // MaxPacketSize + MaxPacketSize uint32 + // AllowOverlappingSubscriptions tells server how to handle overlapping subscriptions from within one client // if true server will send only one publish with max subscribed QoS even there are n subscriptions // if false server will send as many publishes as amount of subscriptions matching publish topic exists @@ -85,7 +89,11 @@ type ServerConfig struct { // If not set than default is false AllowDuplicates bool + // WithSystree WithSystree bool + + // ForceKeepAlive + ForceKeepAlive bool } // NewServerConfig with default values. It's highly recommended to use that function to allocate config @@ -100,13 +108,15 @@ func NewServerConfig() *ServerConfig { AllowOverlappingSubscriptions: true, RewriteNodeName: false, WithSystree: true, - SystreeUpdateInterval: 5, + SystreeUpdateInterval: 0, KeepAlive: types.DefaultKeepAlive, ConnectTimeout: types.DefaultConnectTimeout, + MaxPacketSize: types.DefaultMaxPacketSize, TransportStatus: func(id string, status string) {}, AllowedVersions: map[packet.ProtocolVersion]bool{ packet.ProtocolV31: true, packet.ProtocolV311: true, + packet.ProtocolV50: true, }, } } @@ -127,45 +137,24 @@ type Server interface { // server is a library implementation of the MQTT server that, as best it can, complies // with the MQTT 3.1/3.1.1 and 5.0 specs. type server struct { - config *ServerConfig - // authMgr is the authentication manager that we are going to use for authenticating - // incoming connections - authMgr *auth.Manager - - // sessionsMgr is the sessions manager for keeping track of the sessions + config *ServerConfig + authMgr *auth.Manager sessionsMgr *clients.Manager - - log types.LogInterface - - // topicsMgr is the topics manager for keeping track of subscriptions - topicsMgr topicsTypes.Provider - - persist persistenceTypes.Provider - - sysTree systree.Provider - - // The quit channel for the server. If the server detects that this channel - // is closed, then it's a signal for it to shutdown as well. - quit chan struct{} - - lock sync.Mutex - - onClose sync.Once - - transports struct { + log *zap.Logger + topicsMgr topicsTypes.Provider + persist persistenceTypes.Provider + sysTree systree.Provider + quit chan struct{} + lock sync.Mutex + onClose sync.Once + transports struct { list map[int]transport.Provider wg sync.WaitGroup } - systree struct { - done chan bool - wgStarted sync.WaitGroup - wgStopped sync.WaitGroup - timer *time.Ticker + publishes []systree.DynamicValue + timer *time.Timer } - - // nodes cluster nodes - //nodes map[string]subscriber.Provider } // NewServer allocate server object @@ -180,8 +169,7 @@ func NewServer(config *ServerConfig) (Server, error) { s.config = config - s.log.Prod = configuration.GetProdLogger().Named("server") - s.log.Dev = configuration.GetDevLogger().Named("server") + s.log = configuration.GetLogger().Named("server") s.quit = make(chan struct{}) s.transports.list = make(map[int]transport.Provider) @@ -192,7 +180,7 @@ func NewServer(config *ServerConfig) (Server, error) { } if s.config.Persistence == nil { - return nil, errors.New("Persistence provider cannot be nil") + return nil, errors.New("persistence provider cannot be nil") } if s.persist, err = persistence.New(s.config.Persistence); err != nil { @@ -230,9 +218,9 @@ func NewServer(config *ServerConfig) (Server, error) { var persisRetained persistenceTypes.Retained var retains []types.RetainObject - var dynPublishes []systree.DynamicValue + //var dynPublishes []systree.DynamicValue - if s.sysTree, retains, dynPublishes, err = systree.NewTree("$SYS/servers/" + s.config.NodeName); err != nil { + if s.sysTree, retains, s.systree.publishes, err = systree.NewTree("$SYS/servers/" + s.config.NodeName); err != nil { return nil, err } @@ -258,22 +246,27 @@ func NewServer(config *ServerConfig) (Server, error) { } if s.config.SystreeUpdateInterval > 0 { - s.systree.wgStarted.Add(1) - s.systree.wgStopped.Add(1) - go s.systreeUpdater(dynPublishes, s.config.SystreeUpdateInterval*time.Second) - s.systree.wgStarted.Wait() + s.systree.timer = time.AfterFunc(s.config.SystreeUpdateInterval*time.Second, s.systreeUpdater) } } mConfig := &clients.Config{ - TopicsMgr: s.topicsMgr, - ConnectTimeout: s.config.ConnectTimeout, - Persist: s.persist, - AllowReplace: s.config.AllowDuplicates, - OnReplaceAttempt: s.config.OnDuplicate, - OfflineQoS0: s.config.OfflineQoS0, - Systree: s.sysTree, - NodeName: s.config.NodeName, + TopicsMgr: s.topicsMgr, + ConnectTimeout: s.config.ConnectTimeout, + Persist: s.persist, + Systree: s.sysTree, + AllowReplace: s.config.AllowDuplicates, + OnReplaceAttempt: s.config.OnDuplicate, + NodeName: s.config.NodeName, + OfflineQoS0: s.config.OfflineQoS0, + AvailableRetain: true, + AvailableSubscriptionID: false, + AvailableSharedSubscription: false, + AvailableWildcardSubscription: true, + TopicAliasMaximum: 0xFFFF, + ReceiveMax: types.DefaultReceiveMax, + MaxPacketSize: types.DefaultMaxPacketSize, + MaximumQoS: packet.QoS2, } if s.sessionsMgr, err = clients.NewManager(mConfig); err != nil { @@ -301,7 +294,7 @@ func (s *server) ListenAndServe(config interface{}) error { case *transport.ConfigWS: l, err = transport.NewWS(c, &internalConfig) default: - return errors.New("Invalid listener type") + return errors.New("invalid listener type") } if err != nil { @@ -313,7 +306,7 @@ func (s *server) ListenAndServe(config interface{}) error { if _, ok := s.transports.list[l.Port()]; ok { l.Close() // nolint: errcheck - return errors.New("Already exists") + return errors.New("already exists") } s.transports.list[l.Port()] = l @@ -342,19 +335,11 @@ func (s *server) Close() error { defer s.lock.Unlock() s.lock.Lock() - // shutdown systree updater - if s.systree.timer != nil { - s.systree.timer.Stop() - s.systree.done <- true - s.systree.wgStopped.Wait() - close(s.systree.done) - } - // We then close all net.Listener, which will force Accept() to return if it's // blocked waiting for new connections. for _, l := range s.transports.list { if err := l.Close(); err != nil { - s.log.Prod.Error(err.Error()) + s.log.Error(err.Error()) } } @@ -374,33 +359,28 @@ func (s *server) Close() error { if s.topicsMgr != nil { s.topicsMgr.Close() // nolint: errcheck, gas } + + // shutdown systree updater + if s.systree.timer != nil { + s.systree.timer.Stop() + } + }) return nil } -func (s *server) systreeUpdater(publishes []systree.DynamicValue, period time.Duration) { - defer s.systree.wgStopped.Done() - - s.systree.done = make(chan bool) - s.systree.timer = time.NewTicker(period) - s.systree.wgStarted.Done() - - for { - select { - case <-s.systree.timer.C: - for _, m := range publishes { - _m := m.Publish() - _msg, _ := packet.NewMessage(packet.ProtocolV311, packet.PUBLISH) - msg, _ := _msg.(*packet.Publish) - - msg.SetPayload(_m.Payload()) - msg.SetTopic(_m.Topic()) // nolint: errcheck - msg.SetQoS(_m.QoS()) // nolint: errcheck - s.topicsMgr.Publish(msg) // nolint: errcheck - } - case <-s.systree.done: - return - } +func (s *server) systreeUpdater() { + for _, m := range s.systree.publishes { + _m := m.Publish() + _msg, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) + msg, _ := _msg.(*packet.Publish) + + msg.SetPayload(_m.Payload()) + msg.SetTopic(_m.Topic()) // nolint: errcheck + msg.SetQoS(_m.QoS()) // nolint: errcheck + s.topicsMgr.Publish(msg) // nolint: errcheck } + + s.systree.timer.Reset(s.config.SystreeUpdateInterval * time.Second) }