Skip to content

Commit

Permalink
Fix ws adapter communication with broker and things service (absmach#…
Browse files Browse the repository at this point in the history
…1899)

Signed-off-by: Rodney Osodo <[email protected]>
  • Loading branch information
rodneyosodo authored Sep 1, 2023
1 parent 320921a commit 7c3add6
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cmd/ws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func main() {

svc := newService(tc, nps, logger, tracer)

hs := httpserver.New(ctx, cancel, svcName, httpServerConfig, api.MakeHandler(svc, logger, cfg.InstanceID), logger)
hs := httpserver.New(ctx, cancel, svcName, httpServerConfig, api.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger)

if cfg.SendTelemetry {
chc := chclient.New(svcName, mainflux.Version, logger, cancel)
Expand Down
4 changes: 2 additions & 2 deletions coap/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic
Subject: key,
Object: chanID,
Action: policies.ReadAction,
EntityType: policies.GroupEntityType,
EntityType: policies.ThingEntityType,
}
res, err := svc.auth.Authorize(ctx, ar)
if err != nil {
Expand All @@ -100,7 +100,7 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, key, chanID, subtopi
Subject: key,
Object: chanID,
Action: policies.ReadAction,
EntityType: policies.GroupEntityType,
EntityType: policies.ThingEntityType,
}
res, err := svc.auth.Authorize(ctx, ar)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions mqtt/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (h *handler) AuthConnect(ctx context.Context) error {
return errors.ErrAuthentication
}

if err := h.es.Connect(ctx, string(s.Password)); err != nil {
if err := h.es.Connect(ctx, pwd); err != nil {
h.logger.Error(errors.Wrap(ErrFailedPublishConnectEvent, err).Error())
}

Expand Down Expand Up @@ -249,7 +249,7 @@ func (h *handler) authAccess(ctx context.Context, password, topic, action string
return errors.ErrAuthorization
}

return err
return nil
}

func parseSubtopic(subtopic string) (string, error) {
Expand Down
10 changes: 5 additions & 5 deletions ws/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func New(auth policies.AuthServiceClient, pubsub messaging.PubSub) Service {
}

func (svc *adapterService) Publish(ctx context.Context, thingKey string, msg *messaging.Message) error {
thid, err := svc.authorize(ctx, thingKey, msg.GetChannel())
thid, err := svc.authorize(ctx, thingKey, msg.GetChannel(), policies.WriteAction)
if err != nil {
return ErrUnauthorizedAccess
}
Expand All @@ -98,7 +98,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, thingKey, chanID, subt
return ErrUnauthorizedAccess
}

thid, err := svc.authorize(ctx, thingKey, chanID)
thid, err := svc.authorize(ctx, thingKey, chanID, policies.ReadAction)
if err != nil {
return ErrUnauthorizedAccess
}
Expand All @@ -122,7 +122,7 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, thingKey, chanID, su
return ErrUnauthorizedAccess
}

thid, err := svc.authorize(ctx, thingKey, chanID)
thid, err := svc.authorize(ctx, thingKey, chanID, policies.ReadAction)
if err != nil {
return ErrUnauthorizedAccess
}
Expand All @@ -137,11 +137,11 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, thingKey, chanID, su

// authorize checks if the thingKey is authorized to access the channel
// and returns the thingID if it is.
func (svc *adapterService) authorize(ctx context.Context, thingKey, chanID string) (string, error) {
func (svc *adapterService) authorize(ctx context.Context, thingKey, chanID, action string) (string, error) {
ar := &policies.AuthorizeReq{
Subject: thingKey,
Object: chanID,
Action: policies.ReadAction,
Action: action,
EntityType: policies.ThingEntityType,
}
res, err := svc.auth.Authorize(ctx, ar)
Expand Down
3 changes: 2 additions & 1 deletion ws/api/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package api_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -37,7 +38,7 @@ func newService(cc policies.AuthServiceClient) (ws.Service, mocks.MockPubSub) {

func newHTTPServer(svc ws.Service) *httptest.Server {
logger := mflog.NewMock()
mux := api.MakeHandler(svc, logger, instanceID)
mux := api.MakeHandler(context.Background(), svc, logger, instanceID)
return httptest.NewServer(mux)
}

Expand Down
11 changes: 5 additions & 6 deletions ws/api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import (

var channelPartRegExp = regexp.MustCompile(`^/channels/([\w\-]+)/messages(/[^?]*)?(\?.*)?$`)

func handshake(svc ws.Service) http.HandlerFunc {
func handshake(ctx context.Context, svc ws.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, err := decodeRequest(r)
if err != nil {
encodeError(w, err)
Expand Down Expand Up @@ -146,10 +145,10 @@ func process(ctx context.Context, svc ws.Service, req connReq, msgs <-chan []byt
Payload: msg,
Created: time.Now().UnixNano(),
}
_ = svc.Publish(ctx, req.thingKey, &m)
}
if err := svc.Unsubscribe(ctx, req.thingKey, req.chanID, req.subtopic); err != nil {
req.conn.Close()

if err := svc.Publish(ctx, req.thingKey, &m); err != nil {
logger.Warn(fmt.Sprintf("Failed to publish message: %s", err.Error()))
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions ws/api/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package api

import (
"context"
"errors"
"net/http"

Expand Down Expand Up @@ -35,12 +36,12 @@ var (
)

// MakeHandler returns http handler with handshake endpoint.
func MakeHandler(svc ws.Service, l mflog.Logger, instanceID string) http.Handler {
func MakeHandler(ctx context.Context, svc ws.Service, l mflog.Logger, instanceID string) http.Handler {
logger = l

mux := bone.New()
mux.GetFunc("/channels/:chanID/messages", handshake(svc))
mux.GetFunc("/channels/:chanID/messages/*", handshake(svc))
mux.GetFunc("/channels/:chanID/messages", handshake(ctx, svc))
mux.GetFunc("/channels/:chanID/messages/*", handshake(ctx, svc))
mux.GetFunc("/version", mainflux.Health(protocol, instanceID))
mux.Handle("/metrics", promhttp.Handler())

Expand Down

0 comments on commit 7c3add6

Please sign in to comment.