From b359ea58b1571c6bcd87e898aada94fa5f000ce8 Mon Sep 17 00:00:00 2001 From: Artur Troian Date: Fri, 6 Oct 2017 12:01:03 +0300 Subject: [PATCH] Tons of fixes --- .gitignore | 4 +- .pre-commit-config.yaml | 3 +- .travis.yml | 2 +- auth/basic.go | 9 + auth/manager.go | 4 +- clients/container.go | 52 ++ clients/expiry.go | 99 ++++ clients/session.go | 479 ++++++++-------- clients/sessions.go | 1068 +++++++++++++++++++----------------- connection/ack.go | 2 +- connection/connection.go | 874 ++++++++++++++++++----------- connection/flowControl.go | 19 +- connection/keepAlive.go | 15 + connection/netCallbacks.go | 262 +++------ connection/options.go | 130 +++++ connection/receiver.go | 103 ++-- connection/transmitter.go | 186 ++++--- go.test.sh | 2 +- packet/auth.go | 12 +- packet/connack.go | 11 +- packet/connack_test.go | 40 +- packet/connect.go | 25 +- packet/connect_test.go | 64 +-- packet/disconnect.go | 10 + packet/disconnect_test.go | 30 +- packet/errors.go | 2 +- packet/header.go | 36 +- packet/header_test.go | 68 +-- packet/packet.go | 58 +- packet/packetType.go | 4 +- packet/ping_test.go | 48 +- packet/pingreq.go | 7 + packet/pingresp.go | 9 +- packet/property.go | 7 +- packet/puback.go | 28 +- packet/puback_test.go | 28 +- packet/pubcomp_test.go | 28 +- packet/publish.go | 46 +- packet/publish_test.go | 24 +- packet/pubrec_test.go | 28 +- packet/pubrel_test.go | 28 +- packet/reasonCodes.go | 9 +- packet/suback.go | 7 + packet/suback_test.go | 28 +- packet/subscribe.go | 7 + packet/subscribe_test.go | 34 +- packet/unsuback.go | 17 + packet/unsuback_test.go | 28 +- packet/unsubscribe.go | 7 + packet/unsubscribe_test.go | 28 +- subscriber/subscriber.go | 152 ++--- systree/clients.go | 5 +- systree/sessions.go | 4 +- systree/tree.go | 13 +- topics/mem/node.go | 34 +- topics/mem/topics.go | 13 +- topics/mem/trie_test.go | 56 +- topics/types/types.go | 1 - transport/base.go | 79 +-- transport/connWS.go | 3 +- transport/tcp.go | 4 - transport/websocket.go | 6 +- types/types.go | 5 +- volantmq.go | 30 +- 64 files changed, 2545 insertions(+), 1979 deletions(-) create mode 100644 auth/basic.go create mode 100644 clients/container.go create mode 100644 clients/expiry.go create mode 100644 connection/keepAlive.go create mode 100644 connection/options.go diff --git a/.gitignore b/.gitignore index 5515686..be40369 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,6 @@ _testmain.go persistence/examples/bolt/test.db examples/tcp/persist.db coverage.txt -vendor \ No newline at end of file +vendor + +persist.db \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ef9999b..cc835a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,6 @@ - repo: https://github.com/troian/pre-commit-golang sha: c43dbbf0704e4c722e2b64672aaee689145e17a6 hooks: - - id: go-imports - id: go-build - id: go-metalinter args: @@ -11,3 +10,5 @@ - --dupl-threshold=100 - --disable=gotype - --disable=govendor + - --vendor + - --enable=goimports diff --git a/.travis.yml b/.travis.yml index ff0a851..1fbf5e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ jobs: os: linux script: - go build -race -v -gcflags "-N -l" ./examples/... - - gometalinter --fast --exclude=corefoundation.go --deadline=360s --enable-gc --sort=path --vendor --cyclo-over=40 --dupl-threshold=100 --disable=gotype ./... + - gometalinter --fast --exclude=corefoundation.go --deadline=360s --enable-gc --sort=path --vendor --cyclo-over=40 --dupl-threshold=100 --disable=gotype --vendor --enable=goimports./... - ./go.test.sh after_success: - bash <(curl -s https://codecov.io/bash) diff --git a/auth/basic.go b/auth/basic.go new file mode 100644 index 0000000..4b132f7 --- /dev/null +++ b/auth/basic.go @@ -0,0 +1,9 @@ +package auth + +type Simple interface { + Password(u, p string) Status +} + +type Anonymous interface { + Is() Status +} diff --git a/auth/manager.go b/auth/manager.go index b0dbde8..52c16a7 100644 --- a/auth/manager.go +++ b/auth/manager.go @@ -16,11 +16,11 @@ var providers = make(map[string]Provider) // Register auth provider func Register(name string, provider Provider) error { if name == "" && provider == nil { - return errors.New("Invalid args") + return errors.New("invalid args") } if _, dup := providers[name]; dup { - return errors.New("Already exists") + return errors.New("already exists") } providers[name] = provider diff --git a/clients/container.go b/clients/container.go new file mode 100644 index 0000000..4ec7364 --- /dev/null +++ b/clients/container.go @@ -0,0 +1,52 @@ +package clients + +import ( + "sync" + "sync/atomic" + + "github.com/VolantMQ/volantmq/subscriber" +) + +type container struct { + lock sync.Mutex + ses atomic.Value + expiry atomic.Value + sub *subscriber.Type +} + +func (s *container) acquire() { + s.lock.Lock() +} + +func (s *container) release() { + s.lock.Unlock() +} + +func (s *container) session() *session { + return s.ses.Load().(*session) +} + +func (s *container) swap(w *container) *container { + s.ses = w.ses + + ses := s.ses.Load().(*session) + ses.idLock = &s.lock + + return s +} + +func (s *container) subscriber(cleanStart bool, c subscriber.Config) (*subscriber.Type, bool) { + if cleanStart && s.sub != nil { + s.sub.Offline(true) + s.sub = nil + } + + if s.sub == nil { + s.sub = subscriber.New(c) + cleanStart = true + } else { + cleanStart = false + } + + return s.sub, !cleanStart +} diff --git a/clients/expiry.go b/clients/expiry.go new file mode 100644 index 0000000..4e98e0d --- /dev/null +++ b/clients/expiry.go @@ -0,0 +1,99 @@ +package clients + +import ( + "sync" + "time" + + "github.com/VolantMQ/volantmq/packet" + "github.com/VolantMQ/volantmq/types" +) + +type expiryEvent interface { + sessionTimer(string, bool) +} + +type expiryConfig struct { + expiryEvent + id string + createdAt time.Time + messenger types.TopicMessenger + will *packet.Publish + expireIn *uint32 + willDelay uint32 +} + +type expiry struct { + expiryConfig + expiringSince time.Time + timerLock sync.Mutex + timer *time.Timer +} + +func newExpiry(c expiryConfig) *expiry { + return &expiry{ + expiryConfig: c, + } +} + +func (s *expiry) start() { + var timerPeriod uint32 + + // if meet will requirements point that + if s.will != nil && s.willDelay > 0 { + timerPeriod = s.willDelay + } else { + s.will = nil + } + + if s.expireIn != nil { + // if will delay is set before and value less than expiration + // then timer should fire 2 times + if (timerPeriod > 0) && (timerPeriod < *s.expireIn) { + *s.expireIn = *s.expireIn - timerPeriod + } else { + timerPeriod = *s.expireIn + *s.expireIn = 0 + } + } + + s.expiringSince = time.Now() + s.timer = time.NewTimer(time.Duration(timerPeriod) * time.Second) +} + +func (s *expiry) cancel() { + if !s.timer.Stop() { + s.timerLock.Lock() + s.timerLock.Unlock() // nolint: megacheck + } +} + +func (s *expiry) timerCallback() { + defer s.timerLock.Unlock() + s.timerLock.Lock() + + // 1. check for will message available + if s.will != nil { + // publish if exists and wipe state + s.messenger.Publish(s.will) // nolint: errcheck + s.will = nil + s.willDelay = 0 + } + + if s.expireIn == nil { + // 2.a session has processed delayed will and there is nothing to do + // completely shutdown the session + s.sessionTimer(s.id, false) + } else if *s.expireIn == 0 { + // session has expired. WIPE IT + //if s.subscriber != nil { + // s.shutdownSubscriber(s.subscriber) + //} + s.sessionTimer(s.id, true) + } else { + // restart timer and wait again + val := *s.expireIn + // clear value pointed by expireIn so when next fire comes we signal session is expired + *s.expireIn = 0 + s.timer.Reset(time.Duration(val) * time.Second) + } +} diff --git a/clients/session.go b/clients/session.go index 55b3a46..04a87fc 100644 --- a/clients/session.go +++ b/clients/session.go @@ -1,351 +1,330 @@ package clients import ( + "container/list" "sync" "time" - "strconv" - "github.com/VolantMQ/persistence" "github.com/VolantMQ/volantmq/connection" "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/subscriber" + "github.com/VolantMQ/volantmq/topics/types" "github.com/VolantMQ/volantmq/types" + "go.uber.org/zap" ) -type exitReason int - -const ( - exitReasonClean exitReason = iota - exitReasonShutdown - exitReasonExpired -) - -type switchStatus int - -const ( - swStatusSwitched switchStatus = iota - swStatusIsOnline - swStatusFinalized -) - -type onSessionClose func(string, exitReason) -type onDisconnect func(string, packet.ReasonCode, bool) -type onSubscriberShutdown func(subscriber.ConnectionProvider) - -type sessionEvents struct { - signalClose onSessionClose - signalDisconnected onDisconnect - shutdownSubscriber onSubscriberShutdown +type sessionEvents interface { + sessionOffline(string, bool, *expiry) + connectionClosed(string, packet.ReasonCode, bool) + subscriberShutdown(string, subscriber.SessionProvider) } type sessionPreConfig struct { - sessionEvents - id string - createdAt time.Time - messenger types.TopicMessenger -} - -type sessionReConfig struct { - subscriber subscriber.ConnectionProvider - will *packet.Publish - expireIn *uint32 - willDelay uint32 - killOnDisconnect bool + id string + createdAt time.Time + messenger types.TopicMessenger + conn connection.Session + persistence persistence.Packets } -type session struct { +type sessionConfig struct { sessionEvents - *sessionReConfig - id string - idLock *sync.Mutex - messenger types.TopicMessenger - createdAt time.Time - expiringSince time.Time - lock sync.Mutex - connStop *types.Once - disconnectOnce *types.OnceWait - wgDisconnected sync.WaitGroup - conn *connection.Type - timer *time.Timer - timerLock sync.Mutex - finalized bool - isOnline chan struct{} + subscriber subscriber.SessionProvider + will *packet.Publish + expireIn *uint32 + willDelay uint32 + durable bool + version packet.ProtocolVersion } -type sessionWrap struct { - s *session - lock sync.Mutex +type session struct { + sessionPreConfig + log *zap.Logger + idLock *sync.Mutex + lock sync.Mutex + connStop types.Once + sessionConfig } -func (s *sessionWrap) acquire() { - s.lock.Lock() +type temporaryPublish struct { + gList *list.List + qList *list.List } -func (s *sessionWrap) release() { - s.lock.Unlock() +func newTmpPublish() *temporaryPublish { + return &temporaryPublish{ + gList: list.New(), + qList: list.New(), + } } -func (s *sessionWrap) swap(w *sessionWrap) *session { - s.s = w.s - s.s.idLock = &s.lock - return s.s +func (t *temporaryPublish) Publish(id string, p *packet.Publish) { + if p.QoS() == packet.QoS0 { + t.gList.PushBack(p) + } else { + t.qList.PushBack(p) + } } -func newSession(c *sessionPreConfig) *session { +func newSession(c sessionPreConfig) *session { s := &session{ - sessionEvents: c.sessionEvents, - id: c.id, - createdAt: c.createdAt, - messenger: c.messenger, - isOnline: make(chan struct{}), + sessionPreConfig: c, } - s.timer = time.AfterFunc(10*time.Second, s.timerCallback) - s.timer.Stop() - - close(s.isOnline) return s } -func (s *session) reconfigure(c *sessionReConfig, runExpiry bool) { - s.sessionReConfig = c - s.finalized = false - if runExpiry { - s.runExpiry(true) - } -} - -func (s *session) allocConnection(c *connection.PreConfig) error { - cfg := &connection.Config{ - PreConfig: c, - ID: s.id, - OnDisconnect: s.onDisconnect, - Subscriber: s.subscriber, - Messenger: s.messenger, - KillOnDisconnect: s.killOnDisconnect, - ExpireIn: s.expireIn, - } +func (s *session) configure(c sessionConfig, clean bool) { + s.sessionConfig = c - s.disconnectOnce = &types.OnceWait{} - s.connStop = &types.Once{} + s.conn.SetOptions(connection.AttachSession(s)) - var err error - s.conn, err = connection.New(cfg) - - return err + if !clean { + tmp := newTmpPublish() + s.subscriber.Online(tmp) + s.persistence.PacketsForEach([]byte(s.id), s.conn) + s.subscriber.Online(s.conn) + s.conn.LoadRemaining(tmp.gList, tmp.qList) + } else { + s.subscriber.Online(s.conn) + } } func (s *session) start() { - s.isOnline = make(chan struct{}) - s.wgDisconnected.Add(1) - s.conn.Start() s.idLock.Unlock() } func (s *session) stop(reason packet.ReasonCode) *persistence.SessionState { - s.connStop.Do(func() { - if s.conn != nil { - s.conn.Stop(reason) - s.conn = nil - } - }) - - s.wgDisconnected.Wait() + s.conn.Stop(reason) - if !s.timer.Stop() { - s.timerLock.Lock() - s.timerLock.Unlock() // nolint: megacheck - } - - if !s.finalized { - s.signalClose(s.id, exitReasonShutdown) - s.finalized = true - } + //s.wgDisconnected.Wait() state := &persistence.SessionState{ Timestamp: s.createdAt.Format(time.RFC3339), } - if s.expireIn != nil || (s.willDelay > 0 && s.will != nil) { - state.Expire = &persistence.SessionDelays{ - Since: s.expiringSince.Format(time.RFC3339), - } + //if s.expireIn != nil || (s.willDelay > 0 && s.will != nil) { + // state.Expire = &persistence.SessionDelays{ + // Since: s.expiringSince.Format(time.RFC3339), + // } + // + // elapsed := uint32(time.Since(s.expiringSince) / time.Second) + // + // if (s.willDelay > 0 && s.will != nil) && (s.willDelay-elapsed) > 0 { + // s.willDelay = s.willDelay - elapsed + // s.will.SetPacketID(0) + // if buf, err := packet.Encode(s.will); err != nil { + // + // } else { + // state.Expire.WillIn = strconv.Itoa(int(s.willDelay)) + // state.Expire.WillData = buf + // } + // } + // + // if s.expireIn != nil && *s.expireIn > 0 && (*s.expireIn-elapsed) > 0 { + // *s.expireIn = *s.expireIn - elapsed + // } + //} - elapsed := uint32(time.Since(s.expiringSince) / time.Second) + return state +} - if (s.willDelay > 0 && s.will != nil) && (s.willDelay-elapsed) > 0 { - s.willDelay = s.willDelay - elapsed - s.will.SetPacketID(0) - if buf, err := packet.Encode(s.will); err != nil { +func (s *session) SignalPublish(pkt *packet.Publish) error { + pkt.SetPublishID(s.subscriber.Hash()) - } else { - state.Expire.WillIn = strconv.Itoa(int(s.willDelay)) - state.Expire.WillData = buf - } + // [MQTT-3.3.1.3] + if pkt.Retain() { + if err := s.messenger.Retain(pkt); err != nil { + s.log.Error("Error retaining message", zap.String("ClientID", s.id), zap.Error(err)) } - if s.expireIn != nil && *s.expireIn > 0 && (*s.expireIn-elapsed) > 0 { - *s.expireIn = *s.expireIn - elapsed + // [MQTT-3.3.1-7] + if pkt.QoS() == packet.QoS0 { + retained := packet.NewPublish(s.version) + retained.SetQoS(pkt.QoS()) // nolint: errcheck + retained.SetTopic(pkt.Topic()) // nolint: errcheck + //s.retained.list = append(s.retained.list, m) } } - return state + if err := s.messenger.Publish(pkt); err != nil { + s.log.Error("Couldn't publish", zap.String("ClientID", s.id), zap.Error(err)) + } + + return nil } -// setOnline try switch session state from offline to online. This is necessary when -// when previous network connection has set session expiry or will delay or both -// if switch is successful then swStatusSwitched returned. -// if session has active network connection then returned value is swStatusIsOnline -// if connection has been closed and must not be used anymore then it returns swStatusFinalized -func (s *session) setOnline() switchStatus { - isOnline := false - // check session online status - s.lock.Lock() - select { - case <-s.isOnline: - default: - isOnline = true - } - s.lock.Unlock() +func (s *session) SignalSubscribe(pkt *packet.Subscribe) (packet.Provider, error) { + m, _ := packet.New(s.version, packet.SUBACK) + resp, _ := m.(*packet.SubAck) + + id, _ := pkt.ID() + resp.SetPacketID(id) - status := swStatusSwitched - if !isOnline { - // session is offline. before making any further step wait disconnect procedure is done - s.wgDisconnected.Wait() + var retCodes []packet.ReasonCode + var retainedPublishes []*packet.Publish - // if stop returns false timer has been fired and there is goroutine might be running - if !s.timer.Stop() { - s.timerLock.Lock() - s.timerLock.Unlock() // nolint: megacheck + pkt.RangeTopics(func(t string, ops packet.SubscriptionOptions) { + reason := packet.CodeSuccess // nolint: ineffassign + //authorized := true + // TODO: check permissions here + + //if authorized { + subsID := uint32(0) + + // V5.0 [MQTT-3.8.2.1.2] + if prop := pkt.PropertyGet(packet.PropertySubscriptionIdentifier); prop != nil { + if v, e := prop.AsInt(); e == nil { + subsID = v + } } - if s.finalized { - status = swStatusFinalized + subsParams := topicsTypes.SubscriptionParams{ + ID: subsID, + Ops: ops, } - } else { - status = swStatusIsOnline - } - return status -} + if grantedQoS, retained, err := s.subscriber.Subscribe(t, &subsParams); err != nil { + // [MQTT-3.9.3] + if s.version == packet.ProtocolV50 { + reason = packet.CodeUnspecifiedError + } else { + reason = packet.QosFailure + } + } else { + reason = packet.ReasonCode(grantedQoS) + retainedPublishes = append(retainedPublishes, retained...) + } -func (s *session) runExpiry(will bool) { - var timerPeriod uint32 + retCodes = append(retCodes, reason) + }) - // if meet will requirements point that - if will && s.will != nil && s.willDelay > 0 { - timerPeriod = s.willDelay - } else { - s.will = nil + if err := resp.AddReturnCodes(retCodes); err != nil { + return nil, err } - if s.expireIn != nil { - // if will delay is set before and value less than expiration - // then timer should fire 2 times - if (timerPeriod > 0) && (timerPeriod < *s.expireIn) { - *s.expireIn = *s.expireIn - timerPeriod + // Now put retained messages into publish queue + for _, rp := range retainedPublishes { + if pkt, err := rp.Clone(s.version); err == nil { + pkt.SetRetain(true) + s.conn.Publish(s.id, pkt) } else { - timerPeriod = *s.expireIn - *s.expireIn = 0 + s.log.Error("Couldn't clone PUBLISH message", zap.String("ClientID", s.id), zap.Error(err)) } } - s.expiringSince = time.Now() - s.timer.Reset(time.Duration(timerPeriod) * time.Second) + return resp, nil } -func (s *session) onDisconnect(p *connection.DisconnectParams) { - s.disconnectOnce.Do(func() { - defer s.wgDisconnected.Done() +func (s *session) SignalUnSubscribe(pkt *packet.UnSubscribe) (packet.Provider, error) { + var retCodes []packet.ReasonCode - s.lock.Lock() - close(s.isOnline) - s.lock.Unlock() + for _, t := range pkt.Topics() { + // TODO: check permissions here + authorized := true + reason := packet.CodeSuccess - finalize := func(err exitReason) { - s.signalClose(s.id, err) - s.finalized = true + if authorized { + if err := s.subscriber.UnSubscribe(t); err != nil { + s.log.Error("Couldn't unsubscribe from topic", zap.Error(err)) + reason = packet.CodeNoSubscriptionExisted + } + } else { + reason = packet.CodeNotAuthorized } - s.connStop.Do(func() { - s.conn = nil - }) + retCodes = append(retCodes, reason) + } - if p.ExpireAt != nil { - s.expireIn = p.ExpireAt - } + m, _ := packet.New(s.version, packet.UNSUBACK) + resp, _ := m.(*packet.UnSubAck) - // If session expiry is set to 0, the Session ends when the Network Connection is closed - if s.expireIn != nil && *s.expireIn == 0 { - s.killOnDisconnect = true - } + id, _ := pkt.ID() + resp.SetPacketID(id) + resp.AddReturnCodes(retCodes) // nolint: errcheck - // valid willMsg pointer tells we have will message - // if session is clean send will regardless to will delay - if p.Will && s.will != nil && (s.killOnDisconnect || s.willDelay == 0) { - s.messenger.Publish(s.will) // nolint: errcheck - s.will = nil - } + return resp, nil +} + +func (s *session) SignalDisconnect(pkt *packet.Disconnect) (packet.Provider, error) { + var err error - s.signalDisconnected(s.id, p.Reason, !s.killOnDisconnect) + err = packet.CodeSuccess - if s.killOnDisconnect || !s.subscriber.HasSubscriptions() { - s.shutdownSubscriber(s.subscriber) - s.subscriber = nil + if s.version == packet.ProtocolV50 { + // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill + if pkt.ReasonCode() != packet.CodeRefusedBadUsernameOrPassword { + s.will = nil } - if s.killOnDisconnect { - defer finalize(exitReasonClean) - } else { - // check if remaining subscriptions exists, expiry is presented and will delay not set to 0 - if s.expireIn == nil && s.willDelay == 0 { - // signal to shutdown session - defer finalize(exitReasonShutdown) - } else if (s.expireIn != nil && *s.expireIn > 0) || s.willDelay > 0 { - // new expiry value might be received upon disconnect message from the client - if p.ExpireAt != nil { - s.expireIn = p.ExpireAt + if prop := pkt.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { + if val, ok := prop.AsInt(); ok == nil { + // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- + // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session + // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server + // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. + if (s.expireIn != nil && *s.expireIn == 0) && val != 0 { + err = packet.CodeProtocolError + } else { + s.expireIn = &val } - - s.runExpiry(p.Will) } } - }) + } else { + s.will = nil + } + + return nil, err } -func (s *session) timerCallback() { - defer s.timerLock.Unlock() - s.timerLock.Lock() +// SignalOffline put subscriber in offline mode +func (s *session) SignalOffline() { + s.subscriber.Offline(!s.durable) +} - finalize := func(reason exitReason) { - s.signalClose(s.id, reason) - s.finalized = true +// SignalConnectionClose net connection has been closed +func (s *session) SignalConnectionClose(params connection.DisconnectParams) { + // If session expiry is set to 0, the Session ends when the Network Connection is closed + if s.expireIn != nil && *s.expireIn == 0 { + s.durable = true } - // 1. check for will message available - if s.will != nil { - // publish if exists and wipe state - s.messenger.Publish(s.will) // nolint: errcheck + // valid willMsg pointer tells we have will message + // if session is clean send will regardless to will delay + if s.will != nil && s.willDelay == 0 { + if err := s.messenger.Publish(s.will); err != nil { + s.log.Error("Publish will message", zap.String("ClientID", s.id), zap.Error(err)) + } s.will = nil - s.willDelay = 0 } - if s.expireIn == nil { - // 2.a session has processed delayed will and there is nothing to do - // completely shutdown the session - defer finalize(exitReasonShutdown) - } else if *s.expireIn == 0 { - // session has expired. WIPE IT - if s.subscriber != nil { - s.shutdownSubscriber(s.subscriber) + s.connectionClosed(s.id, params.Reason, !s.durable) + + if s.durable && len(params.Packets) > 0 { + s.persistence.PacketsStore([]byte(s.id), params.Packets) + } + + if !s.durable || !s.subscriber.HasSubscriptions() { + s.subscriberShutdown(s.id, s.subscriber) + s.subscriber = nil + } + + var exp *expiry + + if params.Reason != packet.CodeSessionTakenOver { + if s.willDelay > 0 || (s.expireIn != nil && *s.expireIn > 0) { + exp = newExpiry( + expiryConfig{ + id: s.id, + createdAt: s.createdAt, + messenger: s.messenger, + will: s.will, + expireIn: s.expireIn, + willDelay: s.willDelay, + }) } - defer finalize(exitReasonExpired) - } else { - // restart timer and wait again - val := *s.expireIn - // clear value pointed by expireIn so when next fire comes we signal session is expired - *s.expireIn = 0 - s.timer.Reset(time.Duration(val) * time.Second) } + + s.sessionOffline(s.id, s.durable, exp) } diff --git a/clients/sessions.go b/clients/sessions.go index 748aecb..3bfc6f7 100644 --- a/clients/sessions.go +++ b/clients/sessions.go @@ -1,27 +1,25 @@ package clients import ( - "crypto/rand" - "encoding/base64" "encoding/binary" "errors" - "io" + "fmt" "net" "strconv" "sync" + "sync/atomic" "time" - "unsafe" "github.com/VolantMQ/persistence" "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/configuration" "github.com/VolantMQ/volantmq/connection" "github.com/VolantMQ/volantmq/packet" - "github.com/VolantMQ/volantmq/routines" "github.com/VolantMQ/volantmq/subscriber" "github.com/VolantMQ/volantmq/systree" "github.com/VolantMQ/volantmq/topics/types" "github.com/VolantMQ/volantmq/types" + "github.com/gosuri/uiprogress" "github.com/troian/easygo/netpoll" "go.uber.org/zap" ) @@ -38,6 +36,7 @@ type Config struct { Persist persistence.Provider Systree systree.Provider OnReplaceAttempt func(string, bool) + AllowedVersions map[packet.ProtocolVersion]bool NodeName string ConnectTimeout int KeepAlive int @@ -54,16 +53,21 @@ type Config struct { ForceKeepAlive bool } +type preloadConfig struct { + exp *expiryConfig + sub *subscriberConfig +} + // Manager clients manager type Manager struct { - Config persistence persistence.Sessions log *zap.Logger quit chan struct{} sessionsCount sync.WaitGroup sessions sync.Map - subscribers sync.Map - poll netpoll.EventPoll + ePoll netpoll.EventPoll + delayedWills []*packet.Publish + Config } // StartConfig used to reconfigure session after connection is created @@ -74,6 +78,18 @@ type StartConfig struct { Auth auth.SessionPermissions } +type containerInfo struct { + ses *session + sub *subscriber.Type + present bool +} + +type loadContext struct { + bar *uiprogress.Bar + preloadConfigs map[string]*preloadConfig + delayedWills []packet.Provider +} + // NewManager create new clients manager func NewManager(c *Config) (*Manager, error) { m := &Manager{ @@ -82,32 +98,49 @@ func NewManager(c *Config) (*Manager, error) { log: configuration.GetLogger().Named("sessions"), } - m.poll, _ = netpoll.New(nil) + m.ePoll, _ = netpoll.New(nil) m.persistence, _ = c.Persist.Sessions() var err error - m.log.Info("Loading sessions. Might take a while") + pCount := m.persistence.Count() + if pCount > 0 { + m.log.Info("Loading sessions. Might take a while") - // load sessions for fill systree - // those sessions having either will delay or expire are created with and timer started - if err = m.loadSessions(); err != nil { - return nil, err - } + uiprogress.Start() + bar := uiprogress.AddBar(int(pCount)).AppendCompleted().PrependElapsed() - //if err = m.loadSubscribers(); err != nil { - // return nil, err - //} + bar.PrependFunc(func(b *uiprogress.Bar) string { + return fmt.Sprintf("Session load (%d/%d)", b.Current(), int(pCount)) + }) + + context := &loadContext{ + bar: bar, + preloadConfigs: make(map[string]*preloadConfig), + } + + // load sessions for fill systree + // those sessions having either will delay or expire are created with and timer started + err = m.persistence.LoadForEach(m, context) + + uiprogress.Stop() + + if err != nil { + return nil, err + } - m.log.Info("Sessions loaded") + m.log.Info("Sessions loaded") + } else { + m.log.Info("No persisted sessions") + } - //m.persistence.StatesWipe() // nolint: errcheck - //m.persistence.SubscriptionsWipe() // nolint: errcheck + m.configurePersistedSubscribers() + m.processDelayedWills() return m, nil } -// Shutdown clients manager +// Shutdown sessions manager // gracefully shutdown by stopping all active sessions and persist states func (m *Manager) Shutdown() error { select { @@ -118,126 +151,255 @@ func (m *Manager) Shutdown() error { } m.sessions.Range(func(k, v interface{}) bool { - wrap := v.(*sessionWrap) - state := wrap.s.stop(packet.CodeServerShuttingDown) + wrap := v.(*container) + ses := wrap.ses.Load().(*session) + state := ses.stop(packet.CodeServerShuttingDown) m.persistence.StateStore([]byte(k.(string)), state) // nolint: errcheck return true }) m.sessionsCount.Wait() - m.storeSubscribers() // nolint: errcheck + m.encodeSubscribers() // nolint: errcheck return nil } -// NewSession create new session with provided established connection -// This is god function. Might be try split it -func (m *Manager) NewSession(config *StartConfig) { - var id string - var ses *session +// LoadSession load persisted session. Invoked by persisted provider +func (m *Manager) LoadSession(context interface{}, id []byte, state *persistence.SessionState) error { + sID := string(id) + ctx := context.(*loadContext) + + defer ctx.bar.Incr() + + if len(state.Errors) != 0 { + m.log.Error("Session load", zap.String("ClientID", sID), zap.Errors("errors", state.Errors)) + // if err := m.persistence.SubscriptionsDelete(id); err != nil && err != persistence.ErrNotFound { + // m.log.Error("Persisted subscriber delete", zap.Error(err)) + // } + + return nil + } + var err error - idGenerated := false - var systreeConnStatus *systree.ClientConnectStatus + status := &systree.SessionCreatedStatus{ + Clean: false, + Timestamp: state.Timestamp, + } - defer func() { - if err != nil { - var reason packet.ReasonCode - switch config.Req.Version() { - case packet.ProtocolV50: - reason = packet.CodeUnspecifiedError + if err = m.decodeSessionExpiry(ctx, sID, state); err != nil { + m.log.Error("Decode subscriber", zap.String("ClientID", sID), zap.Error(err)) + } + + if err = m.decodeSubscriber(ctx, sID, state.Subscriptions); err != nil { + m.log.Error("Decode subscriber", zap.String("ClientID", sID), zap.Error(err)) + if err = m.persistence.SubscriptionsDelete(id); err != nil && err != persistence.ErrNotFound { + m.log.Error("Persisted subscriber delete", zap.Error(err)) + } + } + + if cfg, ok := ctx.preloadConfigs[sID]; ok && cfg.exp != nil { + status.WillDelay = strconv.FormatUint(uint64(cfg.exp.willDelay), 10) + if cfg.exp.expireIn != nil { + status.ExpiryInterval = strconv.FormatUint(uint64(*cfg.exp.expireIn), 10) + } + } + + m.Systree.Sessions().Created(sID, status) + return nil +} + +// Handle incoming connection +func (m *Manager) Handle(conn net.Conn, auth auth.SessionPermissions) error { + cn := connection.New( + connection.OnAuth(m.onAuth), + connection.EPoll(m.ePoll), + connection.NetConn(conn), + connection.TxQuota(types.DefaultReceiveMax), + connection.RxQuota(types.DefaultReceiveMax), + connection.Metric(m.Systree.Metric().Packets()), + connection.RetainAvailable(m.AvailableRetain), + connection.OfflineQoS0(m.OfflineQoS0), + connection.MaxTxPacketSize(types.DefaultMaxPacketSize), + connection.MaxRxPacketSize(m.MaxPacketSize), + connection.MaxRxTopicAlias(m.TopicAliasMaximum), + connection.MaxTxTopicAlias(0), + connection.KeepAlive(m.ConnectTimeout), + ) + + var connParams *connection.ConnectParams + var ack *packet.ConnAck + if ch, err := cn.Accept(); err == nil { + for dl := range ch { + var resp packet.Provider + switch obj := dl.(type) { + case *connection.ConnectParams: + connParams = obj + resp, err = m.processConnect(cn, connParams) + case connection.AuthParams: + resp, err = m.processAuth(connParams, obj) + case error: + err = obj default: - reason = packet.CodeRefusedServerUnavailable + err = errors.New("unknown") + } + + if err != nil || resp == nil { + cn.Stop(err) + cn = nil + return nil + } else { + if resp.Type() == packet.AUTH { + cn.Send(resp) + } else { + ack = resp.(*packet.ConnAck) + break + } } - config.Resp.SetReturnCode(reason) // nolint: errcheck } + } - if err = routines.WriteMessage(config.Conn, config.Resp); err != nil { - m.log.Error("Couldn't write CONNACK", zap.String("ClientID", id), zap.Error(err)) - } else { - if ses != nil { - ses.start() - m.Systree.Clients().Connected(id, systreeConnStatus) + m.newSession(cn, connParams, ack) + + return nil +} + +func (m *Manager) processConnect(cn connection.Initial, params *connection.ConnectParams) (packet.Provider, error) { + var resp packet.Provider + + if allowed, ok := m.AllowedVersions[params.Version]; !ok || !allowed { + reason := packet.CodeRefusedUnacceptableProtocolVersion + if params.Version == packet.ProtocolV50 { + reason = packet.CodeUnsupportedProtocol + } + + return nil, reason + } + + if len(params.AuthMethod) > 0 { + // TODO(troian): verify method is allowed + // resp = packet.NewAuth(params.Version) + } else { + var reason packet.ReasonCode + // if status := c.config.AuthManager.Password(string(user), string(pass)); status == auth.StatusAllow { + // reason = packet.CodeSuccess + // } else { + // reason = packet.CodeRefusedBadUsernameOrPassword + // if req.Version() == packet.ProtocolV50 { + // reason = packet.CodeBadUserOrPassword + // } + // } + + pkt := packet.NewConnAck(params.Version) + pkt.SetReturnCode(reason) + resp = pkt + } + + return resp, nil +} + +func (m *Manager) processAuth(params *connection.ConnectParams, auth connection.AuthParams) (packet.Provider, error) { + var resp packet.Provider + + return resp, nil +} + +// newSession create new session with provided established connection +func (m *Manager) newSession(cn connection.Initial, params *connection.ConnectParams, ack *packet.ConnAck) { + var ses *session + var err error + + defer func() { + keepAlive := int(params.KeepAlive) + if m.ForceKeepAlive || params.KeepAlive > 0 { + if m.ForceKeepAlive { + keepAlive = int(m.KeepAlive) } } - }() - m.checkServerStatus(config.Req.Version(), config.Resp) + if cn.Acknowledge(ack, connection.KeepAlive(keepAlive)) { + ses.start() + status := &systree.ClientConnectStatus{ + Username: string(params.Username), + Timestamp: time.Now().Format(time.RFC3339), + ReceiveMaximum: uint32(params.SendQuota), + MaximumPacketSize: params.MaxTxPacketSize, + GeneratedID: params.IDGen, + // SessionPresent: sessionPresent, + // Address: config.Conn.RemoteAddr().String(), + KeepAlive: uint16(keepAlive), + Protocol: params.Version, + ConnAckCode: ack.ReturnCode(), + CleanSession: params.CleanStart, + Durable: params.Durable, + } + + m.Systree.Clients().Connected(params.ID, status) + } + }() // if response has return code differs from CodeSuccess return from this point // and send connack in deferred statement - if config.Resp.ReturnCode() != packet.CodeSuccess { + if ack.ReturnCode() != packet.CodeSuccess { return } - // client might come with empty client id - if id = string(config.Req.ClientID()); len(id) == 0 { - id = m.genClientID() - idGenerated = true - } + if params.Version >= packet.ProtocolV50 { + ids := "" + if params.IDGen { + ids = params.ID + } - if ses, err = m.loadSession(id, config.Req.Version(), config.Resp); err == nil { - if systreeConnStatus, err = m.configureSession(config, ses, id, idGenerated); err != nil { - m.sessions.Delete(id) - m.sessionsCount.Done() + if err = m.writeSessionProperties(ack, ids); err != nil { + reason := packet.CodeUnspecifiedError + if params.Version <= packet.ProtocolV50 { + reason = packet.CodeRefusedServerUnavailable + } + ack.SetReturnCode(reason) + return } } -} -func (m *Manager) loadSession(id string, v packet.ProtocolVersion, resp *packet.ConnAck) (*session, error) { - var err error + var info *containerInfo + if info, err = m.loadContainer(cn.Session(), params); err == nil { + ses = info.ses + config := sessionConfig{ + sessionEvents: m, + expireIn: params.ExpireIn, + willDelay: params.WillDelay, + will: params.Will, + durable: params.Durable, + version: params.Version, + subscriber: info.sub, + } - var ses *session - wrap := m.allocSession(id, time.Now()) - if ss, ok := m.sessions.LoadOrStore(id, wrap); ok { - // release lock of newly allocated session as lock from old one will be used - wrap.release() - // there is some old session exists, check if it has active network connection then stop if so - // if it is offline (either waiting for expiration to fire or will event or) switch back to online - // and use this session henceforth - oldWrap := ss.(*sessionWrap) + ses.configure(config, params.CleanStart) - // lock id to prevent other upcoming session make any changes until we done - oldWrap.acquire() + if info.present { + // if session has present persistence wipe stored messages to prevent duplicate sending + m.persistence.PacketsDelete([]byte(params.ID)) + } - old := oldWrap.s + ack.SetSessionPresent(info.present) + } else { + var reason packet.ReasonCode - switch old.setOnline() { - case swStatusIsOnline: - // existing session has active network connection - // exempt it if allowed - m.OnReplaceAttempt(id, m.AllowReplace) - if !m.AllowReplace { - // we do not make any changes to current network connection - // response to new one with error and release both new & old sessions - err = packet.CodeRefusedIdentifierRejected - if v >= packet.ProtocolV50 { - err = packet.CodeInvalidClientID - } - oldWrap.release() - } else { - // session will be replaced with new connection - // stop current active connection - old.stop(packet.CodeSessionTakenOver) - ses = oldWrap.swap(wrap) - m.sessions.Store(id, oldWrap) - m.sessionsCount.Add(1) + if r, ok := err.(packet.ReasonCode); ok { + reason = r + } else { + reason = packet.CodeUnspecifiedError + if params.Version <= packet.ProtocolV50 { + reason = packet.CodeRefusedServerUnavailable } - case swStatusSwitched: - // session has been turned online successfully - ses = old - default: - ses = oldWrap.swap(wrap) - m.sessions.Store(id, oldWrap) - m.sessionsCount.Add(1) } - } else { - ses = wrap.s - m.sessionsCount.Add(1) + + ack.SetReturnCode(reason) } +} - return ses, err +func (m *Manager) onAuth(id string, params *connection.AuthParams) (packet.Provider, error) { + return nil, nil } func (m *Manager) checkServerStatus(v packet.ProtocolVersion, resp *packet.ConnAck) { @@ -258,442 +420,349 @@ func (m *Manager) checkServerStatus(v packet.ProtocolVersion, resp *packet.ConnA } } -func (m *Manager) allocSession(id string, createdAt time.Time) *sessionWrap { - wrap := &sessionWrap{ - s: newSession(&sessionPreConfig{ - id: id, - createdAt: createdAt, - messenger: m.TopicsMgr, - sessionEvents: sessionEvents{ - signalClose: m.onSessionClose, - signalDisconnected: m.onDisconnect, - shutdownSubscriber: m.onSubscriberShutdown, - }, - })} +func (m *Manager) allocContainer(id string, createdAt time.Time, cn connection.Session) *container { + ses := newSession(sessionPreConfig{ + id: id, + createdAt: createdAt, + conn: cn, + messenger: m.TopicsMgr, + persistence: m.persistence, + }) + + wrap := &container{} + ses.idLock = &wrap.lock + wrap.ses.Store(ses) wrap.acquire() - wrap.s.idLock = &wrap.lock return wrap } -func (m *Manager) getWill(pkt *packet.Connect) *packet.Publish { - var willPkt *packet.Publish - if willTopic, willPayload, willQoS, willRetain, will := pkt.Will(); will { - _m, _ := packet.New(pkt.Version(), packet.PUBLISH) - willPkt = _m.(*packet.Publish) - willPkt.Set(willTopic, willPayload, willQoS, willRetain, false) // nolint: errcheck - } - - return willPkt -} - -func (m *Manager) newConnectionPreConfig(config *StartConfig) *connection.PreConfig { - username, _ := config.Req.Credentials() - - return &connection.PreConfig{ - Username: string(username), - Auth: config.Auth, - Conn: config.Conn, - KeepAlive: config.Req.KeepAlive(), - Version: config.Req.Version(), - Desc: netpoll.Must(netpoll.HandleReadOnce(config.Conn)), - MaxTxPacketSize: types.DefaultMaxPacketSize, - SendQuota: types.DefaultReceiveMax, - State: m.persistence, - EventPoll: m.poll, - Metric: m.Systree.Metric(), - RetainAvailable: m.AvailableRetain, - OfflineQoS0: m.OfflineQoS0, - MaxRxPacketSize: m.MaxPacketSize, - MaxRxTopicAlias: m.TopicAliasMaximum, - MaxTxTopicAlias: 0, - } -} - -func (m *Manager) configureSession(config *StartConfig, ses *session, id string, idGenerated bool) (*systree.ClientConnectStatus, error) { - sub, sessionPresent := m.getSubscriber(id, config.Req.IsClean(), config.Req.Version()) - - sConfig := &sessionReConfig{ - subscriber: sub, - will: m.getWill(config.Req), - killOnDisconnect: false, - } +func (m *Manager) loadContainer(cn connection.Session, params *connection.ConnectParams) (cont *containerInfo, err error) { + newContainer := m.allocContainer(params.ID, time.Now(), cn) + if ss, present := m.sessions.LoadOrStore(params.ID, newContainer); present { + // release lock of newly allocated session as lock from old one will be used + newContainer.release() - cConfig := m.newConnectionPreConfig(config) + // there is some old session exists, check if it has active network connection then stop if so + // if it is offline (either waiting for expiration to fire or will event or) switch back to online + // and use this session henceforth + currContainer := ss.(*container) - if config.Req.Version() >= packet.ProtocolV50 { - if err := readSessionProperties(config.Req, sConfig, cConfig); err != nil { - return nil, err - } + // lock id to prevent other upcoming session make any changes until we done + currContainer.acquire() - ids := "" - if idGenerated { - ids = id + if current := currContainer.session(); current != nil { + m.OnReplaceAttempt(params.ID, m.AllowReplace) + if !m.AllowReplace { + // we do not make any changes to current network connection + // response to new one with error and release both new & old sessions + err = packet.CodeRefusedIdentifierRejected + if params.Version >= packet.ProtocolV50 { + err = packet.CodeInvalidClientID + } + currContainer.release() + newContainer = nil + return + } + // session will be replaced with new connection + // stop current active connection + current.stop(packet.CodeSessionTakenOver) } - m.writeSessionProperties(config.Resp, ids) - if err := config.Resp.PropertySet(packet.PropertyServerKeepAlive, m.KeepAlive); err != nil { - return nil, err + if val := currContainer.expiry.Load(); val != nil { + exp := val.(*expiry) + exp.cancel() + currContainer.expiry = atomic.Value{} + m.sessionsCount.Done() } - } - // MQTT v5 has different meaning of clean comparing to MQTT v3 - // - v3: if session is clean it lasts when Network connection os close - // - v5: clean means clean start and server must wipe any previously created session with same id - // but keep this one if Network Connection is closed - if (config.Req.Version() <= packet.ProtocolV311 && config.Req.IsClean()) || - (sConfig.expireIn != nil && *sConfig.expireIn == 0) { - sConfig.killOnDisconnect = true + newContainer = currContainer.swap(newContainer) + m.sessions.Store(params.ID, currContainer) + m.sessionsCount.Add(1) + } else { + m.sessionsCount.Add(1) } - ses.reconfigure(sConfig, false) - - var status *systree.ClientConnectStatus + sub, present := newContainer.subscriber( + params.CleanStart, + subscriber.Config{ + ID: params.ID, + OfflinePublish: m, + Topics: m.TopicsMgr, + Version: params.Version, + }) - if err := ses.allocConnection(cConfig); err == nil { - if !config.Req.IsClean() { - m.persistence.Delete([]byte(id)) // nolint: errcheck + if params.CleanStart { + if err = m.persistence.Delete([]byte(params.ID)); err != nil && err != persistence.ErrNotFound { + m.log.Error("Couldn't wipe session", zap.String("ClientID", params.ID), zap.Error(err)) + } else { + err = nil } + } else { + persisted := m.persistence.Exists([]byte(params.ID)) + present = present || persisted - status = &systree.ClientConnectStatus{ - Username: cConfig.Username, - Timestamp: time.Now().Format(time.RFC3339), - ReceiveMaximum: uint32(cConfig.SendQuota), - MaximumPacketSize: cConfig.MaxTxPacketSize, - GeneratedID: idGenerated, - SessionPresent: sessionPresent, - Address: config.Conn.RemoteAddr().String(), - KeepAlive: config.Req.KeepAlive(), - Protocol: config.Req.Version(), - ConnAckCode: config.Resp.ReturnCode(), - CleanSession: config.Req.IsClean(), - KillOnDisconnect: sConfig.killOnDisconnect, + if params.Durable && !persisted { + // TODO: create persistence entry } } - config.Resp.SetSessionPresent(sessionPresent) - - return status, nil -} - -func boolToByte(v bool) byte { - if v { - return 1 + cont = &containerInfo{ + ses: newContainer.ses.Load().(*session), + sub: sub, + present: present, } - return 0 + return } -func readSessionProperties(req *packet.Connect, sc *sessionReConfig, cc *connection.PreConfig) (err error) { - // [MQTT-3.1.2.11.2] - if prop := req.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { - if val, e := prop.AsInt(); e == nil { - sc.expireIn = &val +func (m *Manager) writeSessionProperties(resp *packet.ConnAck, id string) error { + boolToByte := func(v bool) byte { + if v { + return 1 } - } - // [MQTT-3.1.2.11.3] - if prop := req.PropertyGet(packet.PropertyWillDelayInterval); prop != nil { - if val, e := prop.AsInt(); e == nil { - sc.willDelay = val - } + return 0 } - // [MQTT-3.1.2.11.4] - if prop := req.PropertyGet(packet.PropertyReceiveMaximum); prop != nil { - if val, e := prop.AsShort(); e == nil { - cc.SendQuota = int32(val) - } - } - - // [MQTT-3.1.2.11.5] - if prop := req.PropertyGet(packet.PropertyMaximumPacketSize); prop != nil { - if val, e := prop.AsInt(); e == nil { - cc.MaxTxPacketSize = val - } - } - - // [MQTT-3.1.2.11.6] - if prop := req.PropertyGet(packet.PropertyTopicAliasMaximum); prop != nil { - if val, e := prop.AsShort(); e == nil { - cc.MaxTxTopicAlias = val - } - } - - // [MQTT-3.1.2.11.10] - if prop := req.PropertyGet(packet.PropertyAuthMethod); prop != nil { - if val, e := prop.AsString(); e == nil { - cc.AuthMethod = val - } - } - - // [MQTT-3.1.2.11.11] - if prop := req.PropertyGet(packet.PropertyAuthData); prop != nil { - if len(cc.AuthMethod) == 0 { - err = packet.CodeProtocolError - return - } - if val, e := prop.AsBinary(); e == nil { - cc.AuthData = val - } - } - - return -} - -func (m *Manager) writeSessionProperties(resp *packet.ConnAck, id string) { // [MQTT-3.2.2.3.2] if server receive max less than 65536 than let client to know about if m.ReceiveMax < types.DefaultReceiveMax { - resp.PropertySet(packet.PropertyReceiveMaximum, m.ReceiveMax) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyReceiveMaximum, m.ReceiveMax); err != nil { + return err + } } // [MQTT-3.2.2.3.3] if supported server's QoS less than 2 notify client if m.MaximumQoS < packet.QoS2 { - resp.PropertySet(packet.PropertyMaximumQoS, byte(m.MaximumQoS)) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyMaximumQoS, byte(m.MaximumQoS)); err != nil { + return err + } } // [MQTT-3.2.2.3.4] tell client whether retained messages supported - resp.PropertySet(packet.PropertyRetainAvailable, boolToByte(m.AvailableRetain)) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyRetainAvailable, boolToByte(m.AvailableRetain)); err != nil { + return err + } // [MQTT-3.2.2.3.5] if server max packet size less than 268435455 than let client to know about if m.MaxPacketSize < types.DefaultMaxPacketSize { - resp.PropertySet(packet.PropertyMaximumPacketSize, m.MaxPacketSize) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyMaximumPacketSize, m.MaxPacketSize); err != nil { + return err + } } // [MQTT-3.2.2.3.6] if len(id) > 0 { - resp.PropertySet(packet.PropertyAssignedClientIdentifier, id) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyAssignedClientIdentifier, id); err != nil { + return err + } } // [MQTT-3.2.2.3.7] if m.TopicAliasMaximum > 0 { - resp.PropertySet(packet.PropertyTopicAliasMaximum, m.TopicAliasMaximum) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyTopicAliasMaximum, m.TopicAliasMaximum); err != nil { + return err + } } // [MQTT-3.2.2.3.10] tell client whether server supports wildcard subscriptions or not - resp.PropertySet(packet.PropertyWildcardSubscriptionAvailable, boolToByte(m.AvailableWildcardSubscription)) // nolint: errcheck + if err := resp.PropertySet(packet.PropertyWildcardSubscriptionAvailable, boolToByte(m.AvailableWildcardSubscription)); err != nil { + return err + } // [MQTT-3.2.2.3.11] tell client whether server supports subscription identifiers or not - resp.PropertySet(packet.PropertySubscriptionIdentifierAvailable, boolToByte(m.AvailableSubscriptionID)) // nolint: errcheck - // [MQTT-3.2.2.3.12] tell client whether server supports shared subscriptions or not - resp.PropertySet(packet.PropertySharedSubscriptionAvailable, boolToByte(m.AvailableSharedSubscription)) // nolint: errcheck -} - -func (m *Manager) getSubscriber(id string, clean bool, v packet.ProtocolVersion) (subscriber.ConnectionProvider, bool) { - var sub subscriber.ConnectionProvider - var present bool - - if clean { - if sb, ok := m.subscribers.Load(id); ok { - sub = sb.(subscriber.ConnectionProvider) - sub.Offline(true) - m.subscribers.Delete(id) - } - if err := m.persistence.Delete([]byte(id)); err != nil && err != persistence.ErrNotFound { - m.log.Error("Couldn't wipe session", zap.String("ClientID", id), zap.Error(err)) - } + if err := resp.PropertySet(packet.PropertySubscriptionIdentifierAvailable, boolToByte(m.AvailableSubscriptionID)); err != nil { + return err } - - if sb, ok := m.subscribers.Load(id); !ok { - sub = subscriber.New(&subscriber.Config{ - ID: id, - Topics: m.TopicsMgr, - OnOfflinePublish: m.onPublish, - OfflineQoS0: m.OfflineQoS0, - Version: v, - }) - - present = m.persistence.Exists([]byte(id)) - m.subscribers.Store(id, sub) - m.log.Debug("Subscriber created", zap.String("ClientID", id)) - } else { - m.log.Debug("Subscriber obtained", zap.String("ClientID", id)) - sub = sb.(subscriber.ConnectionProvider) - present = true + // [MQTT-3.2.2.3.12] tell client whether server supports shared subscriptions or not + if err := resp.PropertySet(packet.PropertySharedSubscriptionAvailable, boolToByte(m.AvailableSharedSubscription)); err != nil { + return err } - return sub, present -} - -func (m *Manager) genClientID() string { - b := make([]byte, 15) - if _, err := io.ReadFull(rand.Reader, b); err != nil { - return "" + if m.ForceKeepAlive { + if err := resp.PropertySet(packet.PropertyServerKeepAlive, m.KeepAlive); err != nil { + return err + } } - return base64.URLEncoding.EncodeToString(b) + return nil } -func (m *Manager) onDisconnect(id string, reason packet.ReasonCode, retain bool) { +func (m *Manager) connectionClosed(id string, reason packet.ReasonCode, retain bool) { m.log.Debug("Disconnected", zap.String("ClientID", id)) m.Systree.Clients().Disconnected(id, reason, retain) } -func (m *Manager) onSubscriberShutdown(sub subscriber.ConnectionProvider) { - m.log.Debug("Shutdown subscriber", zap.String("ClientID", sub.ID())) +func (m *Manager) subscriberShutdown(id string, sub subscriber.SessionProvider) { sub.Offline(true) - m.subscribers.Delete(sub.ID()) + val, _ := m.sessions.Load(id) + + wrap := val.(*container) + wrap.sub = nil } -func (m *Manager) onSessionClose(id string, reason exitReason) { - if reason == exitReasonClean || reason == exitReasonExpired { - if err := m.persistence.Delete([]byte(id)); err != nil && err != persistence.ErrNotFound { - m.log.Error("Couldn't wipe session", zap.String("ClientID", id), zap.Error(err)) - } +func (m *Manager) sessionTimer(id string, expired bool) { + rs := "shutdown" + if expired { + rs = "expired" - rs := "clean" - if reason == exitReasonExpired { - rs = "expired" - } + m.persistence.Delete([]byte(id)) + m.sessions.Delete(id) + m.sessionsCount.Done() + } + + state := &systree.SessionDeletedStatus{ + Timestamp: time.Now().Format(time.RFC3339), + Reason: rs, + } + + m.Systree.Sessions().Removed(id, state) +} + +func (m *Manager) sessionOffline(id string, durable bool, exp *expiry) { + if !durable { state := &systree.SessionDeletedStatus{ Timestamp: time.Now().Format(time.RFC3339), - Reason: rs, + Reason: "clean", } m.Systree.Sessions().Removed(id, state) } - m.log.Debug("Session close", zap.String("ClientID", id)) + if exp != nil { + obj, _ := m.sessions.Load(id) + wrap := obj.(*container) + wrap.expiry.Store(exp) + exp.start() + } - m.sessions.Delete(id) - m.sessionsCount.Done() + if !durable || exp == nil { + m.sessions.Delete(id) + m.sessionsCount.Done() + } } -func (m *Manager) loadSessions() error { - subs := map[string]*subscriberConfig{} - - delayedWills := []*packet.Publish{} - - err := m.persistence.LoadForEach(func(id []byte, state *persistence.SessionState) error { - sID := string(id) +func (m *Manager) configurePersistedSubscribers() { + //for id, t := range m.subscriberConfigs { + // sub := subscriber.New( + // &subscriber.Config{ + // ID: id, + // Topics: m.TopicsMgr, + // OnOfflinePublish: m.onPublish, + // OfflineQoS0: m.OfflineQoS0, + // Version: t.version, + // }) + // + // for topic, ops := range t.topics { + // if _, _, err := sub.Subscribe(topic, ops); err != nil { + // m.log.Error("Couldn't subscribe", zap.Error(err)) + // } + // } + // + // m.subscribers.Store(id, sub) + //} + // + //m.subscriberConfigs = make(map[string]*subscriberConfig) +} - if len(state.Errors) != 0 { - m.log.Error("Session load", zap.String("ClientID", sID), zap.Errors("errors", state.Errors)) - if err := m.persistence.SubscriptionsDelete(id); err != nil && err != persistence.ErrNotFound { - m.log.Error("Persisted subscriber delete", zap.Error(err)) - } +func (m *Manager) processDelayedWills() { + for _, will := range m.delayedWills { + if err := m.TopicsMgr.Publish(will); err != nil { + m.log.Error("Publish delayed will", zap.Error(err)) } + } - if state.Expire != nil { - since, err := time.Parse(time.RFC3339, state.Expire.Since) - if err != nil { - m.log.Error("Parse expiration value", zap.String("ClientID", sID), zap.Error(err)) - if err := m.persistence.SubscriptionsDelete(id); err != nil && err != persistence.ErrNotFound { - m.log.Error("Persisted subscriber delete", zap.Error(err)) - } - return nil - } - - var will *packet.Publish - var willIn uint32 - var expireIn uint32 - - // if persisted state has delayed will lets check if it has not elapsed its time - if len(state.Expire.WillIn) > 0 && len(state.Expire.WillData) > 0 { - pkt, _, _ := packet.Decode(packet.ProtocolVersion(state.Version), state.Expire.WillData) - will, _ = pkt.(*packet.Publish) + m.delayedWills = []*packet.Publish{} +} - if val, err := strconv.Atoi(state.Expire.WillIn); err == nil { - willIn = uint32(val) - willAt := since.Add(time.Duration(willIn) * time.Second) +// decodeSessionExpiry +func (m *Manager) decodeSessionExpiry(ctx *loadContext, id string, state *persistence.SessionState) error { + if state.Expire == nil { + return nil + } - if time.Now().Before(willAt) { - // will delay elapsed. notify that - delayedWills = append(delayedWills, will) - will = nil - } - } else { - m.log.Error("Decode will at", zap.String("ClientID", sID), zap.Error(err)) - } - } + since, err := time.Parse(time.RFC3339, state.Expire.Since) + if err != nil { + prev := err + m.log.Error("Parse expiration value", zap.String("ClientID", id), zap.Error(err)) + if err = m.persistence.SubscriptionsDelete([]byte(id)); err != nil && err != persistence.ErrNotFound { + m.log.Error("Persisted subscriber delete", zap.Error(err)) + } - if len(state.Expire.ExpireIn) > 0 { - if val, err := strconv.Atoi(state.Expire.ExpireIn); err == nil { - expireIn = uint32(val) - expireAt := since.Add(time.Duration(expireIn) * time.Second) - - if time.Now().Before(expireAt) { - // persisted session has expired, wipe it - if err := m.persistence.Delete(id); err != nil && err != persistence.ErrNotFound { - m.log.Error("Persisted session delete", zap.Error(err)) - } - return nil - } - } else { - m.log.Error("Decode expire at", zap.String("ClientID", sID), zap.Error(err)) - } - } + return prev + } - // persisted session has either delayed will or expiry - // create it and run timer - if will != nil || expireIn > 0 { - createdAt, _ := time.Parse(time.RFC3339, state.Timestamp) - ses := m.allocSession(sID, createdAt) - var exp *uint32 - if expireIn > 0 { - exp = &expireIn - } + var will *packet.Publish + var willIn uint32 + var expireIn uint32 - setup := &sessionReConfig{ - subscriber: nil, - expireIn: exp, - will: will, - willDelay: willIn, - killOnDisconnect: false, - } + // if persisted state has delayed will lets check if it has not elapsed its time + if len(state.Expire.WillIn) > 0 && len(state.Expire.WillData) > 0 { + pkt, _, _ := packet.Decode(packet.ProtocolVersion(state.Version), state.Expire.WillData) + will, _ = pkt.(*packet.Publish) + var val int + if val, err = strconv.Atoi(state.Expire.WillIn); err == nil { + willIn = uint32(val) + willAt := since.Add(time.Duration(willIn) * time.Second) - ses.s.reconfigure(setup, true) - m.sessions.Store(id, ses) - m.sessionsCount.Add(1) - ses.release() + if time.Now().After(willAt) { + // will delay elapsed. notify keep in list and publish when all persisted sessions loaded + ctx.delayedWills = append(ctx.delayedWills, will) + will = nil + willIn = 0 } + } else { + m.log.Error("Decode will at", zap.String("ClientID", id), zap.Error(err)) } + } - if len(state.Subscriptions) > 0 { - if sCfg, err := m.loadSubscriber(state.Subscriptions); err == nil { - subs[sID] = sCfg - } else { - m.log.Error("Decode subscriber", zap.String("ClientID", sID), zap.Error(err)) - } + if len(state.Expire.ExpireIn) > 0 { + var val int + if val, err = strconv.Atoi(state.Expire.ExpireIn); err == nil { + expireIn = uint32(val) + expireAt := since.Add(time.Duration(expireIn) * time.Second) - if err := m.persistence.SubscriptionsDelete(id); err != nil && err != persistence.ErrNotFound { - m.log.Error("Persisted subscriber delete", zap.Error(err)) + if time.Now().After(expireAt) { + // persisted session has expired, wipe it + if err = m.persistence.Delete([]byte(id)); err != nil && err != persistence.ErrNotFound { + m.log.Error("Delete expired session", zap.Error(err)) + } + return nil } + } else { + m.log.Error("Decode expire at", zap.String("ClientID", id), zap.Error(err)) } + } - status := &systree.SessionCreatedStatus{ - Clean: false, - Timestamp: state.Timestamp, + // persisted session has either delayed will or expiry + // create it and run timer + if will != nil || expireIn > 0 { + var createdAt time.Time + if createdAt, err = time.Parse(time.RFC3339, state.Timestamp); err != nil { + m.log.Named("persistence").Error("Decode createdAt failed, using current timestamp", + zap.String("ClientID", id), + zap.Error(err)) + createdAt = time.Now() } - m.Systree.Sessions().Created(sID, status) - return nil - }) - - for id, t := range subs { - sub := subscriber.New( - &subscriber.Config{ - ID: id, - Topics: m.TopicsMgr, - OnOfflinePublish: m.onPublish, - OfflineQoS0: m.OfflineQoS0, - Version: t.version, - }) - - for topic, ops := range t.topics { - if _, _, err = sub.Subscribe(topic, ops); err != nil { - m.log.Error("Couldn't subscribe", zap.Error(err)) - } + if _, ok := ctx.preloadConfigs[id]; !ok { + ctx.preloadConfigs[id] = &preloadConfig{} } - m.subscribers.Store(id, sub) - } - - // publish delayed wills if any - for _, will := range delayedWills { - if err = m.TopicsMgr.Publish(will); err != nil { - m.log.Error("Publish delayed will", zap.Error(err)) + ctx.preloadConfigs[id].exp = &expiryConfig{ + expiryEvent: m, + messenger: m.TopicsMgr, + createdAt: createdAt, + will: will, + willDelay: willIn, + expireIn: &expireIn, } } - return err + return nil } -func (m *Manager) loadSubscriber(from []byte) (*subscriberConfig, error) { +// decodeSubscriber function invoke only during server startup. Used to decode persisted session +// which has active subscriptions +func (m *Manager) decodeSubscriber(ctx *loadContext, id string, from []byte) error { + if len(from) == 0 { + return nil + } + subscriptions := subscriber.Subscriptions{} offset := 0 version := packet.ProtocolVersion(from[offset]) @@ -702,7 +771,7 @@ func (m *Manager) loadSubscriber(from []byte) (*subscriberConfig, error) { for offset != remaining { t, total, e := packet.ReadLPBytes(from[offset:]) if e != nil { - return nil, e + return e } offset += total @@ -717,72 +786,81 @@ func (m *Manager) loadSubscriber(from []byte) (*subscriberConfig, error) { subscriptions[string(t)] = params } - return &subscriberConfig{ + if _, ok := ctx.preloadConfigs[id]; !ok { + ctx.preloadConfigs[id] = &preloadConfig{} + } + + ctx.preloadConfigs[id].sub = &subscriberConfig{ version: version, topics: subscriptions, - }, nil + } + + return nil } -func (m *Manager) storeSubscribers() error { +func (m *Manager) encodeSubscribers() error { // 4. shutdown and persist subscriptions from non-clean session //for id, s := range m.subscribers { - m.subscribers.Range(func(k, v interface{}) bool { - id := k.(string) - s := v.(subscriber.ConnectionProvider) - s.Offline(true) - - topics := s.Subscriptions() - - // calculate size of the encoded entry - // consist of: - // _ _ _ _ _ _ _ _ _ _ _ - // |_|_|_|_|_|...|_|_|_|_|_|_| - // ___ _ _________ _ _______ - // | | | | | - // | | | | 4 bytes - subscription id - // | | | | 1 byte - topic options - // | | | n bytes - topic - // | | 1 bytes - protocol version - // | 2 bytes - length prefix - - size := 0 - for topic := range topics { - size += 2 + len(topic) + 1 + int(unsafe.Sizeof(uint32(0))) - } - - buf := make([]byte, size+1) - offset := 0 - buf[offset] = byte(s.Version()) - offset++ - - for s, params := range topics { - total, _ := packet.WriteLPBytes(buf[offset:], []byte(s)) - offset += total - buf[offset] = byte(params.Ops) - offset++ - binary.BigEndian.PutUint32(buf[offset:], params.ID) - offset += 4 - } - - if err := m.persistence.SubscriptionsStore([]byte(id), buf); err != nil { - m.log.Error("Couldn't persist subscriptions", zap.String("ClientID", id), zap.Error(err)) - } - - return true - }) + //m.subscribers.Range(func(k, v interface{}) bool { + // id := k.(string) + // s := v.(subscriber.SessionProvider) + // s.Offline(true) + // + // topics := s.Subscriptions() + // + // // calculate size of the encoded entry + // // consist of: + // // _ _ _ _ _ _ _ _ _ _ _ + // // |_|_|_|_|_|...|_|_|_|_|_|_| + // // ___ _ _________ _ _______ + // // | | | | | + // // | | | | 4 bytes - subscription id + // // | | | | 1 byte - topic options + // // | | | n bytes - topic + // // | | 1 bytes - protocol version + // // | 2 bytes - length prefix + // + // size := 0 + // for topic := range topics { + // size += 2 + len(topic) + 1 + int(unsafe.Sizeof(uint32(0))) + // } + // + // buf := make([]byte, size+1) + // offset := 0 + // buf[offset] = byte(s.Version()) + // offset++ + // + // for s, params := range topics { + // total, _ := packet.WriteLPBytes(buf[offset:], []byte(s)) + // offset += total + // buf[offset] = byte(params.Ops) + // offset++ + // binary.BigEndian.PutUint32(buf[offset:], params.ID) + // offset += 4 + // } + // + // if err := m.persistence.SubscriptionsStore([]byte(id), buf); err != nil { + // m.log.Error("Couldn't persist subscriptions", zap.String("ClientID", id), zap.Error(err)) + // } + // + // return true + //}) return nil } -func (m *Manager) onPublish(id string, p *packet.Publish) { +func (m *Manager) Publish(id string, p *packet.Publish) { pkt := persistence.PersistedPacket{UnAck: false} - if p.Expired(false) { + var expired bool + var expireAt time.Time + + if expireAt, _, expired = p.Expired(); expired { return } - if tm := p.GetExpiry(); !tm.IsZero() { - pkt.ExpireAt = tm.Format(time.RFC3339) + if !expireAt.IsZero() { + pkt.ExpireAt = expireAt.Format(time.RFC3339) } p.SetPacketID(0) diff --git a/connection/ack.go b/connection/ack.go index e0f20c2..2a57abd 100644 --- a/connection/ack.go +++ b/connection/ack.go @@ -22,7 +22,7 @@ func (a *ackQueue) release(pkt packet.Provider) { id, _ := pkt.ID() if value, ok := a.messages.Load(id); ok { - if orig, ok := value.(packet.Provider); ok && a.onRelease != nil { + if orig, k := value.(packet.Provider); k && a.onRelease != nil { a.onRelease(orig, pkt) } a.messages.Delete(id) diff --git a/connection/connection.go b/connection/connection.go index 410e1f5..aff1e20 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -16,37 +16,101 @@ package connection import ( "container/list" + "crypto/rand" + "encoding/base64" "errors" + "io" "net" "sync" + "sync/atomic" "time" "github.com/VolantMQ/persistence" - "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/configuration" "github.com/VolantMQ/volantmq/packet" - "github.com/VolantMQ/volantmq/subscriber" "github.com/VolantMQ/volantmq/systree" "github.com/VolantMQ/volantmq/types" "github.com/troian/easygo/netpoll" "go.uber.org/zap" ) +type state int + +const ( + stateConnecting state = iota + stateAuth + stateConnected + stateReAuth + stateDisconnected + stateConnectFailed +) + // nolint: golint var ( ErrOverflow = errors.New("session: overflow") ErrPersistence = errors.New("session: error during persistence restore") ) +var expectedPacketType = map[state]map[packet.Type]bool{ + stateConnecting: {packet.CONNECT: true}, + stateAuth: { + packet.AUTH: true, + packet.DISCONNECT: true, + }, + stateConnected: { + packet.PUBLISH: true, + packet.PUBACK: true, + packet.PUBREC: true, + packet.PUBREL: true, + packet.PUBCOMP: true, + packet.SUBSCRIBE: true, + packet.SUBACK: true, + packet.UNSUBSCRIBE: true, + packet.UNSUBACK: true, + packet.PINGREQ: true, + packet.AUTH: true, + packet.DISCONNECT: true, + }, + stateReAuth: { + packet.PUBLISH: true, + packet.PUBACK: true, + packet.PUBREC: true, + packet.PUBREL: true, + packet.PUBCOMP: true, + packet.SUBSCRIBE: true, + packet.SUBACK: true, + packet.UNSUBSCRIBE: true, + packet.UNSUBACK: true, + packet.PINGREQ: true, + packet.AUTH: true, + packet.DISCONNECT: true, + }, +} + +func (s state) desc() string { + switch s { + case stateConnecting: + return "CONNECTING" + case stateAuth: + return "AUTH" + case stateConnected: + return "CONNECTED" + case stateReAuth: + return "RE-AUTH" + case stateDisconnected: + return "DISCONNECTED" + default: + return "CONNECT_FAILED" + } +} + // DisconnectParams session state when stopped type DisconnectParams struct { - ExpireAt *uint32 - Desc *netpoll.Desc - Reason packet.ReasonCode - Will bool + Reason packet.ReasonCode + Packets []persistence.PersistedPacket } -type onDisconnect func(*DisconnectParams) +//type onDisconnect func(*DisconnectParams) // Callbacks provided by sessions manager to signal session state type Callbacks struct { @@ -62,80 +126,105 @@ type WillConfig struct { QoS packet.QosType } -// PreConfig used by session manager to when configuring session -// a bit of ugly -// TODO(troian): try rid it off -type PreConfig struct { - EventPoll netpoll.EventPoll - Username string - AuthMethod string - AuthData []byte - State persistence.Packets - Metric systree.Metric - Conn net.Conn - Auth auth.SessionPermissions - Desc *netpoll.Desc - MaxRxPacketSize uint32 +type AuthParams struct { + AuthMethod string + AuthData []byte + Reason packet.ReasonCode +} + +type ConnectParams struct { + AuthParams + ID string + Error error + ExpireIn *uint32 + Will *packet.Publish + Username []byte + Password []byte + WillDelay uint32 MaxTxPacketSize uint32 - SendQuota int32 - MaxTxTopicAlias uint16 - MaxRxTopicAlias uint16 + SendQuota uint16 KeepAlive uint16 + IDGen bool + CleanStart bool + Durable bool Version packet.ProtocolVersion - RetainAvailable bool - PreserveOrder bool - OfflineQoS0 bool -} - -// Config is system wide configuration parameters for every session -type Config struct { - *PreConfig - ID string - Subscriber subscriber.ConnectionProvider - Messenger types.TopicMessenger - OnDisconnect onDisconnect - ExpireIn *uint32 - WillDelay uint32 - KillOnDisconnect bool -} - -// Type connection -type Type struct { - *Config - preProcessPublish func(*packet.Publish) error - postProcessPublish func(*packet.Publish) error - pubIn ackQueue - pubOut ackQueue - quit chan struct{} - onStart types.Once - onConnDisconnect types.Once - started sync.WaitGroup - txWg sync.WaitGroup - rxWg sync.WaitGroup - txTopicAlias map[string]uint16 - rxTopicAlias map[uint16]string - txTimer *time.Timer - log *zap.Logger - keepAliveTimer *time.Timer - txGMessages list.List - txQMessages list.List - txGLock sync.Mutex - txQLock sync.Mutex - keepAlive time.Duration - txAvailable chan int - rxRecv []byte - retained struct { - lock sync.Mutex - list []*packet.Publish - } - flowInUse sync.Map - flowCounter uint64 - rxRemaining int - txRunning uint32 - rxRunning uint32 +} + +type SessionCallbacks interface { + SignalPublish(*packet.Publish) error + SignalSubscribe(*packet.Subscribe) (packet.Provider, error) + SignalUnSubscribe(*packet.UnSubscribe) (packet.Provider, error) + SignalDisconnect(*packet.Disconnect) (packet.Provider, error) + SignalOffline() + SignalConnectionClose(DisconnectParams) +} + +type ackQueues struct { + pubIn ackQueue + pubOut ackQueue +} + +type flow struct { + flowInUse sync.Map + flowCounter uint32 +} + +type tx struct { + txGMessages list.List + txQMessages list.List + txGLock sync.Mutex + txQLock sync.Mutex + txTopicAlias map[string]uint16 + txWg sync.WaitGroup + txTimer *time.Timer + txRunning uint32 + txAvailable chan int + txQuotaExceeded bool +} + +type rx struct { + desc *netpoll.Desc + rxWg sync.WaitGroup + rxTopicAlias map[uint16]string + rxRecv []byte + keepAlive time.Duration + keepAliveTimer *time.Timer + rxRunning uint32 + rxRemaining int +} + +// impl of the connection +type impl struct { + SessionCallbacks + id string + metric systree.PacketsMetric + conn net.Conn + ePoll netpoll.EventPoll + signalAuth OnAuthCb + ackQueues + tx + rx + flow + quit chan struct{} + connect chan interface{} + onStart types.Once + onConnDisconnect types.OnceWait + started sync.WaitGroup + log *zap.Logger + keepAlive time.Duration + authMethod string + connectProcessed uint32 + maxRxPacketSize uint32 + maxTxPacketSize uint32 + txQuota int32 + rxQuota int32 + state state topicAliasCurrMax uint16 - txQuotaExceeded bool - will bool + maxTxTopicAlias uint16 + maxRxTopicAlias uint16 + version packet.ProtocolVersion + retainAvailable bool + offlineQoS0 bool } type unacknowledged struct { @@ -150,132 +239,396 @@ type sizeAble interface { Size() (int, error) } +type baseAPI interface { + Stop(error) bool +} + +type Initial interface { + baseAPI + Accept() (chan interface{}, error) + Send(packet.Provider) + Acknowledge(p *packet.ConnAck, opts ...Option) bool + Session() Session +} + +type Session interface { + baseAPI + persistence.PacketLoader + Publish(string, *packet.Publish) + LoadRemaining(g, q *list.List) + SetOptions(opts ...Option) error +} + +var _ Initial = (*impl)(nil) +var _ Session = (*impl)(nil) + // New allocate new connection object -func New(c *Config) (s *Type, err error) { - s = &Type{ - Config: c, - quit: make(chan struct{}), - txAvailable: make(chan int, 1), - txTopicAlias: make(map[string]uint16), - rxTopicAlias: make(map[uint16]string), - txTimer: time.NewTimer(1 * time.Second), - will: true, +func New(opts ...Option) Initial { + s := &impl{ + state: stateConnecting, + quit: make(chan struct{}), + } + + for _, opt := range opts { + opt(s) } + s.txAvailable = make(chan int, 1) + s.txTopicAlias = make(map[string]uint16) + s.rxTopicAlias = make(map[uint16]string) + s.txTimer = time.NewTimer(1 * time.Second) s.txTimer.Stop() s.started.Add(1) s.pubIn.onRelease = s.onReleaseIn s.pubOut.onRelease = s.onReleaseOut - s.log = configuration.GetLogger().Named("connection." + s.ID) + s.log = configuration.GetLogger().Named("connection") + + return s +} + +// Accept start handling incoming connection +func (s *impl) Accept() (chan interface{}, error) { + s.connect = make(chan interface{}) + + s.desc = netpoll.Must(netpoll.HandleReadOnce(s.conn)) + s.keepAliveTimer = time.AfterFunc(time.Duration(s.keepAlive), s.keepAliveFired) + return s.connect, s.ePoll.Start(s.desc, s.rxConnection) +} + +// Session +func (s *impl) Session() Session { + return s +} + +// Send +func (s *impl) Send(pkt packet.Provider) { + if pkt.Type() == packet.AUTH { + s.state = stateAuth + } + + s.runKeepAlive() + s.gPush(pkt) + + s.ePoll.Resume(s.desc) +} - if s.Version >= packet.ProtocolV50 { - s.preProcessPublish = s.preProcessPublishV50 - s.postProcessPublish = s.postProcessPublishV50 +func (s *impl) Acknowledge(p *packet.ConnAck, opts ...Option) bool { + ack := true + s.ePoll.Stop(s.desc) + + close(s.connect) + + if p.ReturnCode() == packet.CodeSuccess { + s.state = stateConnected + + for _, opt := range opts { + opt(s) + } + + s.runKeepAlive() + if err := s.ePoll.Start(s.desc, s.rxRun); err != nil { + s.log.Error("Cannot start receiver", zap.String("ClientID", s.id), zap.Error(err)) + s.state = stateConnectFailed + ack = false + } } else { - s.preProcessPublish = func(*packet.Publish) error { return nil } - s.postProcessPublish = func(*packet.Publish) error { return nil } + s.state = stateConnectFailed + ack = false } - if c.KeepAlive > 0 { - s.keepAlive = time.Second * time.Duration(c.KeepAlive) - s.keepAlive = s.keepAlive + (s.keepAlive / 2) - s.keepAliveTimer = time.AfterFunc(s.keepAlive, s.keepAliveExpired) + s.gPushFront(p) + + if !ack { + s.Stop(nil) } - gList := list.New() - qList := list.New() + return ack +} - subscribedPublish := func(p *packet.Publish) { - if p.QoS() == packet.QoS0 { - gList.PushBack(p) - } else { - qList.PushBack(p) +// Stop connection. Function assumed to be invoked once server about to either shutdown, disconnect +// or session is being replaced +// Effective only first invoke +func (s *impl) Stop(reason error) bool { + return s.onConnectionClose(reason) +} + +func (s *impl) LoadRemaining(g, q *list.List) { + select { + case <-s.quit: + return + default: + } + s.gLoadList(g) + s.qLoadList(q) +} + +func (s *impl) LoadPersistedPacket(entry persistence.PersistedPacket) error { + var err error + var pkt packet.Provider + if pkt, _, err = packet.Decode(s.version, entry.Data); err != nil { + s.log.Error("Couldn't decode persisted message", zap.Error(err)) + return ErrPersistence + } + + if entry.UnAck { + switch p := pkt.(type) { + case *packet.Publish: + id, _ := p.ID() + s.flowReAcquire(id) + case *packet.Ack: + id, _ := p.ID() + s.flowReAcquire(id) + } + + s.qLoad(&unacknowledged{packet: pkt}) + } else { + if p, ok := pkt.(*packet.Publish); ok { + if len(entry.ExpireAt) > 0 { + if tm, err := time.Parse(time.RFC3339, entry.ExpireAt); err == nil { + p.SetExpireAt(tm) + } else { + s.log.Error("Parse publish expiry", zap.String("ClientID", s.id), zap.Error(err)) + } + } + + if p.QoS() == packet.QoS0 { + s.gLoad(pkt) + } else { + s.qLoad(pkt) + } } } - // transmitter queue is ready, assign to subscriber new online callback - // and signal it to forward messages to online callback by creating channel - s.Subscriber.Online(subscribedPublish) + return nil +} - // restore persisted state of the session if any - if err = s.loadPersistence(); err != nil { - return +func (s *impl) Publish(id string, pkt *packet.Publish) { + if pkt.QoS() == packet.QoS0 { + s.gPush(pkt) + } else { + s.qPush(pkt) + } +} + +func genClientID() string { + b := make([]byte, 15) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "" } - s.Subscriber.OnlineRedirect(s.onSubscribedPublish) + return base64.URLEncoding.EncodeToString(b) +} - s.gLoadList(gList) - s.qLoadList(qList) +func (s *impl) getWill(pkt *packet.Connect) *packet.Publish { + var p *packet.Publish - return + if willTopic, willPayload, willQoS, willRetain, will := pkt.Will(); will { + p = packet.NewPublish(pkt.Version()) + if err := p.Set(willTopic, willPayload, willQoS, willRetain, false); err != nil { + s.log.Error("Configure will packet", zap.String("ClientID", s.id), zap.Error(err)) + p = nil + } + } + + return p } -// Start run connection -func (s *Type) Start() { - s.onStart.Do(func() { - s.txRun() - s.EventPoll.Start(s.Desc, s.rxRun) // nolint: errcheck - s.started.Done() - }) +func (s *impl) onConnect(pkt *packet.Connect) (packet.Provider, error) { + if atomic.CompareAndSwapUint32(&s.connectProcessed, 0, 1) { + id := string(pkt.ClientID()) + idGen := false + if len(id) == 0 { + idGen = true + id = genClientID() + } + + s.id = id + + params := &ConnectParams{ + ID: id, + IDGen: idGen, + Will: s.getWill(pkt), + KeepAlive: pkt.KeepAlive(), + Version: pkt.Version(), + CleanStart: pkt.IsClean(), + Durable: true, + } + + s.version = params.Version + + s.readConnProperties(pkt, params) + + // MQTT v5 has different meaning of clean comparing to MQTT v3 + // - v3: if session is clean it is clean start and session lasts when Network connection closed + // - v5: clean only means "clean start" and sessions lasts on connection close on if expire propery + // exists and set to 0 + if (params.Version <= packet.ProtocolV311 && params.CleanStart) || + (params.Version >= packet.ProtocolV50 && params.ExpireIn != nil && *params.ExpireIn == 0) { + params.Durable = false + } + + s.connect <- params + return nil, nil + } + + // It's protocol error to send CONNECT packet more than once + return nil, packet.CodeProtocolError } -// Stop session. Function assumed to be invoked once server about to either shutdown, disconnect -// or session is being replaced -// Effective only first invoke -func (s *Type) Stop(reason packet.ReasonCode) { - s.onConnectionClose(true, reason) +func (s *impl) onAuth(pkt *packet.Auth) (packet.Provider, error) { + // AUTH packets are allowed for v5.0 only + if s.version < packet.ProtocolV50 { + return nil, packet.CodeRefusedServerUnavailable + } + + reason := pkt.ReasonCode() + + // Client must not send AUTH packets before server has requested it + // during auth or re-auth Client must respond only AUTH with CodeContinueAuthentication + // if connection is being established Client must send AUTH only with CodeReAuthenticate + if (s.state == stateConnecting) || + ((s.state == stateAuth || s.state == stateReAuth) && (reason != packet.CodeContinueAuthentication)) || + ((s.state == stateConnected) && reason != (packet.CodeReAuthenticate)) { + return nil, packet.CodeProtocolError + } + + params := AuthParams{ + Reason: reason, + } + + // [MQTT-3.15.2.2.2] + if prop := pkt.PropertyGet(packet.PropertyAuthMethod); prop != nil { + if val, e := prop.AsString(); e == nil { + params.AuthMethod = val + } + } + + // AUTH packet must provide AuthMethod property + if len(params.AuthMethod) == 0 { + return nil, packet.CodeProtocolError + } + + // [MQTT-4.12.0-7] - If the Client does not include an Authentication Method in the CONNECT, + // the Client MUST NOT send an AUTH packet to the Server + // [MQTT-4.12.1-1] - The Client MUST set the Authentication Method to the same value as + // the Authentication Method originally used to authenticate the Network Connection + if len(s.authMethod) == 0 || s.authMethod != params.AuthMethod { + return nil, packet.CodeProtocolError + } + + // [MQTT-3.15.2.2.3] + if prop := pkt.PropertyGet(packet.PropertyAuthData); prop != nil { + if val, e := prop.AsBinary(); e == nil { + params.AuthData = val + } + } + + if s.state == stateConnecting || s.state == stateAuth { + s.connect <- params + return nil, nil + } + + return s.signalAuth(s.id, ¶ms) +} + +func (s *impl) readConnProperties(req *packet.Connect, params *ConnectParams) { + if s.version < packet.ProtocolV50 { + return + } + + // [MQTT-3.1.2.11.2] + if prop := req.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { + if val, e := prop.AsInt(); e == nil { + params.ExpireIn = &val + } + } + + // [MQTT-3.1.2.11.3] + if prop := req.PropertyGet(packet.PropertyWillDelayInterval); prop != nil { + if val, e := prop.AsInt(); e == nil { + params.WillDelay = val + } + } + + // [MQTT-3.1.2.11.4] + if prop := req.PropertyGet(packet.PropertyReceiveMaximum); prop != nil { + if val, e := prop.AsShort(); e == nil { + s.txQuota = int32(val) + params.SendQuota = val + } + } + + // [MQTT-3.1.2.11.5] + if prop := req.PropertyGet(packet.PropertyMaximumPacketSize); prop != nil { + if val, e := prop.AsInt(); e == nil { + s.maxTxPacketSize = val + } + } + + // [MQTT-3.1.2.11.6] + if prop := req.PropertyGet(packet.PropertyTopicAliasMaximum); prop != nil { + if val, e := prop.AsShort(); e == nil { + s.maxTxTopicAlias = val + } + } + + // [MQTT-3.1.2.11.10] + if prop := req.PropertyGet(packet.PropertyAuthMethod); prop != nil { + if val, e := prop.AsString(); e == nil { + params.AuthMethod = val + s.authMethod = val + } + } + + // [MQTT-3.1.2.11.11] + if prop := req.PropertyGet(packet.PropertyAuthData); prop != nil { + if len(params.AuthMethod) == 0 { + params.Error = packet.CodeProtocolError + return + } + if val, e := prop.AsBinary(); e == nil { + params.AuthData = val + } + } + + return } -func (s *Type) processIncoming(p packet.Provider) error { +func (s *impl) processIncoming(p packet.Provider) error { var err error var resp packet.Provider + // [MQTT-3.1.2-33] - If a Client sets an Authentication Method in the CONNECT, + // the Client MUST NOT send any packets other than AUTH or DISCONNECT packets + // until it has received a CONNACK packet + if _, ok := expectedPacketType[s.state][p.Type()]; !ok { + s.log.Info("Unexpected packet for current state", + zap.String("ClientID", s.id), + zap.String("state", s.state.desc()), + zap.String("packet", p.Type().Name())) + return packet.CodeProtocolError + } + switch pkt := p.(type) { + case *packet.Connect: + resp, err = s.onConnect(pkt) + case *packet.Auth: + resp, err = s.onAuth(pkt) case *packet.Publish: resp, err = s.onPublish(pkt) case *packet.Ack: - resp = s.onAck(pkt) + resp, err = s.onAck(pkt) case *packet.Subscribe: - resp = s.onSubscribe(pkt) + resp, err = s.SignalSubscribe(pkt) case *packet.UnSubscribe: - resp = s.onUnSubscribe(pkt) + resp, err = s.SignalUnSubscribe(pkt) case *packet.PingReq: // For PINGREQ message, we should send back PINGRESP - mR, _ := packet.New(s.Version, packet.PINGRESP) - resp, _ = mR.(*packet.PingResp) + resp = packet.NewPingResp(s.version) case *packet.Disconnect: - // For DISCONNECT message, we should quit without sending Will - s.will = false - - if s.Version == packet.ProtocolV50 { - // FIXME: CodeRefusedBadUsernameOrPassword has same id as CodeDisconnectWithWill - if pkt.ReasonCode() == packet.CodeRefusedBadUsernameOrPassword { - s.will = true - } - - err = errors.New("disconnect") - - if prop := pkt.PropertyGet(packet.PropertySessionExpiryInterval); prop != nil { - if val, ok := prop.AsInt(); ok == nil { - // If the Session Expiry Interval in the CONNECT packet was zero, then it is a Protocol Error to set a non- - // zero Session Expiry Interval in the DISCONNECT packet sent by the Client. If such a non-zero Session - // Expiry Interval is received by the Server, it does not treat it as a valid DISCONNECT packet. The Server - // uses DISCONNECT with Reason Code 0x82 (Protocol Error) as described in section 4.13. - if (s.ExpireIn != nil && *s.ExpireIn == 0) && val != 0 { - err = packet.CodeProtocolError - } else { - s.ExpireIn = &val - } - } - } - } - default: - s.log.Error("Unsupported incoming message type", - zap.String("ClientID", s.ID), - zap.String("type", p.Type().Name())) - return nil + resp, err = s.SignalDisconnect(pkt) } if resp != nil { @@ -285,53 +638,7 @@ func (s *Type) processIncoming(p packet.Provider) error { return err } -func (s *Type) loadPersistence() error { - if s.State == nil { - return nil - } - - return s.State.PacketsForEach([]byte(s.ID), func(entry persistence.PersistedPacket) error { - var err error - var pkt packet.Provider - if pkt, _, err = packet.Decode(s.Version, entry.Data); err != nil { - s.log.Error("Couldn't decode persisted message", zap.Error(err)) - return ErrPersistence - } - - if entry.UnAck { - switch p := pkt.(type) { - case *packet.Publish: - id, _ := p.ID() - s.flowReAcquire(id) - case *packet.Ack: - id, _ := p.ID() - s.flowReAcquire(id) - } - - s.qLoad(&unacknowledged{packet: pkt}) - } else { - if p, ok := pkt.(*packet.Publish); ok { - if len(entry.ExpireAt) > 0 { - if tm, err := time.Parse(time.RFC3339, entry.ExpireAt); err == nil { - p.SetExpiry(tm) - } else { - s.log.Error("Parse publish expiry", zap.String("ClientID", s.ID), zap.Error(err)) - } - } - - if p.QoS() == packet.QoS0 { - s.gLoad(pkt) - } else { - s.qLoad(pkt) - } - } - } - - return nil - }) -} - -func (s *Type) persist() { +func (s *impl) getToPersist() []persistence.PersistedPacket { var packets []persistence.PersistedPacket persistAppend := func(p interface{}) { @@ -340,9 +647,9 @@ func (s *Type) persist() { switch tp := p.(type) { case *packet.Publish: - if (s.OfflineQoS0 || tp.QoS() != packet.QoS0) && !tp.Expired(false) { - if tm := tp.GetExpiry(); !tm.IsZero() { - pPkt.ExpireAt = tm.Format(time.RFC3339) + if expireAt, _, expired := tp.Expired(); expired && (s.offlineQoS0 || tp.QoS() != packet.QoS0) { + if !expireAt.IsZero() { + pPkt.ExpireAt = expireAt.Format(time.RFC3339) } if tp.QoS() != packet.QoS0 { @@ -361,18 +668,19 @@ func (s *Type) persist() { pPkt.UnAck = true } - var err error - if pPkt.Data, err = packet.Encode(pkt); err != nil { - s.log.Error("Couldn't encode message for persistence", zap.Error(err)) - } else { - packets = append(packets, pPkt) + if pkt != nil { + var err error + if pPkt.Data, err = packet.Encode(pkt); err != nil { + s.log.Error("Couldn't encode message for persistence", zap.Error(err)) + } else { + packets = append(packets, pPkt) + } } } var next *list.Element for elem := s.txQMessages.Front(); elem != nil; elem = next { next = elem.Next() - persistAppend(s.txQMessages.Remove(elem)) } @@ -391,71 +699,61 @@ func (s *Type) persist() { return true }) - if err := s.State.PacketsStore([]byte(s.ID), packets); err != nil { - s.log.Error("Persist packets", zap.String("ClientID", s.ID), zap.Error(err)) - } -} - -// onSubscribedPublish is the method that gets added to the topic subscribers list by the -// processSubscribe() method. When the server finishes the ack cycle for a -// PUBLISH message, it will call the subscriber, which is this method. -// -// For the server, when this method is called, it means there's a message that -// should be published to the client on the other end of this connection. So we -// will call publish() to send the message. -func (s *Type) onSubscribedPublish(p *packet.Publish) { - if p.QoS() == packet.QoS0 { - s.gPush(p) - } else { - s.qPush(p) - } + return packets } // forward PUBLISH message to topics manager which takes care about subscribers -func (s *Type) publishToTopic(p *packet.Publish) error { - if err := s.postProcessPublish(p); err != nil { - return err - } - - p.SetPublishID(s.Subscriber.Hash()) - - // [MQTT-3.3.1.3] - if p.Retain() { - if err := s.Messenger.Retain(p); err != nil { - s.log.Error("Error retaining message", zap.String("ClientID", s.ID), zap.Error(err)) +func (s *impl) publishToTopic(p *packet.Publish) error { + // v5.0 + // If the Server included Retain Available in its CONNACK response to a Client with its value set to 0 and it + // receives a PUBLISH packet with the RETAIN flag is set to 1, then it uses the DISCONNECT Reason + // Code of 0x9A (Retain not supported) as described in section 4.13. + if s.version >= packet.ProtocolV50 { + // [MQTT-3.3.2.3.4] + if prop := p.PropertyGet(packet.PropertyTopicAlias); prop != nil { + if val, err := prop.AsShort(); err == nil { + if len(p.Topic()) != 0 { + // renew alias with new topic + s.rxTopicAlias[val] = p.Topic() + } else { + if topic, kk := s.rxTopicAlias[val]; kk { + // do not check for error as topic has been validated when arrived + p.SetTopic(topic) // nolint: errcheck + } else { + return packet.CodeInvalidTopicAlias + } + } + } else { + return packet.CodeInvalidTopicAlias + } } - // [MQTT-3.3.1-7] - if p.QoS() == packet.QoS0 { - _m, _ := packet.New(s.Version, packet.PUBLISH) - m := _m.(*packet.Publish) - m.SetQoS(p.QoS()) // nolint: errcheck - m.SetTopic(p.Topic()) // nolint: errcheck - s.retained.lock.Lock() - s.retained.list = append(s.retained.list, m) - s.retained.lock.Unlock() + // [MQTT-3.3.2.3.3] + if prop := p.PropertyGet(packet.PropertyPublicationExpiry); prop != nil { + if val, err := prop.AsInt(); err == nil { + s.log.Warn("Set pub expiration", zap.String("ClientID", s.id), zap.Duration("val", time.Duration(val)*time.Second)) + p.SetExpireAt(time.Now().Add(time.Duration(val) * time.Second)) + } else { + return err + } } } - if err := s.Messenger.Publish(p); err != nil { - s.log.Error("Couldn't publish", zap.String("ClientID", s.ID), zap.Error(err)) - } - - return nil + return s.SignalPublish(p) } // onReleaseIn ack process for incoming messages -func (s *Type) onReleaseIn(o, n packet.Provider) { +func (s *impl) onReleaseIn(o, n packet.Provider) { switch p := o.(type) { case *packet.Publish: - s.publishToTopic(p) // nolint: errcheck + s.SignalPublish(p) } } // onReleaseOut process messages that required ack cycle // onAckTimeout if publish message has not been acknowledged withing specified ackTimeout // server should mark it as a dup and send again -func (s *Type) onReleaseOut(o, n packet.Provider) { +func (s *impl) onReleaseOut(o, n packet.Provider) { switch n.Type() { case packet.PUBACK: fallthrough @@ -466,53 +764,3 @@ func (s *Type) onReleaseOut(o, n packet.Provider) { } } } - -func (s *Type) preProcessPublishV50(p *packet.Publish) error { - // v5.0 - // If the Server included Retain Available in its CONNACK response to a Client with its value set to 0 and it - // receives a PUBLISH packet with the RETAIN flag is set to 1, then it uses the DISCONNECT Reason - // Code of 0x9A (Retain not supported) as described in section 4.13. - if s.Version >= packet.ProtocolV50 && !s.RetainAvailable && p.Retain() { - return packet.CodeRetainNotSupported - } - - if prop := p.PropertyGet(packet.PropertyTopicAlias); prop != nil { - if val, ok := prop.AsShort(); ok == nil && (val == 0 || val > s.MaxRxTopicAlias) { - return packet.CodeInvalidTopicAlias - } - } - - return nil -} - -func (s *Type) postProcessPublishV50(p *packet.Publish) error { - // [MQTT-3.3.2.3.4] - if prop := p.PropertyGet(packet.PropertyTopicAlias); prop != nil { - if val, err := prop.AsShort(); err == nil { - if len(p.Topic()) != 0 { - // renew alias with new topic - s.rxTopicAlias[val] = p.Topic() - } else { - if topic, kk := s.rxTopicAlias[val]; kk { - // do not check for error as topic has been validated when arrived - p.SetTopic(topic) // nolint: errcheck - } else { - return packet.CodeInvalidTopicAlias - } - } - } else { - return packet.CodeInvalidTopicAlias - } - } - - // [MQTT-3.3.2.3.3] - if prop := p.PropertyGet(packet.PropertyPublicationExpiry); prop != nil { - if val, err := prop.AsInt(); err == nil { - p.SetExpiry(time.Now().Add(time.Duration(val) * time.Second)) - } else { - return err - } - } - - return nil -} diff --git a/connection/flowControl.go b/connection/flowControl.go index 5722f5b..1dadfd5 100644 --- a/connection/flowControl.go +++ b/connection/flowControl.go @@ -12,19 +12,12 @@ var ( errQuotaExceeded = errors.New("quota exceeded") ) -//type packetsFlowControl struct { -// counter uint64 -// quit chan struct{} -// inUse sync.Map -// quota int32 -//} - -func (s *Type) flowReAcquire(id packet.IDType) { - atomic.AddInt32(&s.SendQuota, -1) +func (s *impl) flowReAcquire(id packet.IDType) { + atomic.AddInt32(&s.txQuota, -1) s.flowInUse.Store(id, true) } -func (s *Type) flowAcquire() (packet.IDType, error) { +func (s *impl) flowAcquire() (packet.IDType, error) { select { case <-s.quit: return 0, errExit @@ -32,7 +25,7 @@ func (s *Type) flowAcquire() (packet.IDType, error) { } var err error - if atomic.AddInt32(&s.SendQuota, -1) == 0 { + if atomic.AddInt32(&s.txQuota, -1) == 0 { err = errQuotaExceeded } @@ -49,8 +42,8 @@ func (s *Type) flowAcquire() (packet.IDType, error) { return id, err } -func (s *Type) flowRelease(id packet.IDType) bool { +func (s *impl) flowRelease(id packet.IDType) bool { s.flowInUse.Delete(id) - return atomic.AddInt32(&s.SendQuota, 1) == 1 + return atomic.AddInt32(&s.txQuota, 1) == 1 } diff --git a/connection/keepAlive.go b/connection/keepAlive.go new file mode 100644 index 0000000..85e0d46 --- /dev/null +++ b/connection/keepAlive.go @@ -0,0 +1,15 @@ +package connection + +import ( + "errors" +) + +func (s *impl) runKeepAlive() { + if s.keepAlive > 0 { + s.keepAliveTimer.Reset(s.keepAlive) + } +} + +func (s *impl) keepAliveFired() { + s.onConnectionClose(errors.New("time out")) +} diff --git a/connection/netCallbacks.go b/connection/netCallbacks.go index e5cae34..4fb1180 100644 --- a/connection/netCallbacks.go +++ b/connection/netCallbacks.go @@ -3,13 +3,11 @@ package connection import ( "sync/atomic" - "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/packet" - "github.com/VolantMQ/volantmq/topics/types" "go.uber.org/zap" ) -func (s *Type) txShutdown() { +func (s *impl) txShutdown() { atomic.StoreUint32(&s.txRunning, 2) s.txTimer.Stop() s.txWg.Wait() @@ -21,81 +19,84 @@ func (s *Type) txShutdown() { } } -func (s *Type) rxShutdown() { +func (s *impl) rxShutdown() { + atomic.StoreUint32(&s.rxRunning, 0) s.rxWg.Wait() - - if s.keepAlive > 0 { - s.keepAliveTimer.Stop() - s.keepAliveTimer = nil - } } -func (s *Type) onConnectionClose(will bool, err error) { - s.onConnDisconnect.Do(func() { - // make sure connection has been started before proceeding to any shutdown procedures - s.started.Wait() +func (s *impl) onConnectionClose(status error) bool { + return s.onConnDisconnect.Do(func() { + s.keepAliveTimer.Stop() + close(s.quit) + var err error // shutdown quit channel tells all routines finita la commedia - close(s.quit) - if e := s.EventPoll.Stop(s.Desc); e != nil { - s.log.Error("remove receiver from netpoll", zap.String("ClientID", s.ID), zap.Error(e)) + s.ePoll.Stop(s.desc) + + if s.state != stateConnecting && s.state != stateAuth && s.state != stateConnectFailed { + s.SignalOffline() + } else if s.state == stateConnecting || s.state == stateAuth { + select { + case <-s.connect: + default: + close(s.connect) + } } + // clean up transmitter to allow send disconnect command to client if needed s.txShutdown() - // put subscriber in offline mode - s.Subscriber.Offline(s.KillOnDisconnect) - - if err != nil && s.Version >= packet.ProtocolV50 { + if reason, ok := status.(packet.ReasonCode); ok && + reason != packet.CodeSuccess && s.version >= packet.ProtocolV50 && + s.state != stateConnecting && s.state != stateAuth && s.state != stateConnectFailed { // server wants to tell client disconnect reason - reason, _ := err.(packet.ReasonCode) - p, _ := packet.New(s.Version, packet.DISCONNECT) - pkt, _ := p.(*packet.Disconnect) + pkt := packet.NewDisconnect(s.version) pkt.SetReasonCode(reason) var buf []byte - buf, err = packet.Encode(pkt) - if err != nil { - s.log.Error("encode disconnect packet", zap.String("ClientID", s.ID), zap.Error(err)) + if buf, err = packet.Encode(pkt); err != nil { + s.log.Error("encode disconnect packet", zap.String("ClientID", s.id), zap.Error(err)) } else { - if _, err = s.Conn.Write(buf); err != nil { - s.log.Error("Couldn't write disconnect message", zap.String("ClientID", s.ID), zap.Error(err)) + var written int + if written, err = s.conn.Write(buf); written != len(buf) { + s.log.Error("Couldn't write disconnect message", + zap.String("ClientID", s.id), + zap.Int("packet size", len(buf)), + zap.Int("written", written)) + } else if err != nil { + s.log.Debug("Couldn't write disconnect message", + zap.String("ClientID", s.id), + zap.Error(err)) } } } - if err = s.Conn.Close(); err != nil { - s.log.Error("close connection", zap.String("ClientID", s.ID), zap.Error(err)) + if err = s.desc.Close(); err != nil { + s.log.Error("Close polling descriptor", zap.String("ClientID", s.id), zap.Error(err)) + } + + if err = s.conn.Close(); err != nil { + s.log.Error("close connection", zap.String("ClientID", s.id), zap.Error(err)) } s.rxShutdown() - // [MQTT-3.3.1-7] - // Discard retained messages with QoS 0 - s.retained.lock.Lock() - //for _, m := range s.retained.list { - // s.topics.Retain(m) // nolint: errcheck - //} - s.retained.list = []*packet.Publish{} - s.retained.lock.Unlock() - s.Conn = nil - - params := &DisconnectParams{ - Will: will, - ExpireAt: s.ExpireIn, - Desc: s.Desc, - Reason: packet.CodeSuccess, - } + s.conn = nil - if rc, ok := err.(packet.ReasonCode); ok { - params.Reason = rc - } + if s.state != stateConnecting && s.state != stateAuth && s.state != stateConnectFailed { + params := DisconnectParams{ + Packets: s.getToPersist(), + Reason: packet.CodeSuccess, + } - if !s.KillOnDisconnect { - s.persist() + if rc, ok := err.(packet.ReasonCode); ok { + params.Reason = rc + } + + s.SignalConnectionClose(params) } - s.OnDisconnect(params) + s.state = stateDisconnected }) } @@ -103,13 +104,21 @@ func (s *Type) onConnectionClose(will bool, err error) { // On QoS == 0, we should just take the next step, no ack required // On QoS == 1, send back PUBACK, then take the next step // On QoS == 2, we need to put it in the ack queue, send back PUBREC -func (s *Type) onPublish(pkt *packet.Publish) (packet.Provider, error) { +func (s *impl) onPublish(pkt *packet.Publish) (packet.Provider, error) { // check for topic access var err error reason := packet.CodeSuccess - if err = s.preProcessPublish(pkt); err != nil { - return nil, err + if s.version >= packet.ProtocolV50 { + if !s.retainAvailable && pkt.Retain() { + return nil, packet.CodeRetainNotSupported + } + + if prop := pkt.PropertyGet(packet.PropertyTopicAlias); prop != nil { + if val, ok := prop.AsShort(); ok == nil && (val == 0 || val > s.maxRxTopicAlias) { + return nil, packet.CodeInvalidTopicAlias + } + } } var resp packet.Provider @@ -117,27 +126,32 @@ func (s *Type) onPublish(pkt *packet.Publish) (packet.Provider, error) { // To deal with V3.1.1 two ways left: // - ignore the message but send acks // - return error which leads to disconnect - if status := s.Auth.ACL(s.ID, s.Username, pkt.Topic(), auth.AccessTypeWrite); status == auth.StatusDeny { - reason = packet.CodeAdministrativeAction - } + //if status := s.ACL(s.ID, pkt.Topic(), auth.AccessTypeWrite); status == auth.StatusDeny { + // reason = packet.CodeAdministrativeAction + //} switch pkt.QoS() { case packet.QoS2: - resp, _ = packet.New(s.Version, packet.PUBREC) - r, _ := resp.(*packet.Ack) - id, _ := pkt.ID() + if s.rxQuota == 0 { + reason = packet.CodeReceiveMaximumExceeded + } else { + s.rxQuota-- + resp, _ = packet.New(s.version, packet.PUBREC) + r, _ := resp.(*packet.Ack) + id, _ := pkt.ID() - r.SetPacketID(id) - r.SetReason(reason) + r.SetPacketID(id) + r.SetReason(reason) - // [MQTT-4.3.3-9] - // store incoming QoS 2 message before sending PUBREC as theoretically PUBREL - // might come before store in case message store done after write PUBREC - if reason < packet.CodeUnspecifiedError { - s.pubIn.store(pkt) + // [MQTT-4.3.3-9] + // store incoming QoS 2 message before sending PUBREC as theoretically PUBREL + // might come before store in case message store done after write PUBREC + if reason < packet.CodeUnspecifiedError { + s.pubIn.store(pkt) + } } case packet.QoS1: - resp, _ = packet.New(s.Version, packet.PUBACK) + resp, _ = packet.New(s.version, packet.PUBACK) r, _ := resp.(*packet.Ack) id, _ := pkt.ID() @@ -151,7 +165,7 @@ func (s *Type) onPublish(pkt *packet.Publish) (packet.Provider, error) { if reason < packet.CodeUnspecifiedError { if err = s.publishToTopic(pkt); err != nil { s.log.Error("Couldn't publish message", - zap.String("ClientID", s.ID), + zap.String("ClientID", s.id), zap.Uint8("QoS", uint8(pkt.QoS())), zap.Error(err)) } @@ -162,7 +176,7 @@ func (s *Type) onPublish(pkt *packet.Publish) (packet.Provider, error) { } // onAck handle ack acknowledgment received from remote -func (s *Type) onAck(msg packet.Provider) packet.Provider { +func (s *impl) onAck(msg packet.Provider) (packet.Provider, error) { var resp packet.Provider switch mIn := msg.(type) { case *packet.Ack: @@ -178,7 +192,7 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { id, _ := msg.ID() - if s.Version == packet.ProtocolV50 && mIn.Reason() >= packet.CodeUnspecifiedError { + if s.version == packet.ProtocolV50 && mIn.Reason() >= packet.CodeUnspecifiedError { // v5.9 [MQTT-4.9] if s.flowRelease(id) { s.signalQuota() @@ -188,7 +202,7 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { } if !discard { - resp, _ = packet.New(s.Version, packet.PUBREL) + resp, _ = packet.New(s.version, packet.PUBREL) r, _ := resp.(*packet.Ack) r.SetPacketID(id) @@ -200,117 +214,23 @@ func (s *Type) onAck(msg packet.Provider) packet.Provider { } case packet.PUBREL: // Remote has released PUBLISH - resp, _ = packet.New(s.Version, packet.PUBCOMP) + resp, _ = packet.New(s.version, packet.PUBCOMP) r, _ := resp.(*packet.Ack) id, _ := msg.ID() r.SetPacketID(id) + s.rxQuota++ s.pubIn.release(msg) case packet.PUBCOMP: // PUBREL message has been acknowledged, release from queue s.pubOut.release(msg) default: s.log.Error("Unsupported ack message type", - zap.String("ClientID", s.ID), + zap.String("ClientID", s.id), zap.String("type", msg.Type().Name())) } - default: - //return err } - return resp -} - -func (s *Type) onSubscribe(msg *packet.Subscribe) packet.Provider { - m, _ := packet.New(s.Version, packet.SUBACK) - resp, _ := m.(*packet.SubAck) - - id, _ := msg.ID() - resp.SetPacketID(id) - - var retCodes []packet.ReasonCode - var retainedPublishes []*packet.Publish - - msg.RangeTopics(func(t string, ops packet.SubscriptionOptions) { - reason := packet.CodeSuccess // nolint: ineffassign - //authorized := true - // TODO: check permissions here - - //if authorized { - subsID := uint32(0) - - // V5.0 [MQTT-3.8.2.1.2] - if prop := msg.PropertyGet(packet.PropertySubscriptionIdentifier); prop != nil { - if v, e := prop.AsInt(); e == nil { - subsID = v - } - } - - subsParams := topicsTypes.SubscriptionParams{ - ID: subsID, - Ops: ops, - } - - if grantedQoS, retained, err := s.Subscriber.Subscribe(t, &subsParams); err != nil { - // [MQTT-3.9.3] - if s.Version == packet.ProtocolV50 { - reason = packet.CodeUnspecifiedError - } else { - reason = packet.QosFailure - } - } else { - reason = packet.ReasonCode(grantedQoS) - retainedPublishes = append(retainedPublishes, retained...) - } - - retCodes = append(retCodes, reason) - }) - - if err := resp.AddReturnCodes(retCodes); err != nil { - return nil - } - - // Now put retained messages into publish queue - for _, rp := range retainedPublishes { - if pkt, err := rp.Clone(s.Version); err == nil { - pkt.SetRetain(true) - s.onSubscribedPublish(pkt) - } else { - s.log.Error("Couldn't clone PUBLISH message", zap.String("ClientID", s.ID), zap.Error(err)) - } - } - - return resp -} - -func (s *Type) onUnSubscribe(msg *packet.UnSubscribe) packet.Provider { - var retCodes []packet.ReasonCode - - for _, t := range msg.Topics() { - // TODO: check permissions here - authorized := true - reason := packet.CodeSuccess - - if authorized { - if err := s.Subscriber.UnSubscribe(t); err != nil { - s.log.Error("Couldn't unsubscribe from topic", zap.Error(err)) - } else { - reason = packet.CodeNoSubscriptionExisted - } - } else { - reason = packet.CodeNotAuthorized - } - - retCodes = append(retCodes, reason) - } - - m, _ := packet.New(s.Version, packet.UNSUBACK) - resp, _ := m.(*packet.UnSubAck) - - id, _ := msg.ID() - resp.SetPacketID(id) - resp.AddReturnCodes(retCodes) // nolint: errcheck - - return resp + return resp, nil } diff --git a/connection/options.go b/connection/options.go new file mode 100644 index 0000000..e82226f --- /dev/null +++ b/connection/options.go @@ -0,0 +1,130 @@ +package connection + +import ( + "errors" + "net" + "time" + + "github.com/VolantMQ/volantmq/packet" + "github.com/VolantMQ/volantmq/systree" + "github.com/troian/easygo/netpoll" +) + +type OnAuthCb func(string, *AuthParams) (packet.Provider, error) + +type Option func(*impl) error + +func (s *impl) SetOptions(opts ...Option) error { + for _, opt := range opts { + if err := opt(s); err != nil { + return err + } + } + + return nil +} +func OfflineQoS0(val bool) Option { + return func(t *impl) error { + t.offlineQoS0 = val + return nil + } +} + +func KeepAlive(val int) Option { + return func(t *impl) error { + vl := time.Duration(val) * time.Second + vl = vl + (vl / 2) + t.keepAlive = vl + return nil + } +} + +func Metric(val systree.PacketsMetric) Option { + return func(t *impl) error { + t.metric = val + return nil + } +} + +func EPoll(val netpoll.EventPoll) Option { + return func(t *impl) error { + t.ePoll = val + return nil + } +} + +func MaxRxPacketSize(val uint32) Option { + return func(t *impl) error { + t.maxRxPacketSize = val + return nil + } +} + +func MaxTxPacketSize(val uint32) Option { + return func(t *impl) error { + t.maxTxPacketSize = val + return nil + } +} + +func TxQuota(val int32) Option { + return func(t *impl) error { + t.txQuota = val + return nil + } +} + +func RxQuota(val int32) Option { + return func(t *impl) error { + t.rxQuota = val + return nil + } +} + +func MaxTxTopicAlias(val uint16) Option { + return func(t *impl) error { + t.maxTxTopicAlias = val + return nil + } +} + +func MaxRxTopicAlias(val uint16) Option { + return func(t *impl) error { + t.maxRxTopicAlias = val + return nil + } +} + +func RetainAvailable(val bool) Option { + return func(t *impl) error { + t.retainAvailable = val + return nil + } +} + +func OnAuth(val OnAuthCb) Option { + return func(t *impl) error { + t.signalAuth = val + return nil + } +} + +func NetConn(val net.Conn) Option { + return func(t *impl) error { + if t.conn != nil { + return errors.New("already set") + } + t.conn = val + return nil + } +} + +func AttachSession(val SessionCallbacks) Option { + return func(t *impl) error { + if t.SessionCallbacks != nil { + return errors.New("already set") + } + t.SessionCallbacks = val + return nil + } +} diff --git a/connection/receiver.go b/connection/receiver.go index 751807a..9b98d07 100644 --- a/connection/receiver.go +++ b/connection/receiver.go @@ -3,79 +3,90 @@ package connection import ( "bufio" "encoding/binary" - "sync/atomic" - "errors" + "sync/atomic" "github.com/VolantMQ/volantmq/packet" "github.com/troian/easygo/netpoll" ) -func (s *Type) keepAliveExpired() { - s.onConnectionClose(true, nil) -} - -func (s *Type) rxRun(event netpoll.Event) { - select { - case <-s.quit: - return - default: - } - +func (s *impl) rxRun(event netpoll.Event) { if atomic.CompareAndSwapUint32(&s.rxRunning, 0, 1) { - s.rxWg.Wait() - s.rxWg.Add(1) - - exit := false - if event&(netpoll.EventReadHup|netpoll.EventWriteHup|netpoll.EventHup|netpoll.EventErr) != 0 { - exit = true + mask := netpoll.EventHup | netpoll.EventReadHup | netpoll.EventWriteHup | netpoll.EventErr | netpoll.EventPollClosed + if (event & mask) != 0 { + go s.onConnectionClose(nil) + } else { + go func() { + s.rxWg.Wait() + s.rxWg.Add(1) + s.rxRoutine() + }() } + } +} - go s.rxRoutine(exit) +func (s *impl) rxConnection(event netpoll.Event) { + mask := netpoll.EventHup | netpoll.EventReadHup | netpoll.EventWriteHup | netpoll.EventErr | netpoll.EventPollClosed + if (event & mask) != 0 { + go func() { + s.connect <- errors.New("disconnected") + }() + } else { + go func() { + s.connectionRoutine() + }() } } -func (s *Type) rxRoutine(exit bool) { +func (s *impl) rxRoutine() { var err error defer func() { s.rxWg.Done() - if err != nil { - if _, ok := err.(packet.ReasonCode); !ok { - err = nil - } - s.onConnectionClose(s.will, err) + s.onConnectionClose(err) } }() - if exit { - err = errors.New("disconnect") - return - } - - buf := bufio.NewReader(s.Conn) + buf := bufio.NewReader(s.conn) for atomic.LoadUint32(&s.rxRunning) == 1 { - if s.keepAlive > 0 { - s.keepAliveTimer.Reset(s.keepAlive) - } + s.runKeepAlive() + var pkt packet.Provider if pkt, err = s.readPacket(buf); err == nil { + s.metric.Received(pkt.Type()) err = s.processIncoming(pkt) } if err != nil { atomic.StoreUint32(&s.rxRunning, 0) + break } } - if _, ok := err.(packet.ReasonCode); !ok { - err = s.EventPoll.Resume(s.Desc) + if _, ok := err.(packet.ReasonCode); ok { + return } + + err = s.ePoll.Resume(s.desc) } -func (s *Type) readPacket(buf *bufio.Reader) (packet.Provider, error) { +func (s *impl) connectionRoutine() { + buf := bufio.NewReader(s.conn) + + pkt, err := s.readPacket(buf) + + s.keepAliveTimer.Stop() + if err == nil { + s.metric.Received(pkt.Type()) + err = s.processIncoming(pkt) + } else { + s.connect <- err + } +} + +func (s *impl) readPacket(buf *bufio.Reader) (packet.Provider, error) { var err error if len(s.rxRecv) == 0 { @@ -93,11 +104,6 @@ func (s *Type) readPacket(buf *bufio.Reader) (packet.Provider, error) { return nil, err } - // If not enough bytes are returned, then continue until there's enough. - if len(header) < peekCount { - continue - } - // If we got enough bytes, then check the last byte to see if the continuation // bit is set. If so, increment cnt and continue peeking if header[peekCount-1] >= 0x80 { @@ -114,7 +120,7 @@ func (s *Type) readPacket(buf *bufio.Reader) (packet.Provider, error) { s.rxRecv = make([]byte, s.rxRemaining) } - if s.rxRemaining > int(s.MaxRxPacketSize) { + if s.rxRemaining > int(s.maxRxPacketSize) { return nil, packet.CodePacketTooLarge } @@ -122,14 +128,17 @@ func (s *Type) readPacket(buf *bufio.Reader) (packet.Provider, error) { for offset != s.rxRemaining { var n int - if n, err = buf.Read(s.rxRecv[offset:]); err != nil { + + n, err = buf.Read(s.rxRecv[offset:]) + offset += n + if err != nil { + s.rxRemaining -= offset return nil, err } - offset += n } var pkt packet.Provider - pkt, _, err = packet.Decode(s.Version, s.rxRecv) + pkt, _, err = packet.Decode(s.version, s.rxRecv) s.rxRecv = []byte{} s.rxRemaining = 0 diff --git a/connection/transmitter.go b/connection/transmitter.go index 09a5e70..1a77d75 100644 --- a/connection/transmitter.go +++ b/connection/transmitter.go @@ -2,20 +2,25 @@ package connection import ( "container/list" - "errors" + "math/rand" "net" + "reflect" "sync/atomic" "time" - "reflect" - - "math/rand" - "github.com/VolantMQ/volantmq/packet" "go.uber.org/zap" ) -func (s *Type) gPush(pkt packet.Provider) { +func (s *impl) gPushFront(pkt packet.Provider) { + s.txGLock.Lock() + s.txGMessages.PushFront(pkt) + s.txGLock.Unlock() + s.txSignalAvailable() + s.txRun() +} + +func (s *impl) gPush(pkt packet.Provider) { s.txGLock.Lock() s.txGMessages.PushBack(pkt) s.txGLock.Unlock() @@ -23,17 +28,20 @@ func (s *Type) gPush(pkt packet.Provider) { s.txRun() } -func (s *Type) gLoad(pkt packet.Provider) { +func (s *impl) gLoad(pkt packet.Provider) { + s.txGLock.Lock() s.txGMessages.PushBack(pkt) + s.txGLock.Unlock() s.txSignalAvailable() } -func (s *Type) gLoadList(l *list.List) { +func (s *impl) gLoadList(l *list.List) { + s.txGLock.Lock() s.txGMessages.PushBackList(l) - s.txSignalAvailable() + s.txGLock.Unlock() } -func (s *Type) qPush(pkt interface{}) { +func (s *impl) qPush(pkt interface{}) { s.txQLock.Lock() s.txQMessages.PushBack(pkt) s.txQLock.Unlock() @@ -41,17 +49,20 @@ func (s *Type) qPush(pkt interface{}) { s.txRun() } -func (s *Type) qLoad(pkt interface{}) { +func (s *impl) qLoad(pkt interface{}) { + s.txQLock.Lock() s.txQMessages.PushBack(pkt) + s.txQLock.Unlock() s.txSignalAvailable() } -func (s *Type) qLoadList(l *list.List) { +func (s *impl) qLoadList(l *list.List) { + s.txQLock.Lock() s.txQMessages.PushBackList(l) - s.txSignalAvailable() + s.txQLock.Unlock() } -func (s *Type) txSignalAvailable() { +func (s *impl) txSignalAvailable() { select { case <-s.quit: return @@ -63,7 +74,7 @@ func (s *Type) txSignalAvailable() { } } -func (s *Type) signalQuota() { +func (s *impl) signalQuota() { s.txQLock.Lock() s.txQuotaExceeded = false l := s.txQMessages.Len() @@ -75,7 +86,7 @@ func (s *Type) signalQuota() { } } -func (s *Type) txRun() { +func (s *impl) txRun() { select { case <-s.quit: return @@ -89,32 +100,31 @@ func (s *Type) txRun() { } } -func (s *Type) flushBuffers(buf net.Buffers) error { - _, e := buf.WriteTo(s.Conn) +func (s *impl) flushBuffers(buf net.Buffers) error { + _, e := buf.WriteTo(s.conn) buf = net.Buffers{} - // todo metrics return e } -func (s *Type) packetFitsSize(value interface{}) bool { +func (s *impl) packetFitsSize(value interface{}) bool { var sz int var err error if obj, ok := value.(sizeAble); !ok { s.log.Fatal("Object does not belong to allowed types", - zap.String("ClientID", s.ID), + zap.String("ClientID", s.id), zap.String("Type", reflect.TypeOf(value).String())) } else { if sz, err = obj.Size(); err != nil { - s.log.Error("Couldn't calculate message size", zap.String("ClientID", s.ID), zap.Error(err)) + s.log.Error("Couldn't calculate message size", zap.String("ClientID", s.id), zap.Error(err)) return false } } // ignore any packet with size bigger than negotiated - if sz > int(s.MaxTxPacketSize) { + if sz > int(s.maxTxPacketSize) { s.log.Warn("Ignore packet with size bigger than negotiated with client", - zap.String("ClientID", s.ID), - zap.Uint32("negotiated", s.MaxTxPacketSize), + zap.String("ClientID", s.id), + zap.Uint32("negotiated", s.maxTxPacketSize), zap.Int("actual", sz)) return false } @@ -122,21 +132,21 @@ func (s *Type) packetFitsSize(value interface{}) bool { return true } -func (s *Type) gAvailable() bool { +func (s *impl) gAvailable() bool { defer s.txGLock.Unlock() s.txGLock.Lock() return s.txGMessages.Len() > 0 } -func (s *Type) qAvailable() bool { +func (s *impl) qAvailable() bool { defer s.txQLock.Unlock() s.txQLock.Lock() return !s.txQuotaExceeded && s.txQMessages.Len() > 0 } -func (s *Type) gPopPacket() packet.Provider { +func (s *impl) gPopPacket() packet.Provider { defer s.txGLock.Unlock() s.txGLock.Lock() @@ -146,17 +156,16 @@ func (s *Type) gPopPacket() packet.Provider { return nil } - value := s.txGMessages.Remove(elem) - - return value.(packet.Provider) + return s.txGMessages.Remove(elem).(packet.Provider) } -func (s *Type) qPopPacket() packet.Provider { +func (s *impl) qPopPacket() packet.Provider { defer s.txQLock.Unlock() s.txQLock.Lock() + var pkt packet.Provider + if elem := s.txQMessages.Front(); !s.txQuotaExceeded && elem != nil { - var pkt packet.Provider value := elem.Value switch m := value.(type) { case *packet.Publish: @@ -165,9 +174,7 @@ func (s *Type) qPopPacket() packet.Provider { if err == errExit { atomic.StoreUint32(&s.txRunning, 0) return nil - } - - if err == errQuotaExceeded { + } else if err == errQuotaExceeded { s.txQuotaExceeded = true } @@ -179,82 +186,74 @@ func (s *Type) qPopPacket() packet.Provider { } s.txQMessages.Remove(elem) s.pubOut.store(pkt) - - return pkt } - return nil + return pkt } -func (s *Type) txRoutine() { +func (s *impl) txRoutine() { var err error defer func() { s.txWg.Done() if err != nil { - s.onConnectionClose(true, nil) + s.onConnectionClose(err) } }() sendBuffers := net.Buffers{} + for atomic.LoadUint32(&s.txRunning) == 1 { select { - case <-s.quit: - err = errors.New("exit") - atomic.StoreUint32(&s.txRunning, 0) - return - case <-s.txTimer.C: - if err = s.flushBuffers(sendBuffers); err != nil { - atomic.StoreUint32(&s.txRunning, 0) - return - } - sendBuffers = net.Buffers{} - - if s.qAvailable() || s.gAvailable() { - s.txSignalAvailable() - } else { - atomic.StoreUint32(&s.txRunning, 0) - } case <-s.txAvailable: - // check if there any control packets except PUBLISH QoS 1/2 - // and process them prevLen := len(sendBuffers) - for _, pkt := range s.popPackets() { - switch _p := pkt.(type) { - case *packet.Publish: - if _p.Expired(true) { - pkt = nil - } else { - s.setTopicAlias(_p) + + for packets := s.popPackets(); len(packets) > 0; packets = s.popPackets() { + for _, pkt := range packets { + switch pack := pkt.(type) { + case *packet.Publish: + if _, expireLeft, expired := pack.Expired(); expired { + continue + } else { + if expireLeft > 0 { + if err = pkt.PropertySet(packet.PropertyPublicationExpiry, expireLeft); err != nil { + s.log.Error("Set publication expire", zap.String("ClientID", s.id), zap.Error(err)) + } + } + s.setTopicAlias(pack) + } } - } - if pkt != nil { if ok := s.packetFitsSize(pkt); ok { if buf, e := packet.Encode(pkt); e != nil { - s.log.Error("Message encode", zap.String("ClientID", s.ID), zap.Error(err)) + s.log.Error("Message encode", zap.String("ClientID", s.id), zap.Error(err)) } else { + // todo (troian) might not be good place to do metrics + s.metric.Sent(pkt.Type()) sendBuffers = append(sendBuffers, buf) } } } - } - available := true + if len(sendBuffers) >= 5 { + break + } + } + available := false if s.qAvailable() || s.gAvailable() { s.txSignalAvailable() - } else { - available = false + available = true } - if prevLen == 0 { + if prevLen == 0 && len(sendBuffers) < 5 { s.txTimer.Reset(1 * time.Millisecond) } else if len(sendBuffers) >= 5 { s.txTimer.Stop() if err = s.flushBuffers(sendBuffers); err != nil { atomic.StoreUint32(&s.txRunning, 0) + return } sendBuffers = net.Buffers{} @@ -263,11 +262,28 @@ func (s *Type) txRoutine() { atomic.StoreUint32(&s.txRunning, 0) } } + case <-s.txTimer.C: + if err = s.flushBuffers(sendBuffers); err != nil { + atomic.StoreUint32(&s.txRunning, 0) + return + } + + sendBuffers = net.Buffers{} + + if s.qAvailable() || s.gAvailable() { + s.txSignalAvailable() + } else { + atomic.StoreUint32(&s.txRunning, 0) + } + case <-s.quit: + atomic.StoreUint32(&s.txRunning, 0) + return } } } -func (s *Type) popPackets() []packet.Provider { +// + +func (s *impl) popPackets() []packet.Provider { var packets []packet.Provider if pkt := s.gPopPacket(); pkt != nil { packets = append(packets, pkt) @@ -280,23 +296,23 @@ func (s *Type) popPackets() []packet.Provider { return packets } -func (s *Type) setTopicAlias(pkt *packet.Publish) { - if s.MaxTxTopicAlias > 0 { - var ok bool +// + +func (s *impl) setTopicAlias(pkt *packet.Publish) { + if s.maxTxTopicAlias > 0 { + var exists bool var alias uint16 - if alias, ok = s.txTopicAlias[pkt.Topic()]; !ok { - if s.topicAliasCurrMax < s.MaxTxTopicAlias { + if alias, exists = s.txTopicAlias[pkt.Topic()]; !exists { + if s.topicAliasCurrMax < s.maxTxTopicAlias { s.topicAliasCurrMax++ alias = s.topicAliasCurrMax - ok = true } else { - alias = uint16(rand.Intn(int(s.MaxTxTopicAlias)) + 1) + alias = uint16(rand.Intn(int(s.maxTxTopicAlias))) } - } else { - ok = false + + s.txTopicAlias[pkt.Topic()] = alias } - if err := pkt.PropertySet(packet.PropertyTopicAlias, alias); err == nil && !ok { + if err := pkt.PropertySet(packet.PropertyTopicAlias, alias); err == nil && exists { pkt.SetTopic("") // nolint: errcheck } } diff --git a/go.test.sh b/go.test.sh index 80d33f9..2b62108 100755 --- a/go.test.sh +++ b/go.test.sh @@ -4,7 +4,7 @@ set -e echo "" > coverage.txt for d in $(go list ./... | grep -v vendor); do - go test -race -coverprofile=profile.out -covermode=atomic $d + go test -race -coverprofile=profile.out -covermode=atomic ${d} if [ -f profile.out ]; then cat profile.out >> coverage.txt rm profile.out diff --git a/packet/auth.go b/packet/auth.go index 7989c1a..3f90324 100644 --- a/packet/auth.go +++ b/packet/auth.go @@ -28,11 +28,17 @@ type Auth struct { var _ Provider = (*Auth)(nil) -// newAuth creates a new AUTH message +// newAuth creates a new AUTH packet func newAuth() *Auth { - msg := &Auth{} + return &Auth{} +} + +// NewAuth creates a new AUTH packet +func NewAuth(v ProtocolVersion) *Auth { + p := newAuth() + p.init(AUTH, v, p.size, p.encodeMessage, p.decodeMessage) - return msg + return p } // ReasonCode get authentication reason diff --git a/packet/connack.go b/packet/connack.go index eb6c185..98a70e2 100644 --- a/packet/connack.go +++ b/packet/connack.go @@ -29,11 +29,20 @@ type ConnAck struct { var _ Provider = (*ConnAck)(nil) -// newConnAck creates a new CONNACK message +// newConnAck creates a new CONNACK packet func newConnAck() *ConnAck { return &ConnAck{} } +// NewConnAck creates a new CONNACK packet +func NewConnAck(v ProtocolVersion) *ConnAck { + p := newConnAck() + + p.init(CONNACK, v, p.size, p.encodeMessage, p.decodeMessage) + + return p +} + // SessionPresent returns the session present flag value func (msg *ConnAck) SessionPresent() bool { return msg.sessionPresent diff --git a/packet/connack_test.go b/packet/connack_test.go index e0719bf..d2f9ae9 100644 --- a/packet/connack_test.go +++ b/packet/connack_test.go @@ -40,77 +40,77 @@ func TestConnAckMessageFields(t *testing.T) { } func TestConnAckMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 0, // session not present byte(CodeSuccess), // connection accepted } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) require.NoError(t, err) msg, ok := m.(*ConnAck) require.Equal(t, true, ok, "Invalid message type") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.False(t, msg.SessionPresent(), "Error decoding session present flag.") require.Equal(t, CodeSuccess, msg.ReturnCode(), "Error decoding return code.") } // testing wrong message length func TestConnAckMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 3, 0, // session not present byte(CodeSuccess), // connection accepted } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err, "Error decoding message.") } // testing wrong message size func TestConnAckMessageDecode3(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 0, // session not present } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err, "Error decoding message.") } // testing wrong reserve bits func TestConnAckMessageDecode4(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 64, // <- wrong size 0, // connection accepted } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err, "Error decoding message.") } // testing invalid return code func TestConnAckMessageDecode5(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 0, 6, // <- wrong code } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err, "Error decoding message.") } func TestConnAckMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 1, // session present @@ -132,38 +132,38 @@ func TestConnAckMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n, "Error encoding message") - require.Equal(t, msgBytes, dst[:n], "Error encoding connack message") + require.Equal(t, len(buf), n, "Error encoding message") + require.Equal(t, buf, dst[:n], "Error encoding connack message") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestConnAckDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNACK << 4), 2, 0, // session not present 0, // connection accepted } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*ConnAck) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } func TestConnAckEncodeEnsureSize(t *testing.T) { diff --git a/packet/connect.go b/packet/connect.go index a719b2c..18554f4 100644 --- a/packet/connect.go +++ b/packet/connect.go @@ -29,7 +29,7 @@ func init() { clientIDRegexp = regexp.MustCompile(`^[0-9a-zA-Z \-_,.|]*$`) } -// Connect After a Network Connection is established by a Client to a Server, the first Packet +// Connect Accept After a Network Connection is established by a Client to a Server, the first Packet // sent from the Client to the Server MUST be a CONNECT Packet [MQTT-3.1.0-1]. // // A Client can only send the CONNECT Packet once over a Network Connection. The Server @@ -61,17 +61,16 @@ type Connect struct { var _ Provider = (*Connect)(nil) func newConnect() *Connect { - msg := &Connect{} - - return msg + return &Connect{} } -// Version returns the the 8 bit unsigned value that represents the revision level -// of the protocol used by the Client. The value of the Protocol Level field for -// the version 3.1.1 of the protocol is 4 (0x04). -//func (msg *Connect) Version() byte { -// return msg.version -//} +// NewConnect creates a new CONNECT packet +func NewConnect(v ProtocolVersion) *Connect { + p := newConnect() + p.init(CONNECT, v, p.size, p.encodeMessage, p.decodeMessage) + p.properties.reset() + return p +} // IsClean returns the bit that specifies the handling of the Session state. // The Client and Server can store Session state to enable reliable messaging to @@ -201,7 +200,7 @@ func (msg *Connect) SetCredentials(u []byte, p []byte) error { } // willFlag returns the bit that specifies whether a Will Message should be stored -// on the server. If the Will Flag is set to 1 this indicates that, if the Connect +// on the server. If the Will Flag is set to 1 this indicates that, if the Accept // request is accepted, a Will Message MUST be stored on the Server and associated // with the Network Connection. func (msg *Connect) willFlag() bool { @@ -345,9 +344,7 @@ func (msg *Connect) decodeMessage(from []byte) (int, error) { // V3.1.1 [MQTT-3.1.2-2] // V5.0 [MQTT-3.1.2-2] - if verStr, ok := SupportedVersions[msg.version]; !ok { - return offset, ErrInvalidProtocolVersion - } else if verStr != string(protoName) { + if verStr, ok := SupportedVersions[msg.version]; !ok || verStr != string(protoName) { return offset, ErrInvalidProtocolVersion } diff --git a/packet/connect_test.go b/packet/connect_test.go index 2f4f92e..c7299c1 100644 --- a/packet/connect_test.go +++ b/packet/connect_test.go @@ -119,7 +119,7 @@ func TestConnectMessageFields(t *testing.T) { } func TestConnectMessageDecodeV3(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 62, 0, // Length MSB (0) @@ -146,13 +146,13 @@ func TestConnectMessageDecodeV3(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message") msg, ok := m.(*Connect) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, 206, int(msg.connectFlags), "Incorrect flag value.") require.Equal(t, 10, int(msg.KeepAlive()), "Incorrect KeepAlive value.") require.Equal(t, "volantmq", string(msg.ClientID()), "Incorrect client ID value.") @@ -169,13 +169,13 @@ func TestConnectMessageDecodeV3(t *testing.T) { require.Equal(t, "volantmq", string(username), "Incorrect username value.") require.Equal(t, "verysecret", string(password), "Incorrect password value.") - _, _, err = Decode(ProtocolV50, msgBytes) + _, _, err = Decode(ProtocolV50, buf) require.NoError(t, err) } func TestConnectMessageDecode2(t *testing.T) { // missing last byte 't' - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -202,14 +202,14 @@ func TestConnectMessageDecode2(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.EqualError(t, err, ErrInsufficientDataSize.Error()) - _, _, err = Decode(ProtocolV50, msgBytes) + _, _, err = Decode(ProtocolV50, buf) require.EqualError(t, err, ErrInsufficientDataSize.Error()) // missing last byte 't' - msgBytes = []byte{ + buf = []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -237,13 +237,13 @@ func TestConnectMessageDecode2(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', } - _, _, err = Decode(ProtocolV50, msgBytes) + _, _, err = Decode(ProtocolV50, buf) require.Error(t, err) } func TestConnectMessageDecode3(t *testing.T) { // missing last byte 't' - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -271,16 +271,16 @@ func TestConnectMessageDecode3(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.EqualError(t, err, ErrInsufficientDataSize.Error()) - _, _, err = Decode(ProtocolV50, msgBytes) + _, _, err = Decode(ProtocolV50, buf) require.EqualError(t, err, ErrInsufficientDataSize.Error()) } func TestConnectMessageDecode4(t *testing.T) { // extra bytes - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -308,13 +308,13 @@ func TestConnectMessageDecode4(t *testing.T) { 'e', 'x', 't', 'r', 'a', } - _, n, err := Decode(ProtocolV311, msgBytes) + _, n, err := Decode(ProtocolV311, buf) require.NoError(t, err) require.Equal(t, 62, n) // extra bytes - msgBytes = []byte{ + buf = []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -342,12 +342,12 @@ func TestConnectMessageDecode4(t *testing.T) { 'e', 'x', 't', 'r', 'a', } - _, _, err = Decode(ProtocolV311, msgBytes) + _, _, err = Decode(ProtocolV311, buf) require.Error(t, err) // extra bytes - msgBytes = []byte{ + buf = []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -376,14 +376,14 @@ func TestConnectMessageDecode4(t *testing.T) { 'e', 'x', 't', 'r', 'a', } - _, n, err = Decode(ProtocolV311, msgBytes) + _, n, err = Decode(ProtocolV311, buf) require.NoError(t, err) require.Equal(t, 63, n) } func TestConnectMessageDecode5(t *testing.T) { // missing client Id, clean session == 0 - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 53, 0, // Length MSB (0) @@ -409,13 +409,13 @@ func TestConnectMessageDecode5(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestConnectMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 62, 0, // Length MSB (0) @@ -460,12 +460,12 @@ func TestConnectMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") require.Equal(t, ProtocolV311, msg.Version()) // V5.0 - msgBytes = []byte{ + buf = []byte{ byte(CONNECT << 4), 63, 0, // Length MSB (0) @@ -512,14 +512,14 @@ func TestConnectMessageEncode(t *testing.T) { require.NoError(t, err, "Error decoding message") require.Equal(t, ProtocolV50, msg.Version()) - require.Equal(t, len(msgBytes), n, "Error decoding message") - require.Equal(t, msgBytes, dst[:n], "Error decoding message") + require.Equal(t, len(buf), n, "Error decoding message") + require.Equal(t, buf, dst[:n], "Error decoding message") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestConnectDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(CONNECT << 4), 60, 0, // Length MSB (0) @@ -546,22 +546,22 @@ func TestConnectDecodeEncodeEquiv(t *testing.T) { 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', } - m, n, err := Decode(ProtocolV50, msgBytes) + m, n, err := Decode(ProtocolV50, buf) msg, ok := m.(*Connect) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV50, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/disconnect.go b/packet/disconnect.go index 98ab51e..a9fe0a1 100644 --- a/packet/disconnect.go +++ b/packet/disconnect.go @@ -28,6 +28,13 @@ func newDisconnect() *Disconnect { return &Disconnect{} } +// NewDisconnect creates a new DISCONNECT packet +func NewDisconnect(v ProtocolVersion) *Disconnect { + p := newDisconnect() + p.init(DISCONNECT, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // ReasonCode get disconnect reason func (msg *Disconnect) ReasonCode() ReasonCode { return msg.reasonCode @@ -85,10 +92,13 @@ func (msg *Disconnect) encodeMessage(to []byte) (int, error) { var err error if msg.version == ProtocolV50 { pLen := msg.properties.FullLen() + // The Reason Code and Property Length can be omitted if the Reason Code is 0x00 (Normal disconnection) + // and there are no Properties. In this case the DISCONNECT has a Remaining Length of 0 if pLen > 1 || msg.reasonCode != CodeSuccess { to[offset] = byte(msg.reasonCode) offset++ + // [MQTT-3.14.2.2.1] if pLen > 1 { var n int n, err = msg.properties.encode(to[offset:]) diff --git a/packet/disconnect_test.go b/packet/disconnect_test.go index a1e4f1b..3084ee1 100644 --- a/packet/disconnect_test.go +++ b/packet/disconnect_test.go @@ -21,35 +21,35 @@ import ( ) func TestDisconnectMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(DISCONNECT << 4), 0, } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Disconnect) require.NoError(t, err, "Error decoding message.") require.Equal(t, true, ok, "Invalid message type") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, DISCONNECT, msg.Type(), "Error decoding message.") - msgBytes = []byte{ + buf = []byte{ byte(DISCONNECT << 4), 1, 0, } - _, _, err = Decode(ProtocolV50, msgBytes) + _, _, err = Decode(ProtocolV50, buf) require.EqualError(t, CodeMalformedPacket, err.Error()) - _, _, err = Decode(ProtocolV311, msgBytes) + _, _, err = Decode(ProtocolV311, buf) require.EqualError(t, CodeRefusedServerUnavailable, err.Error()) } func TestDisconnectMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(DISCONNECT << 4), 0, } @@ -62,34 +62,34 @@ func TestDisconnectMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestDisconnectDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(DISCONNECT << 4), 0, } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Disconnect) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/errors.go b/packet/errors.go index 2200b73..6efa24f 100644 --- a/packet/errors.go +++ b/packet/errors.go @@ -9,8 +9,8 @@ const ( ErrInvalidUnSubscribe Error = iota // ErrInvalidUnSubAck Invalid UNSUBACK message ErrInvalidUnSubAck - // ErrPackedIDNotMatched Packet ID does not match ErrDupViolation + // ErrPackedIDNotMatched Packet ID does not match ErrPackedIDNotMatched ErrInvalid // ErrPackedIDZero cannot be 0 diff --git a/packet/header.go b/packet/header.go index 05900ac..cf99351 100644 --- a/packet/header.go +++ b/packet/header.go @@ -23,16 +23,16 @@ type header struct { const ( offsetPacketType byte = 0x04 - //offsetPublishFlagRetain byte = 0x00 + // offsetPublishFlagRetain byte = 0x00 offsetPublishFlagQoS byte = 0x01 - //offsetPublishFlagDup byte = 0x03 + // offsetPublishFlagDup byte = 0x03 offsetConnFlagWillQoS byte = 0x03 - //offsetSubscribeOps byte = 0x06 - //offsetSubscriptionQoS byte = 0x00 + // offsetSubscribeOps byte = 0x06 + // offsetSubscriptionQoS byte = 0x00 offsetSubscriptionNL byte = 0x02 offsetSubscriptionRAP byte = 0x03 offsetSubscriptionRetainHandling byte = 0x04 - //offsetSubscriptionReserved byte = 0x06 + // offsetSubscriptionReserved byte = 0x06 ) @@ -55,6 +55,18 @@ const ( maskSubscriptionReserved byte = 0xC0 ) +func (h *header) init(t Type, v ProtocolVersion, sz func() int, enc, dec func([]byte) (int, error)) { + h.mType = t + h.version = v + h.cb.encode = enc + h.cb.decode = dec + h.cb.size = sz + + if v >= ProtocolV50 { + h.properties.reset() + } +} + // Name returns a string representation of the message type. Examples include // "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of // the message types and cannot be changed. @@ -85,10 +97,13 @@ func (h *header) RemainingLength() int32 { return h.remLen } +// Version protocol version used by packet func (h *header) Version() ProtocolVersion { return h.version } +// ID packet id, valid only for +// PUBLISH (QoS 1/2), PUBACK, PUBREC, PUBREL, PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK func (h *header) ID() (IDType, error) { if len(h.packetID) == 0 { return 0, ErrNotSet @@ -97,6 +112,7 @@ func (h *header) ID() (IDType, error) { return IDType(binary.BigEndian.Uint16(h.packetID)), nil } +// Encode packet into buffer, Size() should be called to determine expected buffer size func (h *header) Encode(to []byte) (int, error) { expectedSize, err := h.Size() if err != nil { @@ -121,6 +137,7 @@ func (h *header) Encode(to []byte) (int, error) { return offset, err } +// SetVersion protocol version used to encode packet func (h *header) SetVersion(v ProtocolVersion) { h.version = v } @@ -136,6 +153,12 @@ func (h *header) Size() (int, error) { return h.size() + ml, nil } +// PropertiesDiscard discard all previously set properties +func (h *header) PropertiesDiscard() { + h.properties.reset() +} + +// PropertyGet get property value, nil if not present func (h *header) PropertyGet(id PropertyID) PropertyToType { if h.version != ProtocolV50 { return nil @@ -144,6 +167,7 @@ func (h *header) PropertyGet(id PropertyID) PropertyToType { return h.properties.Get(id) } +// PropertySet set value func (h *header) PropertySet(id PropertyID, val interface{}) error { if h.version != ProtocolV50 { return ErrNotSupported @@ -152,6 +176,7 @@ func (h *header) PropertySet(id PropertyID, val interface{}) error { return h.properties.Set(h.mType, id, val) } +// PropertyForEach iterate over properties func (h *header) PropertyForEach(f func(PropertyID, PropertyToType)) error { if h.version != ProtocolV50 { return ErrNotSupported @@ -225,7 +250,6 @@ func (h *header) decode(from []byte) (int, error) { offset := 0 // decode and validate fixed header - //h.mTypeFlags = src[total] h.mType = Type(from[offset] >> offsetPacketType) h.mFlags = from[offset] & maskMessageFlags diff --git a/packet/header_test.go b/packet/header_test.go index 6224856..34d6f39 100644 --- a/packet/header_test.go +++ b/packet/header_test.go @@ -21,75 +21,75 @@ import ( ) func TestMessageHeaderFields(t *testing.T) { - header := &header{} + hdr := &header{} - header.setRemainingLength(33) // nolint: errcheck + hdr.setRemainingLength(33) // nolint: errcheck - require.Equal(t, int32(33), header.RemainingLength()) + require.Equal(t, int32(33), hdr.RemainingLength()) - err := header.setRemainingLength(268435456) + err := hdr.setRemainingLength(268435456) require.Error(t, err) - err = header.setRemainingLength(-1) + err = hdr.setRemainingLength(-1) require.Error(t, err) - header.setType(PUBREL) + hdr.setType(PUBREL) - require.Equal(t, PUBREL, header.Type()) - require.Equal(t, "PUBREL", header.Name()) - require.Equal(t, 2, int(header.Flags())) - require.Equal(t, PUBREL.Desc(), header.Desc()) + require.Equal(t, PUBREL, hdr.Type()) + require.Equal(t, "PUBREL", hdr.Name()) + require.Equal(t, 2, int(hdr.Flags())) + require.Equal(t, PUBREL.Desc(), hdr.Desc()) } // Not enough bytes func TestMessageHeaderDecode(t *testing.T) { buf := []byte{0x6f, 193, 2} - header := &header{} + hdr := &header{} - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.Error(t, err) } // Remaining length too big func TestMessageHeaderDecode2(t *testing.T) { buf := []byte{0x62, 0xff, 0xff, 0xff, 0xff} - header := &header{} + hdr := &header{} - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.EqualError(t, err, ErrInsufficientDataSize.Error()) } func TestMessageHeaderDecode3(t *testing.T) { buf := []byte{0x62, 0xff} - header := &header{} + hdr := &header{} - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.Error(t, err) } func TestMessageHeaderDecode4(t *testing.T) { buf := []byte{0x62, 0xff, 0xff, 0xff, 0x7f} - header := &header{ + hdr := &header{ mType: PUBREL, mFlags: 2, } - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.EqualError(t, ErrInsufficientDataSize, err.Error()) - require.Equal(t, maxRemainingLength, header.RemainingLength()) + require.Equal(t, maxRemainingLength, hdr.RemainingLength()) } func TestMessageHeaderDecode5(t *testing.T) { buf := []byte{0x62, 0xff, 0x7f} - header := &header{ + hdr := &header{ mType: PUBREL, mFlags: 2, } - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.Error(t, err) } @@ -97,29 +97,29 @@ func TestMessageHeaderDecode6(t *testing.T) { buf := []byte{byte(PUBLISH<> offsetPacketType), mFlags: buf[0] | maskMessageFlags, } - _, err := header.decode(buf) + _, err := hdr.decode(buf) require.EqualError(t, err, CodeRefusedServerUnavailable.Error()) } func TestMessageHeaderEncode1(t *testing.T) { - header := &header{} + hdr := &header{} //headerBytes := []byte{0x62, 193, 2} - //header.setVT(ProtocolV311, PUBREL) + //hdr.setVT(ProtocolV311, PUBREL) //require.NoError(t, err) - err := header.setRemainingLength(321) + err := hdr.setRemainingLength(321) require.NoError(t, err) //buf := make([]byte, 3) - //n, err := header.encode(buf) + //n, err := hdr.encode(buf) //require.NoError(t, err) //require.Equal(t, 3, n) @@ -127,33 +127,33 @@ func TestMessageHeaderEncode1(t *testing.T) { } func TestMessageHeaderEncode2(t *testing.T) { - header := &header{} + hdr := &header{} //header.setVT(ProtocolV311, PUBREL) //require.NoError(t, err) - header.remLen = 268435456 + hdr.remLen = 268435456 //buf := make([]byte, 5) - //_, err = header.encode(buf) + //_, err = hdr.encode(buf) // //require.Error(t, err) } func TestMessageHeaderEncode3(t *testing.T) { - header := &header{} + hdr := &header{} //headerBytes := []byte{0x62, 0xff, 0xff, 0xff, 0x7f} - //header.setVT(ProtocolV311, PUBREL) + //hdr.setVT(ProtocolV311, PUBREL) //require.NoError(t, err) - err := header.setRemainingLength(maxRemainingLength) + err := hdr.setRemainingLength(maxRemainingLength) require.NoError(t, err) //buf := make([]byte, 5) - //n, err := header.encode(buf) + //n, err := hdr.encode(buf) // //require.NoError(t, err) //require.Equal(t, 5, n) diff --git a/packet/packet.go b/packet/packet.go index 691d416..dd643a3 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -1,6 +1,7 @@ package packet import ( + "fmt" "strings" "unicode/utf8" ) @@ -99,6 +100,8 @@ type Provider interface { // Version get protocol version used by message Version() ProtocolVersion + PropertiesDiscard() + PropertyGet(PropertyID) PropertyToType PropertySet(PropertyID, interface{}) error @@ -144,54 +147,54 @@ func newMessage(v ProtocolVersion, t Type) (Provider, error) { switch t { case CONNECT: - m = newConnect() + m = NewConnect(v) case CONNACK: - m = newConnAck() + m = NewConnAck(v) case PUBLISH: - m = newPublish() + m = NewPublish(v) case PUBACK: - m = newPubAck() + m = NewPubAck(v) case PUBREC: - m = newPubRec() + m = NewPubRec(v) case PUBREL: - m = newPubRel() + m = NewPubRel(v) case PUBCOMP: - m = newPubComp() + m = NewPubComp(v) case SUBSCRIBE: - m = newSubscribe() + m = NewSubscribe(v) case SUBACK: - m = newSubAck() + m = NewSubAck(v) case UNSUBSCRIBE: - m = newUnSubscribe() + m = NewUnSubscribe(v) case UNSUBACK: - m = newUnSubAck() + m = NewUnSubAck(v) case PINGREQ: - m = newPingReq() + m = NewPingReq(v) case PINGRESP: - m = newPingResp() + m = NewPingResp(v) case DISCONNECT: - m = newDisconnect() + m = NewDisconnect(v) case AUTH: - if v != ProtocolV50 { + if v < ProtocolV50 { return nil, ErrInvalidMessageType } - m = newAuth() + m = NewAuth(v) default: return nil, ErrInvalidMessageType } m.setType(t) - h := m.getHeader() - - h.version = v - h.cb.encode = m.encodeMessage - h.cb.decode = m.decodeMessage - h.cb.size = m.size - - if v >= ProtocolV50 { - h.properties.properties = make(map[PropertyID]interface{}) - } + //h := m.getHeader() + // + //h.version = v + //h.cb.encode = m.encodeMessage + //h.cb.decode = m.decodeMessage + //h.cb.size = m.size + // + //if v >= ProtocolV50 { + // h.properties.properties = make(map[PropertyID]interface{}) + //} return m, nil } @@ -213,7 +216,7 @@ func Encode(p Provider) ([]byte, error) { // Decode buf into message and return Provider type func Decode(v ProtocolVersion, buf []byte) (msg Provider, total int, err error) { defer func() { - // TODO: this case might be improved + // TODO(troian): this case might be improved // Panic might be provided during message decode with malformed len // For example on length-prefixed payloads/topics or properties: @@ -227,6 +230,7 @@ func Decode(v ProtocolVersion, buf []byte) (msg Provider, total int, err error) // but it might be worth doing such checks (there might be many for each message) on each decode // as it is abnormal and server must close connection if r := recover(); r != nil { + fmt.Println(r) msg = nil total = 0 err = ErrPanicDetected diff --git a/packet/packetType.go b/packet/packetType.go index c1f0cdd..b28a0e7 100644 --- a/packet/packetType.go +++ b/packet/packetType.go @@ -16,7 +16,7 @@ const ( // Dir: Client to Server CONNECT - // CONNACK Connect acknowledgement + // CONNACK Accept acknowledgement // version: v3.1, v3.1.1, v5.0 // Dir: Server to Client CONNACK @@ -87,7 +87,7 @@ var typeName = [AUTH + 1]string{ var typeDescription = [AUTH + 1]string{ "Reserved", "Client request to connect to Server", - "Connect acknowledgement", + "Accept acknowledgement", "Publish message", "Publish acknowledgement", "Publish received (assured delivery part 1)", diff --git a/packet/ping_test.go b/packet/ping_test.go index c3feb87..c1039bb 100644 --- a/packet/ping_test.go +++ b/packet/ping_test.go @@ -21,22 +21,22 @@ import ( ) func TestPingReqMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGREQ << 4), 0, } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*PingReq) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, PINGREQ, msg.Type(), "Error decoding message.") } func TestPingReqMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGREQ << 4), 0, } @@ -51,25 +51,25 @@ func TestPingReqMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } func TestPingRespMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGRESP << 4), 0, } - msg, n, err := Decode(ProtocolV311, msgBytes) + msg, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, PINGRESP, msg.Type(), "Error decoding message.") } func TestPingRespMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGRESP << 4), 0, } @@ -83,58 +83,58 @@ func TestPingRespMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestPingReqDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGREQ << 4), 0, } - msg, n, err := Decode(ProtocolV311, msgBytes) + msg, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestPingRespDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PINGRESP << 4), 0, } - msg, n, err := Decode(ProtocolV311, msgBytes) + msg, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/pingreq.go b/packet/pingreq.go index 8253d56..2986371 100644 --- a/packet/pingreq.go +++ b/packet/pingreq.go @@ -29,6 +29,13 @@ func newPingReq() *PingReq { return &PingReq{} } +// NewPingReq creates a new PINGREQ packet +func NewPingReq(v ProtocolVersion) *PingReq { + p := newPingReq() + p.init(PINGREQ, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // decode message func (msg *PingReq) decodeMessage(src []byte) (int, error) { return 0, nil diff --git a/packet/pingresp.go b/packet/pingresp.go index 40d80a1..986021e 100644 --- a/packet/pingresp.go +++ b/packet/pingresp.go @@ -22,10 +22,17 @@ type PingResp struct { var _ Provider = (*PingResp)(nil) -func newPingResp() Provider { +func newPingResp() *PingResp { return &PingResp{} } +// NewPingResp creates a new PINGRESP packet +func NewPingResp(v ProtocolVersion) *PingResp { + p := newPingResp() + p.init(PINGRESP, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // decode message func (msg *PingResp) decodeMessage(src []byte) (int, error) { return 0, nil diff --git a/packet/property.go b/packet/property.go index 2f5f250..f5083fa 100644 --- a/packet/property.go +++ b/packet/property.go @@ -25,7 +25,7 @@ const ( ErrPropertyWrongType ) -// Error +// Error description func (e PropertyError) Error() string { switch e { case ErrPropertyNotFound: @@ -323,6 +323,11 @@ func (p PropertyID) IsValidPacketType(t Type) bool { return true } +func (p *property) reset() { + p.properties = make(map[PropertyID]interface{}) + p.len = 0 +} + // Len of the encoded property field. Does not include size property len prefix func (p *property) Len() (uint32, int) { return p.len, uvarintCalc(p.len) diff --git a/packet/puback.go b/packet/puback.go index fc42ab8..3fadf52 100644 --- a/packet/puback.go +++ b/packet/puback.go @@ -29,16 +29,32 @@ func newPubAck() *Ack { return &Ack{} } -func newPubRec() *Ack { - return &Ack{} +// NewPubAck creates a new PUBACK packet +func NewPubAck(v ProtocolVersion) *Ack { + p := newPubAck() + p.init(PUBACK, v, p.size, p.encodeMessage, p.decodeMessage) + return p } -func newPubRel() *Ack { - return &Ack{} +// NewPubRec creates a new PUBREC packet +func NewPubRec(v ProtocolVersion) *Ack { + p := newPubAck() + p.init(PUBREC, v, p.size, p.encodeMessage, p.decodeMessage) + return p } -func newPubComp() *Ack { - return &Ack{} +// NewPubRel creates a new PUBREL packet +func NewPubRel(v ProtocolVersion) *Ack { + p := newPubAck() + p.init(PUBREL, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + +// NewPubComp creates a new PUBCOMP packet +func NewPubComp(v ProtocolVersion) *Ack { + p := newPubAck() + p.init(PUBCOMP, v, p.size, p.encodeMessage, p.decodeMessage) + return p } // SetPacketID sets the ID of the packet. diff --git a/packet/puback_test.go b/packet/puback_test.go index ba25d22..a0e6016 100644 --- a/packet/puback_test.go +++ b/packet/puback_test.go @@ -34,19 +34,19 @@ func TestPubAckMessageFields(t *testing.T) { } func TestPubAckMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBACK << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "decode length does not match") + require.Equal(t, len(buf), n, "decode length does not match") require.Equal(t, PUBACK, msg.Type(), "Message type does not match") id, _ := msg.ID() @@ -55,19 +55,19 @@ func TestPubAckMessageDecode(t *testing.T) { // test insufficient bytes func TestPubAckMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBACK << 4), 2, 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestPubAckMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBACK << 4), 2, 0, // packet ID MSB (0) @@ -86,36 +86,36 @@ func TestPubAckMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestPubAckDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBACK << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/pubcomp_test.go b/packet/pubcomp_test.go index 7f09764..89ecfab 100644 --- a/packet/pubcomp_test.go +++ b/packet/pubcomp_test.go @@ -35,19 +35,19 @@ func TestPubCompMessageFields(t *testing.T) { } func TestPubCompMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBCOMP << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, PUBCOMP, msg.Type(), "Error decoding message.") id, _ := msg.ID() @@ -56,19 +56,19 @@ func TestPubCompMessageDecode(t *testing.T) { // test insufficient bytes func TestPubCompMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBCOMP << 4), 2, 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestPubCompMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBCOMP << 4), 2, 0, // packet ID MSB (0) @@ -87,36 +87,36 @@ func TestPubCompMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestPubCompDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBCOMP << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/publish.go b/packet/publish.go index b4fdf55..930237d 100644 --- a/packet/publish.go +++ b/packet/publish.go @@ -14,7 +14,9 @@ package packet -import "time" +import ( + "time" +) // Publish A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client // to transport an Application Message. @@ -33,48 +35,50 @@ func newPublish() *Publish { return &Publish{} } -// SetExpiry time object -func (msg *Publish) SetExpiry(tm time.Time) { - msg.expireAt = tm +// NewPublish creates a new PUBLISH packet +func NewPublish(v ProtocolVersion) *Publish { + p := newPublish() + p.init(PUBLISH, v, p.size, p.encodeMessage, p.decodeMessage) + return p } -// GetExpiry time object -func (msg *Publish) GetExpiry() time.Time { - return msg.expireAt +// SetExpireAt time object +func (msg *Publish) SetExpireAt(tm time.Time) { + msg.expireAt = tm } // Expired check if packet has elapsed it's time or not -// if not expirable returns false -func (msg *Publish) Expired(set bool) bool { - expired := false +// returns false if does not expire +func (msg *Publish) Expired() (time.Time, uint32, bool) { if !msg.expireAt.IsZero() { + var df uint32 + expired := true now := time.Now() - if df := uint32(msg.expireAt.Sub(now) / time.Second); msg.expireAt.After(now) && df > 0 { - if set && msg.version >= ProtocolV50 { - // TODO(troian): check error - msg.properties.Set(msg.mType, PropertyPublicationExpiry, df) // nolint: errcheck + + if msg.expireAt.After(now) { + if df = uint32(msg.expireAt.Sub(now) / time.Second); df > 0 { + expired = false } - } else { - expired = true } + + return msg.expireAt, df, expired } - return expired + return time.Time{}, 0, false } // Clone packet // qos, topic, payload, retain and properties func (msg *Publish) Clone(v ProtocolVersion) (*Publish, error) { // message version should be same as session as encode/decode depends on it - _pkt, _ := New(msg.version, PUBLISH) - pkt, _ := _pkt.(*Publish) + pkt := NewPublish(msg.version) // [MQTT-3.3.1-9] // [MQTT-3.3.1-3] pkt.Set(msg.topic, msg.Payload(), msg.QoS(), msg.Retain(), false) // nolint: errcheck - // clone expiration setting with no matter of version as message should not be delivered to V3 brokers - // when it expired + // clone expiration setting with no matter of version as expired publish packet + // should not be delivered to V3 brokers when it expired pkt.expireAt = msg.expireAt if msg.version == ProtocolV50 && v == ProtocolV50 { diff --git a/packet/publish_test.go b/packet/publish_test.go index 6ea60ed..5cc93b3 100644 --- a/packet/publish_test.go +++ b/packet/publish_test.go @@ -42,20 +42,26 @@ func TestPublishDecode1(t *testing.T) { } func TestPublishExpire(t *testing.T) { - p, err := New(ProtocolV50, PUBLISH) - require.NoError(t, err) - pkt, ok := p.(*Publish) - require.True(t, ok) + pkt := NewPublish(ProtocolV50) + require.NotNil(t, pkt) - require.False(t, pkt.Expired(false)) + _, _, expired := pkt.Expired() - pkt.SetExpiry(time.Now().Add(2 * time.Second)) + require.False(t, expired) + pkt.SetExpireAt(time.Now().Add(2 * time.Second)) time.Sleep(3 * time.Second) - require.True(t, pkt.Expired(false)) + _, _, expired = pkt.Expired() + require.True(t, expired) + + pkt.SetExpireAt(time.Now().Add(3 * time.Second)) - pkt.SetExpiry(time.Now().Add(3 * time.Second)) - require.False(t, pkt.Expired(true)) + var expLeft uint32 + _, expLeft, expired = pkt.Expired() + require.False(t, expired) + + err := pkt.PropertySet(PropertyPublicationExpiry, expLeft) + require.NoError(t, err) prop := pkt.PropertyGet(PropertyPublicationExpiry) require.NotNil(t, prop) diff --git a/packet/pubrec_test.go b/packet/pubrec_test.go index 8dd5305..e47b98d 100644 --- a/packet/pubrec_test.go +++ b/packet/pubrec_test.go @@ -34,18 +34,18 @@ func TestPubRecMessageFields(t *testing.T) { } func TestPubRecMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREC << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, PUBREC, msg.Type(), "Error decoding message.") id, _ := msg.ID() @@ -54,19 +54,19 @@ func TestPubRecMessageDecode(t *testing.T) { // test insufficient bytes func TestPubRecMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREC << 4), 2, 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestPubRecMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREC << 4), 2, 0, // packet ID MSB (0) @@ -85,36 +85,36 @@ func TestPubRecMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestPubRecDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREC << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/pubrel_test.go b/packet/pubrel_test.go index 89eae04..093bd6f 100644 --- a/packet/pubrel_test.go +++ b/packet/pubrel_test.go @@ -34,19 +34,19 @@ func TestPubRelMessageFields(t *testing.T) { } func TestPubRelMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREL<<4) | 2, 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Ack) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, PUBREL, msg.Type(), "Error decoding message.") id, _ := msg.ID() @@ -55,19 +55,19 @@ func TestPubRelMessageDecode(t *testing.T) { // test insufficient bytes func TestPubRelMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREL<<4) | 2, 2, 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestPubRelMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(PUBREL<= CodeSuccess && c <= CodeRefusedNotAuthorized { - return true - } - return false + return c <= CodeRefusedNotAuthorized } // IsValidV5 check either reason code is valid for MQTT V5.0 or not diff --git a/packet/suback.go b/packet/suback.go index 44cca5b..ebc7f31 100644 --- a/packet/suback.go +++ b/packet/suback.go @@ -31,6 +31,13 @@ func newSubAck() *SubAck { return &SubAck{} } +// NewSubAck creates a new SUBACK packet +func NewSubAck(v ProtocolVersion) *SubAck { + p := newSubAck() + p.init(SUBACK, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // ReturnCodes returns the list of QoS returns from the subscriptions sent in the SUBSCRIBE message. func (msg *SubAck) ReturnCodes() []ReasonCode { return msg.returnCodes diff --git a/packet/suback_test.go b/packet/suback_test.go index c04afab..f72b37b 100644 --- a/packet/suback_test.go +++ b/packet/suback_test.go @@ -39,7 +39,7 @@ func TestSubAckMessageFields(t *testing.T) { } func TestSubAckMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBACK << 4), 6, 0, // packet ID MSB (0) @@ -50,7 +50,7 @@ func TestSubAckMessageDecode(t *testing.T) { 0x80, // return code 4 } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*SubAck) require.Equal(t, true, ok, "Invalid message type") @@ -58,14 +58,14 @@ func TestSubAckMessageDecode(t *testing.T) { t.Log(err) } require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, SUBACK, msg.Type(), "Error decoding message.") require.Equal(t, 4, len(msg.ReturnCodes()), "Error adding return code.") } // test with wrong return code func TestSubAckMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBACK << 4), 6, 0, // packet ID MSB (0) @@ -76,12 +76,12 @@ func TestSubAckMessageDecode2(t *testing.T) { 0x81, // return code 4 } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestSubAckMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBACK << 4), 6, 0, // packet ID MSB (0) @@ -108,14 +108,14 @@ func TestSubAckMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message.") - require.Equal(t, len(msgBytes), n, "Encoded length does not match") - require.Equal(t, msgBytes, dst[:n], "Raw message does not match") + require.Equal(t, len(buf), n, "Encoded length does not match") + require.Equal(t, buf, dst[:n], "Raw message does not match") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestSubAckDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBACK << 4), 6, 0, // packet ID MSB (0) @@ -126,22 +126,22 @@ func TestSubAckDecodeEncodeEquiv(t *testing.T) { 0x80, // return code 4 } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*SubAck) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message.") - require.Equal(t, len(msgBytes), n2, "Error encoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error encoding message.") + require.Equal(t, len(buf), n2, "Error encoding message.") + require.Equal(t, buf, dst[:n2], "Error encoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n3, "Error decoding message") + require.Equal(t, len(buf), n3, "Error decoding message") } diff --git a/packet/subscribe.go b/packet/subscribe.go index d912483..2d6fd5c 100644 --- a/packet/subscribe.go +++ b/packet/subscribe.go @@ -36,6 +36,13 @@ func newSubscribe() *Subscribe { return &Subscribe{} } +// NewSubscribe creates a new SUBSCRIBE packet +func NewSubscribe(v ProtocolVersion) *Subscribe { + p := newSubscribe() + p.init(SUBSCRIBE, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // RangeTopics loop through list of topics func (msg *Subscribe) RangeTopics(fn func(string, SubscriptionOptions)) { for i, t := range msg.topics { diff --git a/packet/subscribe_test.go b/packet/subscribe_test.go index b6efaae..ffa9f21 100644 --- a/packet/subscribe_test.go +++ b/packet/subscribe_test.go @@ -41,7 +41,7 @@ func TestSubscribeMessageFields(t *testing.T) { } func TestSubscribeMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBSCRIBE<<4) | 2, 37, 0, // packet ID MSB (0) @@ -61,32 +61,32 @@ func TestSubscribeMessageDecode(t *testing.T) { } //msg := NewSubscribeMessage() - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Subscribe) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, SUBSCRIBE, msg.Type(), "Error decoding message.") require.Equal(t, 3, len(msg.topics), "Error decoding topics.") } // test empty topic list func TestSubscribeMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBSCRIBE<<4) | 2, 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestSubscribeMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBSCRIBE<<4) | 2, 36, 0, // packet ID MSB (0) @@ -119,7 +119,7 @@ func TestSubscribeMessageEncode(t *testing.T) { dst := make([]byte, 100) n, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message.") - require.Equal(t, len(msgBytes), n, "Error encoding message.") + require.Equal(t, len(buf), n, "Error encoding message.") //msg1 := NewSubscribeMessage() var m1 Provider @@ -128,14 +128,14 @@ func TestSubscribeMessageEncode(t *testing.T) { require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, 3, len(msg1.topics), "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestSubscribeDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBSCRIBE<<4) | 2, 37, 0, // packet ID MSB (0) @@ -155,28 +155,28 @@ func TestSubscribeDecodeEncodeEquiv(t *testing.T) { } //msg := NewSubscribeMessage() - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Subscribe) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n, "Raw message length does not match") + require.Equal(t, len(buf), n, "Raw message length does not match") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message") - require.Equal(t, len(msgBytes), n2, "Raw message length does not match") + require.Equal(t, len(buf), n2, "Raw message length does not match") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n3, "Raw message length does not match") + require.Equal(t, len(buf), n3, "Raw message length does not match") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestSubscribeDecodeOrder(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(SUBSCRIBE<<4) | 2, 37, 0, // packet ID MSB (0) @@ -195,11 +195,11 @@ func TestSubscribeDecodeOrder(t *testing.T) { 2, // QoS } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*Subscribe) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n, "Raw message length does not match") + require.Equal(t, len(buf), n, "Raw message length does not match") i := 0 msg.RangeTopics(func(topic string, ops SubscriptionOptions) { @@ -214,7 +214,7 @@ func TestSubscribeDecodeOrder(t *testing.T) { require.Equal(t, "/a/b/#/cdd", topic) require.Equal(t, SubscriptionOptions(QoS2), ops) default: - assert.Error(t, errors.New("Invalid topics count")) + assert.Error(t, errors.New("invalid topics count")) } i++ }) diff --git a/packet/unsuback.go b/packet/unsuback.go index 1982211..f843641 100644 --- a/packet/unsuback.go +++ b/packet/unsuback.go @@ -30,6 +30,13 @@ func newUnSubAck() *UnSubAck { return msg } +// NewUnSubAck creates a new UNSUBACK packet +func NewUnSubAck(v ProtocolVersion) *UnSubAck { + p := newUnSubAck() + p.init(UNSUBACK, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // SetPacketID sets the ID of the packet. func (msg *UnSubAck) SetPacketID(v IDType) { msg.setPacketID(v) @@ -71,6 +78,10 @@ func (msg *UnSubAck) decodeMessage(from []byte) (int, error) { if err != nil { return offset, err } + + for _, c := range from[offset:msg.remLen] { + msg.returnCodes = append(msg.returnCodes, ReasonCode(c)) + } } return offset, nil @@ -89,6 +100,11 @@ func (msg *UnSubAck) encodeMessage(to []byte) (int, error) { var n int n, err = msg.properties.encode(to[offset:]) offset += n + + for _, c := range msg.returnCodes { + to[offset] = byte(c) + offset++ + } } return offset, err @@ -100,6 +116,7 @@ func (msg *UnSubAck) size() int { if msg.version == ProtocolV50 { total += int(msg.properties.FullLen()) + total += len(msg.returnCodes) } return total diff --git a/packet/unsuback_test.go b/packet/unsuback_test.go index 3275078..1cdc7dd 100644 --- a/packet/unsuback_test.go +++ b/packet/unsuback_test.go @@ -34,19 +34,19 @@ func TestUnSubAckMessageFields(t *testing.T) { } func TestUnSubAckMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBACK << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*UnSubAck) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, UNSUBACK, msg.Type(), "Error decoding message.") id, _ := msg.ID() @@ -55,19 +55,19 @@ func TestUnSubAckMessageDecode(t *testing.T) { // test insufficient bytes func TestUnSubAckMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBACK << 4), 2, 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestUnSubAckMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBACK << 4), 2, 0, // packet ID MSB (0) @@ -86,36 +86,36 @@ func TestUnSubAckMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") + require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestUnSubAckDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBACK << 4), 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) msg, ok := m.(*UnSubAck) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n2, "Error decoding message.") - require.Equal(t, msgBytes, dst[:n2], "Error decoding message.") + require.Equal(t, len(buf), n2, "Error decoding message.") + require.Equal(t, buf, dst[:n2], "Error decoding message.") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n3, "Error decoding message.") + require.Equal(t, len(buf), n3, "Error decoding message.") } diff --git a/packet/unsubscribe.go b/packet/unsubscribe.go index e6f3483..0cb649b 100644 --- a/packet/unsubscribe.go +++ b/packet/unsubscribe.go @@ -31,6 +31,13 @@ func newUnSubscribe() *UnSubscribe { return &UnSubscribe{} } +// NewUnSubscribe creates a new UNSUBSCRIBE packet +func NewUnSubscribe(v ProtocolVersion) *UnSubscribe { + p := newUnSubscribe() + p.init(UNSUBSCRIBE, v, p.size, p.encodeMessage, p.decodeMessage) + return p +} + // Topics returns a list of topics sent by the Client. func (msg *UnSubscribe) Topics() []string { return msg.topics diff --git a/packet/unsubscribe_test.go b/packet/unsubscribe_test.go index cb01f28..ae8b44a 100644 --- a/packet/unsubscribe_test.go +++ b/packet/unsubscribe_test.go @@ -39,7 +39,7 @@ func TestUnSubscribeMessageFields(t *testing.T) { } func TestUnSubscribeMessageDecode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBSCRIBE<<4) | 2, 34, 0, // packet ID MSB (0) @@ -55,34 +55,34 @@ func TestUnSubscribeMessageDecode(t *testing.T) { '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message.") msg, ok := m.(*UnSubscribe) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, UNSUBSCRIBE, msg.Type(), "Error decoding message.") require.Equal(t, 3, len(msg.topics), "Error decoding topics.") } // test empty topic list func TestUnSubscribeMessageDecode2(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBSCRIBE<<4) | 2, 2, 0, // packet ID MSB (0) 7, // packet ID LSB (7) } - _, _, err := Decode(ProtocolV311, msgBytes) + _, _, err := Decode(ProtocolV311, buf) require.Error(t, err) } func TestUnSubscribeMessageEncode(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBSCRIBE<<4) | 2, 33, 0, // packet ID MSB (0) @@ -113,7 +113,7 @@ func TestUnSubscribeMessageEncode(t *testing.T) { n, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message.") - require.Equal(t, len(msgBytes), n, "Error encoding message.") + require.Equal(t, len(buf), n, "Error encoding message.") //msg1 := NewUnSubscribeMessage() @@ -123,16 +123,16 @@ func TestUnSubscribeMessageEncode(t *testing.T) { require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message.") - require.Equal(t, len(msgBytes), n, "Error decoding message.") + require.Equal(t, len(buf), n, "Error decoding message.") require.Equal(t, 3, len(msg1.topics), "Error decoding message.") - //require.Equal(t, msgBytes, dst[:n], "Error decoding message.") + //require.Equal(t, buf, dst[:n], "Error decoding message.") } // test to ensure encoding and decoding are the same // decode, encode, and decode again func TestUnSubscribeDecodeEncodeEquiv(t *testing.T) { - msgBytes := []byte{ + buf := []byte{ byte(UNSUBSCRIBE<<4) | 2, 34, 0, // packet ID MSB (0) @@ -148,22 +148,22 @@ func TestUnSubscribeDecodeEncodeEquiv(t *testing.T) { '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', } - m, n, err := Decode(ProtocolV311, msgBytes) + m, n, err := Decode(ProtocolV311, buf) require.NoError(t, err, "Error decoding message.") msg, ok := m.(*UnSubscribe) require.Equal(t, true, ok, "Invalid message type") require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n, "Raw message length does not match") + require.Equal(t, len(buf), n, "Raw message length does not match") dst := make([]byte, 100) n2, err := msg.Encode(dst) require.NoError(t, err, "Error encoding message.") - require.Equal(t, len(msgBytes), n2, "Raw message length does not match") + require.Equal(t, len(buf), n2, "Raw message length does not match") require.Equal(t, 3, len(msg.topics), "Topics count does not match") _, n3, err := Decode(ProtocolV311, dst) require.NoError(t, err, "Error decoding message") - require.Equal(t, len(msgBytes), n3, "Raw message length does not match") + require.Equal(t, len(buf), n3, "Raw message length does not match") } diff --git a/subscriber/subscriber.go b/subscriber/subscriber.go index 15646a4..0914f70 100644 --- a/subscriber/subscriber.go +++ b/subscriber/subscriber.go @@ -1,82 +1,68 @@ package subscriber import ( - "unsafe" - "sync" + "sync/atomic" + "unsafe" "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/topics/types" ) -// ConnectionProvider passed to present network connection -type ConnectionProvider interface { - ID() string +// SessionProvider passed to present network connection +type SessionProvider interface { Subscriptions() Subscriptions Subscribe(string, *topicsTypes.SubscriptionParams) (packet.QosType, []*packet.Publish, error) UnSubscribe(string) error HasSubscriptions() bool - Online(c OnlinePublish) - OnlineRedirect(c OnlinePublish) + Online(c Publisher) Offline(bool) Hash() uintptr - Version() packet.ProtocolVersion } -// OnlinePublish invoked when subscriber respective to sessions receive message -type OnlinePublish func(*packet.Publish) +type Publisher interface { + Publish(string, *packet.Publish) +} -// OfflinePublish invoked when subscriber respective to sessions receive message -type OfflinePublish func(string, *packet.Publish) +type publisher struct { + Publisher + sync.WaitGroup +} // Subscriptions contains active subscriptions with respective subscription parameters type Subscriptions map[string]*topicsTypes.SubscriptionParams // Config subscriber config options type Config struct { - ID string - Topics topicsTypes.SubscriberInterface - OnOfflinePublish OfflinePublish - OfflineQoS0 bool - Version packet.ProtocolVersion + ID string + OfflinePublish Publisher + Topics topicsTypes.SubscriberInterface + Version packet.ProtocolVersion } // Type subscriber object type Type struct { - id string - subscriptions Subscriptions - topics topicsTypes.SubscriberInterface - publishOffline OfflinePublish - publishOnline OnlinePublish - access sync.WaitGroup - wgOffline sync.WaitGroup - wgOnline sync.WaitGroup - publishLock sync.RWMutex // todo: find better way - isOnline chan struct{} - offlineQoS0 bool - version packet.ProtocolVersion + subscriptions Subscriptions + publish atomic.Value + access sync.WaitGroup + Config } +var _ SessionProvider = (*Type)(nil) + // New allocate new subscriber -func New(c *Config) *Type { +func New(c Config) *Type { p := &Type{ - isOnline: make(chan struct{}), - subscriptions: make(Subscriptions), - id: c.ID, - publishOffline: c.OnOfflinePublish, - version: c.Version, - offlineQoS0: c.OfflineQoS0, - topics: c.Topics, + subscriptions: make(Subscriptions), + Config: c, } - close(p.isOnline) - return p } -// ID get subscriber id -func (s *Type) ID() string { - return s.id +// GetID get subscriber id +func (s *Type) GetID() string { + return s.ID } // Hash returns address of the provider struct. @@ -100,9 +86,9 @@ func (s *Type) Release() { s.access.Done() } -// Version return MQTT protocol version -func (s *Type) Version() packet.ProtocolVersion { - return s.version +// GetVersion return MQTT protocol version +func (s *Type) GetVersion() packet.ProtocolVersion { + return s.Version } // Subscriptions list active subscriptions @@ -112,7 +98,7 @@ func (s *Type) Subscriptions() Subscriptions { // Subscribe to given topic func (s *Type) Subscribe(topic string, params *topicsTypes.SubscriptionParams) (packet.QosType, []*packet.Publish, error) { - q, r, err := s.topics.Subscribe(topic, s, params) + q, r, err := s.Topics.Subscribe(topic, s, params) s.subscriptions[topic] = params @@ -121,16 +107,15 @@ func (s *Type) Subscribe(topic string, params *topicsTypes.SubscriptionParams) ( // UnSubscribe from given topic func (s *Type) UnSubscribe(topic string) error { - err := s.topics.UnSubscribe(topic, s) delete(s.subscriptions, topic) - return err + return s.Topics.UnSubscribe(topic, s) } // Publish message accordingly to subscriber state // online: forward message to session // offline: persist message func (s *Type) Publish(p *packet.Publish, grantedQoS packet.QosType, ops packet.SubscriptionOptions, ids []uint32) error { - pkt, err := p.Clone(s.version) + pkt, err := p.Clone(s.Version) if err != nil { return err } @@ -164,48 +149,31 @@ func (s *Type) Publish(p *packet.Publish, grantedQoS packet.QosType, ops packet. // originally published as QoS 2 might get lost on the hop to the Client, but the Server should never // send a duplicate of that Message. A QoS 1 Message published to the same topic might either get // lost or duplicated on its transmission to that Client. - //case message.QoS0: + // case message.QoS0: } - select { - case <-s.isOnline: - // if session is offline forward message to persisted storage - // only with QoS1 and QoS2 and QoS0 if set by config - qos := pkt.QoS() - if qos != packet.QoS0 || (s.offlineQoS0 && qos == packet.QoS0) { - defer s.wgOffline.Done() - s.publishLock.RLock() - s.wgOffline.Add(1) - s.publishLock.RUnlock() - s.publishOffline(s.id, pkt) - } - default: - // forward message to publish queue - defer s.wgOnline.Done() - s.publishLock.RLock() - s.wgOnline.Add(1) - s.publishLock.RUnlock() - s.publishOnline(pkt) - } + pb := s.publish.Load().(*publisher) + pb.Add(1) + pb.Publish(s.ID, pkt) + pb.Done() return nil } // Online moves subscriber to online state // since this moment all of publishes are forwarded to provided callback -func (s *Type) Online(c OnlinePublish) { - s.wgOffline.Wait() - defer s.publishLock.Unlock() - s.publishLock.Lock() - s.publishOnline = c - s.isOnline = make(chan struct{}) -} +func (s *Type) Online(c Publisher) { + p := s.publish.Load() + + pb := &publisher{ + Publisher: c, + } + s.publish.Store(pb) -// OnlineRedirect set new online publish callback -func (s *Type) OnlineRedirect(c OnlinePublish) { - defer s.publishLock.Unlock() - s.publishLock.Lock() - s.publishOnline = c + if p != nil { + old := p.(*publisher) + old.Wait() + } } // Offline put session offline @@ -214,19 +182,15 @@ func (s *Type) Offline(shutdown bool) { // if session is clean then remove all remaining subscriptions if shutdown { for topic := range s.subscriptions { - s.topics.UnSubscribe(topic, s) // nolint: errcheck + s.Topics.UnSubscribe(topic, s) // nolint: errcheck delete(s.subscriptions, topic) } - } - - // wait all of remaining publishes are finished - select { - case <-s.isOnline: - default: - close(s.isOnline) - // Wait all of online publishes done - s.publishLock.Lock() - s.wgOnline.Wait() - s.publishLock.Unlock() + } else { + pb := &publisher{ + Publisher: s.OfflinePublish, + } + old := s.publish.Load().(*publisher) + s.publish.Store(pb) + old.Wait() } } diff --git a/systree/clients.go b/systree/clients.go index 52a2659..9100c33 100644 --- a/systree/clients.go +++ b/systree/clients.go @@ -1,11 +1,10 @@ package systree import ( + "encoding/json" "sync/atomic" "time" - "encoding/json" - "github.com/VolantMQ/volantmq/packet" "github.com/VolantMQ/volantmq/types" ) @@ -20,7 +19,7 @@ type ClientConnectStatus struct { KeepAlive uint16 GeneratedID bool CleanSession bool - KillOnDisconnect bool + Durable bool SessionPresent bool PreserveOrder bool MaximumQoS packet.QosType diff --git a/systree/sessions.go b/systree/sessions.go index 44ab519..838cc8c 100644 --- a/systree/sessions.go +++ b/systree/sessions.go @@ -38,7 +38,7 @@ func newSessions(topicPrefix string, retained *[]types.RetainObject) sessions { return c } -// Connected add to statistic new client +// Created add to statistic new client func (t *sessions) Created(id string, status *SessionCreatedStatus) { newVal := atomic.AddUint64(&t.curr.val, 1) if atomic.LoadUint64(&t.max.val) < newVal { @@ -65,7 +65,7 @@ func (t *sessions) Created(id string, status *SessionCreatedStatus) { } } -// Disconnected remove client from statistic +// Removed remove client from statistic func (t *sessions) Removed(id string, status *SessionDeletedStatus) { atomic.AddUint64(&t.curr.val, ^uint64(0)) if t.topicsManager != nil { diff --git a/systree/tree.go b/systree/tree.go index 887af60..44ef5d4 100644 --- a/systree/tree.go +++ b/systree/tree.go @@ -1,6 +1,8 @@ package systree -import "github.com/VolantMQ/volantmq/types" +import ( + "github.com/VolantMQ/volantmq/types" +) type impl struct { server server @@ -13,8 +15,8 @@ type impl struct { // NewTree allocate systree provider func NewTree(base string) (Provider, []types.RetainObject, []DynamicValue, error) { - retains := []types.RetainObject{} - staticRetains := []types.RetainObject{} + var retains []types.RetainObject + var staticRetains []types.RetainObject tr := &impl{ newServer(base, &retains, &staticRetains), @@ -25,7 +27,7 @@ func NewTree(base string) (Provider, []types.RetainObject, []DynamicValue, error newSessions(base, &retains), } - dynUpdates := []DynamicValue{} + var dynUpdates []DynamicValue for _, d := range retains { v := d.(DynamicValue) dynUpdates = append(dynUpdates, v) @@ -35,6 +37,7 @@ func NewTree(base string) (Provider, []types.RetainObject, []DynamicValue, error return tr, retains, dynUpdates, nil } +// SetCallbacks func (t *impl) SetCallbacks(cb types.TopicMessenger) { t.clients.topicsManager = cb t.sessions.topicsManager = cb @@ -45,7 +48,7 @@ func (t *impl) Sessions() Sessions { return &t.sessions } -// Session get session stat provider +// Clients get clients stat provider func (t *impl) Clients() Clients { return &t.clients } diff --git a/topics/mem/node.go b/topics/mem/node.go index 88fc3c8..d3e4bad 100644 --- a/topics/mem/node.go +++ b/topics/mem/node.go @@ -9,16 +9,16 @@ import ( "github.com/VolantMQ/volantmq/types" ) -type subscribedEntry struct { +type subscriber struct { s topicsTypes.Subscriber p *topicsTypes.SubscriptionParams } -type subscribedEntries map[uintptr]*subscribedEntry +type subscribers map[uintptr]*subscriber -func (s *subscribedEntry) acquire() *publishEntry { +func (s *subscriber) acquire() *publish { s.s.Acquire() - pe := &publishEntry{ + pe := &publish{ s: s.s, qos: s.p.Granted, ops: s.p.Ops, @@ -31,26 +31,26 @@ func (s *subscribedEntry) acquire() *publishEntry { return pe } -type publishEntry struct { +type publish struct { s topicsTypes.Subscriber ops packet.SubscriptionOptions qos packet.QosType ids []uint32 } -type publishEntries map[uintptr][]*publishEntry +type publishes map[uintptr][]*publish type node struct { retained interface{} - subs subscribedEntries + subs subscribers parent *node children map[string]*node - getSubscribers func(uintptr, *publishEntries) + getSubscribers func(uintptr, *publishes) } func newNode(overlap bool, parent *node) *node { n := &node{ - subs: make(subscribedEntries), + subs: make(subscribers), children: make(map[string]*node), parent: parent, } @@ -106,7 +106,7 @@ func (mT *provider) subscriptionInsert(filter string, sub topicsTypes.Subscriber // Otherwise create new entry exists := false if s, ok := root.subs[sub.Hash()]; !ok { - root.subs[sub.Hash()] = &subscribedEntry{ + root.subs[sub.Hash()] = &subscriber{ s: sub, p: p, } @@ -133,7 +133,7 @@ func (mT *provider) subscriptionRemove(topic string, sub topicsTypes.Subscriber) // otherwise try remove subscriber or set error if not exists if sub == nil { // If subscriber == nil, then it's signal to remove ALL subscribers - root.subs = make(subscribedEntries) + root.subs = make(subscribers) } else { id := sub.Hash() if _, ok := root.subs[id]; ok { @@ -162,7 +162,7 @@ func (mT *provider) subscriptionRemove(topic string, sub topicsTypes.Subscriber) return err } -func subscriptionRecurseSearch(root *node, levels []string, publishID uintptr, p *publishEntries) { +func subscriptionRecurseSearch(root *node, levels []string, publishID uintptr, p *publishes) { if len(levels) == 0 { // leaf level of the topic // get all subscribers and return @@ -185,7 +185,7 @@ func subscriptionRecurseSearch(root *node, levels []string, publishID uintptr, p } } -func (mT *provider) subscriptionSearch(topic string, publishID uintptr, p *publishEntries) { +func (mT *provider) subscriptionSearch(topic string, publishID uintptr, p *publishes) { root := mT.root levels := strings.Split(topic, "/") level := levels[0] @@ -288,7 +288,7 @@ func (sn *node) getRetained(retained *[]*packet.Publish) { } // if publish has expiration set check if there time left to live - if !p.Expired(false) { + if _, _, expired := p.Expired(); !expired { *retained = append(*retained, p) } else { // publish has expired, thus nobody should get it @@ -305,7 +305,7 @@ func (sn *node) allRetained(retained *[]*packet.Publish) { } } -func (sn *node) overlappingSubscribers(publishID uintptr, p *publishEntries) { +func (sn *node) overlappingSubscribers(publishID uintptr, p *publishes) { for id, sub := range sn.subs { if s, ok := (*p)[id]; ok { if sub.p.ID > 0 { @@ -324,14 +324,14 @@ func (sn *node) overlappingSubscribers(publishID uintptr, p *publishEntries) { } } -func (sn *node) nonOverlappingSubscribers(publishID uintptr, p *publishEntries) { +func (sn *node) nonOverlappingSubscribers(publishID uintptr, p *publishes) { for id, sub := range sn.subs { if !sub.p.Ops.NL() || id != publishID { pe := sub.acquire() if _, ok := (*p)[id]; ok { (*p)[id] = append((*p)[id], pe) } else { - (*p)[id] = []*publishEntry{pe} + (*p)[id] = []*publish{pe} } } } diff --git a/topics/mem/topics.go b/topics/mem/topics.go index 17ce0f6..a9d4a0d 100644 --- a/topics/mem/topics.go +++ b/topics/mem/topics.go @@ -16,7 +16,6 @@ package mem import ( "sync" - "time" "github.com/VolantMQ/persistence" @@ -75,7 +74,7 @@ func NewMemProvider(config *topicsTypes.MemConfig) (topicsTypes.Provider, error) if m, ok := pkt.(*packet.Publish); ok { if len(d.ExpireAt) > 0 { if tm, err := time.Parse(time.RFC3339, d.ExpireAt); err == nil { - m.SetExpiry(tm) + m.SetExpireAt(tm) } else { p.log.Error("Decode publish expire at", zap.Error(err)) } @@ -174,16 +173,16 @@ func (mT *provider) Close() error { var encoded []persistence.PersistedPacket for _, pkt := range res { - // Discard retained QoS0 messages - if pkt.QoS() != packet.QoS0 && !pkt.Expired(false) { + // Discard retained expired and QoS0 messages + if expireAt, _, expired := pkt.Expired(); !expired && pkt.QoS() != packet.QoS0 { if buf, err := packet.Encode(pkt); err != nil { mT.log.Error("Couldn't encode retained message", zap.Error(err)) } else { entry := persistence.PersistedPacket{ Data: buf, } - if tm := pkt.GetExpiry(); !tm.IsZero() { - entry.ExpireAt = tm.Format(time.RFC3339) + if !expireAt.IsZero() { + entry.ExpireAt = expireAt.Format(time.RFC3339) } encoded = append(encoded, entry) } @@ -240,7 +239,7 @@ func (mT *provider) publisher() { mT.wgPublisherStarted.Done() for msg := range mT.inbound { - pubEntries := publishEntries{} + pubEntries := publishes{} mT.smu.Lock() mT.subscriptionSearch(msg.Topic(), msg.PublishID(), &pubEntries) diff --git a/topics/mem/trie_test.go b/topics/mem/trie_test.go index 18ef642..759477e 100644 --- a/topics/mem/trie_test.go +++ b/topics/mem/trie_test.go @@ -42,7 +42,7 @@ func TestMatch1(t *testing.T) { } prov.Subscribe("sport/tennis/player1/#", sub, p) // nolint: errcheck - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) @@ -58,7 +58,7 @@ func TestMatch2(t *testing.T) { } prov.Subscribe("sport/tennis/player1/#", sub, p) // nolint: errcheck - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) @@ -75,7 +75,7 @@ func TestSNodeMatch3(t *testing.T) { prov.Subscribe("sport/tennis/#", sub, p) // nolint: errcheck - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -89,29 +89,29 @@ func TestMatch4(t *testing.T) { } prov.Subscribe("#", sub, p) // nolint: errcheck - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("/sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") err := prov.subscriptionRemove("#", sub) require.NoError(t, err) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("#", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") prov.subscriptionInsert("/#", sub, p) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") @@ -120,19 +120,19 @@ func TestMatch4(t *testing.T) { prov.subscriptionInsert("bla/bla/#", sub, p) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers), "should not return subscribers") - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla/bla/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers), "should return subscribers") } @@ -149,7 +149,7 @@ func TestMatch5(t *testing.T) { prov.subscriptionInsert("sport/tennis/+/+/#", sub1, p) prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 2, len(subscribers)) @@ -166,7 +166,7 @@ func TestMatch6(t *testing.T) { prov.subscriptionInsert("sport/tennis/+/+/+/+/#", sub1, p) prov.subscriptionInsert("sport/tennis/player1/anzel", sub2, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel/bla/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) } @@ -186,7 +186,7 @@ func TestMatch7(t *testing.T) { prov.subscriptionInsert("sport/tennis", sub2, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("sport/tennis/player1/anzel", 0, &subscribers) require.Equal(t, 1, len(subscribers)) require.Equal(t, sub1, subscribers[sub1.Hash()][0].s) @@ -202,7 +202,7 @@ func TestMatch8(t *testing.T) { prov.subscriptionInsert("+/+", sub, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 1, len(subscribers)) @@ -218,7 +218,7 @@ func TestMatch9(t *testing.T) { prov.subscriptionInsert("/+", sub1, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 1, len(subscribers)) @@ -234,7 +234,7 @@ func TestMatch10(t *testing.T) { prov.subscriptionInsert("+", sub1, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("/finance", 0, &subscribers) require.Equal(t, 0, len(subscribers)) @@ -249,28 +249,28 @@ func TestInsertRemove(t *testing.T) { prov.subscriptionInsert("#", sub, p) - subscribers := publishEntries{} + subscribers := publishes{} prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 0, len(subscribers)) err := prov.subscriptionRemove("#", sub) require.NoError(t, err) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("#", 0, &subscribers) require.Equal(t, 0, len(subscribers)) prov.subscriptionInsert("/#", sub, p) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("bla", 0, &subscribers) require.Equal(t, 0, len(subscribers)) - subscribers = publishEntries{} + subscribers = publishes{} prov.subscriptionSearch("/bla", 0, &subscribers) require.Equal(t, 1, len(subscribers)) @@ -317,7 +317,7 @@ func TestInsert1(t *testing.T) { require.Equal(t, 0, len(level5.children)) require.Equal(t, 1, len(level5.subs)) - var e *subscribedEntry + var e *subscriber e, ok = level5.subs[sub1.Hash()] require.Equal(t, true, ok) @@ -343,7 +343,7 @@ func TestSNodeInsert2(t *testing.T) { require.Equal(t, 0, len(n2.children)) require.Equal(t, 1, len(n2.subs)) - var e *subscribedEntry + var e *subscriber e, ok = n2.subs[sub1.Hash()] require.Equal(t, true, ok) @@ -381,7 +381,7 @@ func TestSNodeInsert3(t *testing.T) { require.Equal(t, 0, len(n4.children)) require.Equal(t, 1, len(n4.subs)) - var e *subscribedEntry + var e *subscriber e, ok = n4.subs[sub1.Hash()] require.Equal(t, true, ok) @@ -413,7 +413,7 @@ func TestSNodeInsert4(t *testing.T) { require.Equal(t, 0, len(n3.children)) require.Equal(t, 1, len(n3.subs)) - var e *subscribedEntry + var e *subscriber e, ok = n3.subs[sub1.Hash()] require.Equal(t, true, ok) @@ -447,7 +447,7 @@ func TestSNodeInsertDup(t *testing.T) { require.Equal(t, 0, len(n3.children)) require.Equal(t, 1, len(n3.subs)) - var e *subscribedEntry + var e *subscriber e, ok = n3.subs[sub1.Hash()] require.Equal(t, true, ok) diff --git a/topics/types/types.go b/topics/types/types.go index 35c23e4..a15be2f 100644 --- a/topics/types/types.go +++ b/topics/types/types.go @@ -2,7 +2,6 @@ package topicsTypes import ( "errors" - "regexp" "github.com/VolantMQ/volantmq/packet" diff --git a/transport/base.go b/transport/base.go index 3533996..4a1240e 100644 --- a/transport/base.go +++ b/transport/base.go @@ -1,14 +1,10 @@ package transport import ( - "errors" "sync" - "time" "github.com/VolantMQ/volantmq/auth" "github.com/VolantMQ/volantmq/clients" - "github.com/VolantMQ/volantmq/packet" - "github.com/VolantMQ/volantmq/routines" "github.com/VolantMQ/volantmq/systree" "go.uber.org/zap" ) @@ -24,10 +20,6 @@ type Config struct { // InternalConfig used by server implementation to configure internal specific needs type InternalConfig struct { - // AllowedVersions what protocol version server will handle - // If not set than defaults to 0x3 and 0x04 - AllowedVersions map[packet.ProtocolVersion]bool - Sessions *clients.Manager Metric systree.Metric @@ -76,14 +68,6 @@ func (c *baseConfig) handleConnection(conn conn) { return } - var err error - - defer func() { - if err != nil { - conn.Close() // nolint: errcheck, gas - } - }() - // To establish a connection, we must // 1. Read and decode the message.ConnectMessage from the wire // 2. If no decoding errors, then authenticate using username and password. @@ -96,66 +80,9 @@ func (c *baseConfig) handleConnection(conn conn) { // Read the CONNECT message from the wire, if error, then check to see if it's // a CONNACK error. If it's CONNACK error, send the proper CONNACK error back // to client. Exit regardless of error type. - conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.ConnectTimeout))) // nolint: errcheck, gas - - var req packet.Provider - - var buf []byte - if buf, err = routines.GetMessageBuffer(conn); err != nil { - c.log.Error("Couldn't get CONNECT message", zap.Error(err)) - return - } + //conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.ConnectTimeout))) // nolint: errcheck, gas - if req, _, err = packet.Decode(packet.ProtocolV50, buf); err != nil { - c.log.Warn("Couldn't decode message", zap.Error(err)) - - if _, ok := err.(packet.ReasonCode); ok { - if req != nil { - c.Metric.Packets().Received(req.Type()) - } - } - } else { - // Disable read deadline. Will set it later if keep-alive interval is bigger than 0 - conn.SetReadDeadline(time.Time{}) // nolint: errcheck - switch r := req.(type) { - case *packet.Connect: - m, _ := packet.New(req.Version(), packet.CONNACK) - resp, _ := m.(*packet.ConnAck) - - var reason packet.ReasonCode - // If protocol version is not in allowed list then give reject and pass control to session manager - // to handle response - if allowed, ok := c.AllowedVersions[r.Version()]; !ok || !allowed { - reason = packet.CodeRefusedUnacceptableProtocolVersion - if r.Version() == packet.ProtocolV50 { - reason = packet.CodeUnsupportedProtocol - } - } else { - user, pass := r.Credentials() - - if status := c.config.AuthManager.Password(string(user), string(pass)); status == auth.StatusAllow { - reason = packet.CodeSuccess - } else { - reason = packet.CodeRefusedBadUsernameOrPassword - if req.Version() == packet.ProtocolV50 { - reason = packet.CodeBadUserOrPassword - } - } - } - resp.SetReturnCode(reason) // nolint: errcheck - - c.Sessions.NewSession( - &clients.StartConfig{ - Req: r, - Resp: resp, - Conn: conn, - Auth: c.config.AuthManager, - }) - default: - c.log.Error("Unexpected message type", - zap.String("expected", "CONNECT"), - zap.String("received", r.Type().Name())) - err = errors.New("unexpected message type") - } + if err := c.Sessions.Handle(conn, c.config.AuthManager); err != nil { + conn.Close() // nolint: errcheck, gas } } diff --git a/transport/connWS.go b/transport/connWS.go index 7dffa8f..e0e0c4d 100644 --- a/transport/connWS.go +++ b/transport/connWS.go @@ -1,11 +1,10 @@ package transport import ( + "io" "net" "time" - "io" - "github.com/VolantMQ/volantmq/systree" "github.com/gorilla/websocket" ) diff --git a/transport/tcp.go b/transport/tcp.go index bff61fd..d23e4cf 100644 --- a/transport/tcp.go +++ b/transport/tcp.go @@ -87,10 +87,6 @@ func (l *tcp) Close() error { return err } -//func (l *tcp) Protocol() string { -// return "tcp" -//} - func (l *tcp) Serve() error { var tempDelay time.Duration // how long to sleep on accept failure diff --git a/transport/websocket.go b/transport/websocket.go index bdd58cd..29e74f4 100644 --- a/transport/websocket.go +++ b/transport/websocket.go @@ -1,11 +1,9 @@ package transport import ( - "net/http" - - "crypto/tls" - "context" + "crypto/tls" + "net/http" "time" "github.com/VolantMQ/volantmq/auth" diff --git a/types/types.go b/types/types.go index 6aed5cf..affc6c5 100644 --- a/types/types.go +++ b/types/types.go @@ -1,9 +1,8 @@ package types import ( - "sync/atomic" - "sync" + "sync/atomic" "go.uber.org/zap" ) @@ -17,7 +16,7 @@ type LogInterface struct { // Default configs const ( DefaultKeepAlive = 60 // DefaultKeepAlive default keep - DefaultConnectTimeout = 2 // DefaultConnectTimeout connect timeout + DefaultConnectTimeout = 5 // DefaultConnectTimeout connect timeout DefaultMaxPacketSize = 268435455 DefaultReceiveMax = 65535 DefaultAckTimeout = 20 // DefaultAckTimeout ack timeout diff --git a/volantmq.go b/volantmq.go index d3578bc..a09ea75 100644 --- a/volantmq.go +++ b/volantmq.go @@ -32,6 +32,8 @@ var ( ErrInvalidNodeName = errors.New("node name is invalid") ) +type option func(*Server) + // ServerConfig configuration of the MQTT server type ServerConfig struct { // Configuration of persistence provider @@ -244,6 +246,7 @@ func NewServer(config *ServerConfig) (Server, error) { } mConfig := &clients.Config{ + AllowedVersions: s.AllowedVersions, TopicsMgr: s.topicsMgr, ConnectTimeout: s.ConnectTimeout, Persist: s.Persistence, @@ -274,11 +277,11 @@ func (s *server) ListenAndServe(config interface{}) error { var err error internalConfig := transport.InternalConfig{ - Metric: s.sysTree.Metric(), - Sessions: s.sessionsMgr, - ConnectTimeout: s.ConnectTimeout, - KeepAlive: s.KeepAlive, - AllowedVersions: s.AllowedVersions, + Metric: s.sysTree.Metric(), + Sessions: s.sessionsMgr, + ConnectTimeout: s.ConnectTimeout, + KeepAlive: s.KeepAlive, + //AllowedVersions: s.AllowedVersions, } switch c := config.(type) { @@ -364,15 +367,14 @@ func (s *server) Close() error { } func (s *server) systreeUpdater() { - for _, m := range s.systree.publishes { - _m := m.Publish() - _msg, _ := packet.New(packet.ProtocolV311, packet.PUBLISH) - msg, _ := _msg.(*packet.Publish) - - msg.SetPayload(_m.Payload()) - msg.SetTopic(_m.Topic()) // nolint: errcheck - msg.SetQoS(_m.QoS()) // nolint: errcheck - s.topicsMgr.Publish(msg) // nolint: errcheck + for _, val := range s.systree.publishes { + p := val.Publish() + pkt := packet.NewPublish(packet.ProtocolV311) + + pkt.SetPayload(p.Payload()) + pkt.SetTopic(p.Topic()) // nolint: errcheck + pkt.SetQoS(p.QoS()) // nolint: errcheck + s.topicsMgr.Publish(pkt) // nolint: errcheck } s.systree.timer.Reset(s.SystreeUpdateInterval * time.Second)