diff --git a/journal/api/endpoint.go b/journal/api/endpoint.go index 2042aeba3f..a8128783ba 100644 --- a/journal/api/endpoint.go +++ b/journal/api/endpoint.go @@ -37,3 +37,26 @@ func retrieveJournalsEndpoint(svc journal.Service) endpoint.Endpoint { }, nil } } + +func retrieveClientTelemetryEndpoint(svc journal.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(retrieveClientTelemetryReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + telemetry, err := svc.RetrieveClientTelemetry(ctx, session, req.clientID) + if err != nil { + return nil, err + } + + return clientTelemetryRes{ + ClientTelemetry: telemetry, + }, nil + } +} diff --git a/journal/api/endpoint_test.go b/journal/api/endpoint_test.go index b470a4cbf7..38e944f9a4 100644 --- a/journal/api/endpoint_test.go +++ b/journal/api/endpoint_test.go @@ -371,7 +371,7 @@ func TestListEntityJournalsEndpoint(t *testing.T) { desc: "with empty domain ID", token: validToken, url: "/group/", - status: http.StatusNotFound, + status: http.StatusBadRequest, svcErr: nil, }, } @@ -402,3 +402,86 @@ func TestListEntityJournalsEndpoint(t *testing.T) { }) } } + +func TestRetrieveClientTelemetryEndpoint(t *testing.T) { + es, svc, authn := newjournalServer() + + clientID := testsutil.GenerateUUID(t) + userID := testsutil.GenerateUUID(t) + domanID := testsutil.GenerateUUID(t) + + cases := []struct { + desc string + token string + session smqauthn.Session + clientID string + domainID string + url string + contentType string + status int + authnErr error + svcErr error + }{ + { + desc: "successful", + token: validToken, + clientID: clientID, + domainID: domanID, + url: fmt.Sprintf("/client/%s/telemetry", clientID), + status: http.StatusOK, + svcErr: nil, + }, + { + desc: "with service error", + token: validToken, + clientID: clientID, + domainID: domanID, + url: fmt.Sprintf("/client/%s/telemetry", clientID), + status: http.StatusForbidden, + svcErr: svcerr.ErrAuthorization, + }, + { + desc: "with empty token", + clientID: clientID, + domainID: domanID, + url: fmt.Sprintf("/client/%s/telemetry", clientID), + status: http.StatusUnauthorized, + svcErr: nil, + }, + { + desc: "with invalid client ID", + token: validToken, + domainID: domanID, + clientID: "invalid", + url: "/client/invalid/telemetry", + status: http.StatusNotFound, + svcErr: svcerr.ErrNotFound, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + if c.token == validToken { + c.session = smqauthn.Session{ + UserID: userID, + DomainID: c.domainID, + DomainUserID: c.domainID + "_" + userID, + } + } + authCall := authn.On("Authenticate", mock.Anything, c.token).Return(c.session, c.authnErr) + svcCall := svc.On("RetrieveClientTelemetry", mock.Anything, c.session, c.clientID).Return(journal.ClientTelemetry{}, c.svcErr) + req := testRequest{ + client: es.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/journal%s", es.URL, c.domainID, c.url), + token: c.token, + } + resp, err := req.make() + assert.Nil(t, err, c.desc) + defer resp.Body.Close() + assert.Equal(t, c.status, resp.StatusCode, c.desc) + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/journal/api/requests.go b/journal/api/requests.go index 8d52fd8ecf..41e4dd3462 100644 --- a/journal/api/requests.go +++ b/journal/api/requests.go @@ -30,3 +30,15 @@ func (req retrieveJournalsReq) validate() error { return nil } + +type retrieveClientTelemetryReq struct { + clientID string +} + +func (req retrieveClientTelemetryReq) validate() error { + if req.clientID == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/journal/api/requests_test.go b/journal/api/requests_test.go index e88c52c9f0..eba8e325de 100644 --- a/journal/api/requests_test.go +++ b/journal/api/requests_test.go @@ -124,3 +124,33 @@ func TestRetrieveJournalsReqValidate(t *testing.T) { }) } } + +func TestRetrieveClientTelemetryReqValidate(t *testing.T) { + cases := []struct { + desc string + req retrieveClientTelemetryReq + err error + }{ + { + desc: "valid", + req: retrieveClientTelemetryReq{ + clientID: "id", + }, + err: nil, + }, + { + desc: "missing client id", + req: retrieveClientTelemetryReq{ + clientID: "", + }, + err: apiutil.ErrMissingID, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + err := c.req.validate() + assert.Equal(t, c.err, err) + }) + } +} diff --git a/journal/api/responses.go b/journal/api/responses.go index 417c8c87d7..b4e0f9496a 100644 --- a/journal/api/responses.go +++ b/journal/api/responses.go @@ -27,3 +27,7 @@ func (res pageRes) Code() int { func (res pageRes) Empty() bool { return false } + +type clientTelemetryRes struct { + journal.ClientTelemetry `json:",inline"` +} diff --git a/journal/api/transport.go b/journal/api/transport.go index 7c7f5bc83c..f3c856c928 100644 --- a/journal/api/transport.go +++ b/journal/api/transport.go @@ -48,12 +48,23 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo opts..., ), "list_user_journals").ServeHTTP) - mux.With(api.AuthenticateMiddleware(authn, true)).Get("/{domainID}/journal/{entityType}/{entityID}", otelhttp.NewHandler(kithttp.NewServer( - retrieveJournalsEndpoint(svc), - decodeRetrieveEntityJournalReq, - api.EncodeResponse, - opts..., - ), "list__entity_journals").ServeHTTP) + mux.Route("/{domainID}/journal", func(r chi.Router) { + r.Use(api.AuthenticateMiddleware(authn, true)) + + r.Get("/{entityType}/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + retrieveJournalsEndpoint(svc), + decodeRetrieveEntityJournalReq, + api.EncodeResponse, + opts..., + ), "list__entity_journals").ServeHTTP) + + r.Get("/client/{clientID}/telemetry", otelhttp.NewHandler(kithttp.NewServer( + retrieveClientTelemetryEndpoint(svc), + decodeRetrieveClientTelemetryReq, + api.EncodeResponse, + opts..., + ), "view_client_telemetry").ServeHTTP) + }) mux.Get("/health", supermq.Health(svcName, instanceID)) mux.Handle("/metrics", promhttp.Handler()) @@ -160,3 +171,11 @@ func decodePageQuery(r *http.Request) (journal.Page, error) { Direction: dir, }, nil } + +func decodeRetrieveClientTelemetryReq(_ context.Context, r *http.Request) (interface{}, error) { + req := retrieveClientTelemetryReq{ + clientID: chi.URLParam(r, "clientID"), + } + + return req, nil +} diff --git a/journal/journal.go b/journal/journal.go index c24982ce55..54d9b96f9a 100644 --- a/journal/journal.go +++ b/journal/journal.go @@ -137,6 +137,16 @@ func (page JournalsPage) MarshalJSON() ([]byte, error) { return json.Marshal(a) } +type ClientTelemetry struct { + ClientID string `json:"client_id"` + DomainID string `json:"domain_id"` + Subscriptions []string `json:"subscriptions"` + InboundMessages uint64 `json:"inbound_messages"` + OutboundMessages uint64 `json:"outbound_messages"` + FirstSeen time.Time `json:"first_seen"` + LastSeen time.Time `json:"last_seen"` +} + // Service provides access to the journal log service. // //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" @@ -146,6 +156,9 @@ type Service interface { // RetrieveAll retrieves all journals from the database with the given page. RetrieveAll(ctx context.Context, session smqauthn.Session, page Page) (JournalsPage, error) + + // RetrieveClientTelemetry retrieves telemetry data for a client. + RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (ClientTelemetry, error) } // Repository provides access to the journal log database. @@ -157,4 +170,13 @@ type Repository interface { // RetrieveAll retrieves all journals from the database with the given page. RetrieveAll(ctx context.Context, page Page) (JournalsPage, error) + + // SaveClientTelemetry persists telemetry data for a client to the database. + SaveClientTelemetry(ctx context.Context, ct ClientTelemetry) error + + // RetrieveClientTelemetry retrieves telemetry data for a client from the database. + RetrieveClientTelemetry(ctx context.Context, clientID, domainID string) (ClientTelemetry, error) + + // DeleteClientTelemetry removes telemetry data for a client from the database. + DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error } diff --git a/journal/middleware/authorization.go b/journal/middleware/authorization.go index 6a3d72b2c4..2cf9edccff 100644 --- a/journal/middleware/authorization.go +++ b/journal/middleware/authorization.go @@ -12,9 +12,11 @@ import ( "github.com/absmach/supermq/pkg/policies" ) -var _ journal.Service = (*authorizationMiddleware)(nil) +var ( + _ journal.Service = (*authorizationMiddleware)(nil) -var readPermission = "read_permission" + readPermission = "read_permission" +) type authorizationMiddleware struct { svc journal.Service @@ -62,3 +64,21 @@ func (am *authorizationMiddleware) RetrieveAll(ctx context.Context, session smqa return am.svc.RetrieveAll(ctx, session, page) } + +func (am *authorizationMiddleware) RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (journal.ClientTelemetry, error) { + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.UserID, + Permission: readPermission, + ObjectType: policies.ClientType, + Object: clientID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return journal.ClientTelemetry{}, err + } + + return am.svc.RetrieveClientTelemetry(ctx, session, clientID) +} diff --git a/journal/middleware/logging.go b/journal/middleware/logging.go index d2a937076c..afaf9077e5 100644 --- a/journal/middleware/logging.go +++ b/journal/middleware/logging.go @@ -69,3 +69,21 @@ func (lm *loggingMiddleware) RetrieveAll(ctx context.Context, session smqauthn.S return lm.service.RetrieveAll(ctx, session, page) } + +func (lm *loggingMiddleware) RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (ct journal.ClientTelemetry, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", clientID), + slog.String("domain_id", session.DomainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Retrieve client telemetry failed", args...) + return + } + lm.logger.Info("Retrieve client telemetry completed successfully", args...) + }(time.Now()) + + return lm.service.RetrieveClientTelemetry(ctx, session, clientID) +} diff --git a/journal/middleware/metrics.go b/journal/middleware/metrics.go index 2c0698893d..ce70231ab1 100644 --- a/journal/middleware/metrics.go +++ b/journal/middleware/metrics.go @@ -47,3 +47,12 @@ func (mm *metricsMiddleware) RetrieveAll(ctx context.Context, session smqauthn.S return mm.service.RetrieveAll(ctx, session, page) } + +func (mm *metricsMiddleware) RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (journal.ClientTelemetry, error) { + defer func(begin time.Time) { + mm.counter.With("method", "retrieve_client_telemetry").Add(1) + mm.latency.With("method", "retrieve_client_telemetry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.RetrieveClientTelemetry(ctx, session, clientID) +} diff --git a/journal/middleware/tracing.go b/journal/middleware/tracing.go index 60774f7492..fcbec1e9e7 100644 --- a/journal/middleware/tracing.go +++ b/journal/middleware/tracing.go @@ -45,3 +45,13 @@ func (tm *tracing) RetrieveAll(ctx context.Context, session smqauthn.Session, pa return tm.svc.RetrieveAll(ctx, session, page) } + +func (tm *tracing) RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (j journal.ClientTelemetry, err error) { + ctx, span := tm.tracer.Start(ctx, "retrieve", trace.WithAttributes( + attribute.String("client_id", clientID), + attribute.String("domain_id", session.DomainID), + )) + defer span.End() + + return tm.svc.RetrieveClientTelemetry(ctx, session, clientID) +} diff --git a/journal/mocks/repository.go b/journal/mocks/repository.go index f5114c9ae3..32abe3ce4c 100644 --- a/journal/mocks/repository.go +++ b/journal/mocks/repository.go @@ -16,6 +16,24 @@ type Repository struct { mock.Mock } +// DeleteClientTelemetry provides a mock function with given fields: ctx, clientID, domainID +func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string, domainID string) error { + ret := _m.Called(ctx, clientID, domainID) + + if len(ret) == 0 { + panic("no return value specified for DeleteClientTelemetry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, clientID, domainID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // RetrieveAll provides a mock function with given fields: ctx, page func (_m *Repository) RetrieveAll(ctx context.Context, page journal.Page) (journal.JournalsPage, error) { ret := _m.Called(ctx, page) @@ -44,6 +62,34 @@ func (_m *Repository) RetrieveAll(ctx context.Context, page journal.Page) (journ return r0, r1 } +// RetrieveClientTelemetry provides a mock function with given fields: ctx, clientID, domainID +func (_m *Repository) RetrieveClientTelemetry(ctx context.Context, clientID string, domainID string) (journal.ClientTelemetry, error) { + ret := _m.Called(ctx, clientID, domainID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveClientTelemetry") + } + + var r0 journal.ClientTelemetry + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (journal.ClientTelemetry, error)); ok { + return rf(ctx, clientID, domainID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) journal.ClientTelemetry); ok { + r0 = rf(ctx, clientID, domainID) + } else { + r0 = ret.Get(0).(journal.ClientTelemetry) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, clientID, domainID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Save provides a mock function with given fields: ctx, _a1 func (_m *Repository) Save(ctx context.Context, _a1 journal.Journal) error { ret := _m.Called(ctx, _a1) @@ -62,6 +108,24 @@ func (_m *Repository) Save(ctx context.Context, _a1 journal.Journal) error { return r0 } +// SaveClientTelemetry provides a mock function with given fields: ctx, ct +func (_m *Repository) SaveClientTelemetry(ctx context.Context, ct journal.ClientTelemetry) error { + ret := _m.Called(ctx, ct) + + if len(ret) == 0 { + panic("no return value specified for SaveClientTelemetry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, journal.ClientTelemetry) error); ok { + r0 = rf(ctx, ct) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewRepository(t interface { diff --git a/journal/mocks/service.go b/journal/mocks/service.go index 84bb490abc..e194c89bb0 100644 --- a/journal/mocks/service.go +++ b/journal/mocks/service.go @@ -47,6 +47,34 @@ func (_m *Service) RetrieveAll(ctx context.Context, session authn.Session, page return r0, r1 } +// RetrieveClientTelemetry provides a mock function with given fields: ctx, session, clientID +func (_m *Service) RetrieveClientTelemetry(ctx context.Context, session authn.Session, clientID string) (journal.ClientTelemetry, error) { + ret := _m.Called(ctx, session, clientID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveClientTelemetry") + } + + var r0 journal.ClientTelemetry + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (journal.ClientTelemetry, error)); ok { + return rf(ctx, session, clientID) + } + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) journal.ClientTelemetry); ok { + r0 = rf(ctx, session, clientID) + } else { + r0 = ret.Get(0).(journal.ClientTelemetry) + } + + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = rf(ctx, session, clientID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Save provides a mock function with given fields: ctx, _a1 func (_m *Service) Save(ctx context.Context, _a1 journal.Journal) error { ret := _m.Called(ctx, _a1) diff --git a/journal/postgres/init.go b/journal/postgres/init.go index dd5482f3d7..29ffb23565 100644 --- a/journal/postgres/init.go +++ b/journal/postgres/init.go @@ -27,8 +27,19 @@ func Migration() *migrate.MemoryMigrationSource { `CREATE INDEX idx_journal_default_group_filter ON journal(operation, (attributes->>'id'), (attributes->>'group_id'), occurred_at DESC);`, `CREATE INDEX idx_journal_default_client_filter ON journal(operation, (attributes->>'id'), (attributes->>'client_id'), occurred_at DESC);`, `CREATE INDEX idx_journal_default_channel_filter ON journal(operation, (attributes->>'id'), (attributes->>'channel_id'), occurred_at DESC);`, + `CREATE TABLE IF NOT EXISTS clients_telemetry ( + client_id VARCHAR(36) NOT NULL, + domain_id VARCHAR(36) NOT NULL, + subscriptions TEXT[], + inbound_messages BIGINT DEFAULT 0, + outbound_messages BIGINT DEFAULT 0, + first_seen TIMESTAMP, + last_seen TIMESTAMP, + PRIMARY KEY (client_id, domain_id) + )`, }, Down: []string{ + `DROP TABLE IF EXISTS clients_telemetry`, `DROP TABLE IF EXISTS journal`, }, }, diff --git a/journal/postgres/telemetry.go b/journal/postgres/telemetry.go new file mode 100644 index 0000000000..f231327319 --- /dev/null +++ b/journal/postgres/telemetry.go @@ -0,0 +1,135 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/absmach/supermq/journal" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jackc/pgtype" +) + +func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.ClientTelemetry) error { + q := `INSERT INTO clients_telemetry (client_id, domain_id, messages, subscriptions, first_seen, last_seen) + VALUES (:client_id, :domain_id, :messages, :subscriptions, :first_seen, :last_seen);` + + dbct, err := toDBClientsTelemetry(ct) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + if _, err := repo.db.NamedExecContext(ctx, q, dbct); err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error { + q := "DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;" + + dbct := dbClientTelemetry{ + ClientID: clientID, + DomainID: domainID, + } + + result, err := repo.db.NamedExecContext(ctx, q, dbct) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + return nil +} + +func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, domainID string) (journal.ClientTelemetry, error) { + q := "SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;" + + dbct := dbClientTelemetry{ + ClientID: clientID, + DomainID: domainID, + } + + rows, err := repo.db.NamedQueryContext(ctx, q, dbct) + if err != nil { + return journal.ClientTelemetry{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + dbct = dbClientTelemetry{} + if rows.Next() { + if err = rows.StructScan(&dbct); err != nil { + return journal.ClientTelemetry{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + + ct, err := toClientsTelemetry(dbct) + if err != nil { + return journal.ClientTelemetry{}, errors.Wrap(repoerr.ErrFailedOpDB, err) + } + + return ct, nil + } + + return journal.ClientTelemetry{}, repoerr.ErrNotFound +} + +type dbClientTelemetry struct { + ClientID string `db:"client_id"` + DomainID string `db:"domain_id"` + Subscriptions pgtype.TextArray `db:"subscriptions"` + InboundMessages uint64 `db:"inbound_messages"` + OutboundMessages uint64 `db:"outbound_messages"` + FirstSeen time.Time `db:"first_seen"` + LastSeen sql.NullTime `db:"last_seen"` +} + +func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) { + var subs pgtype.TextArray + if err := subs.Set(ct.Subscriptions); err != nil { + return dbClientTelemetry{}, err + } + + var lastSeen sql.NullTime + if ct.LastSeen != (time.Time{}) { + lastSeen = sql.NullTime{Time: ct.LastSeen, Valid: true} + } + + return dbClientTelemetry{ + ClientID: ct.ClientID, + DomainID: ct.DomainID, + Subscriptions: subs, + InboundMessages: ct.InboundMessages, + OutboundMessages: ct.OutboundMessages, + FirstSeen: ct.FirstSeen, + LastSeen: lastSeen, + }, nil +} + +func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) { + var subs []string + for _, e := range dbct.Subscriptions.Elements { + subs = append(subs, e.String) + } + + var lastSeen time.Time + if dbct.LastSeen.Valid { + lastSeen = dbct.LastSeen.Time + } + + return journal.ClientTelemetry{ + ClientID: dbct.ClientID, + DomainID: dbct.DomainID, + Subscriptions: subs, + InboundMessages: dbct.InboundMessages, + OutboundMessages: dbct.OutboundMessages, + FirstSeen: dbct.FirstSeen, + LastSeen: lastSeen, + }, nil +} diff --git a/journal/service.go b/journal/service.go index 6f288a9393..81f9ed61bf 100644 --- a/journal/service.go +++ b/journal/service.go @@ -42,3 +42,12 @@ func (svc *service) RetrieveAll(ctx context.Context, session smqauthn.Session, p return journalPage, nil } + +func (svc *service) RetrieveClientTelemetry(ctx context.Context, session smqauthn.Session, clientID string) (ClientTelemetry, error) { + ct, err := svc.repository.RetrieveClientTelemetry(ctx, clientID, session.DomainID) + if err != nil { + return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + return ct, nil +}