diff --git a/groups/api/grpc/client.go b/groups/api/grpc/client.go index e65c5c0cf4c..9a0e0f01151 100644 --- a/groups/api/grpc/client.go +++ b/groups/api/grpc/client.go @@ -31,7 +31,6 @@ type grpcClient struct { // NewClient returns new gRPC client instance. func NewClient(conn *grpc.ClientConn, timeout time.Duration) grpcGroupsV1.GroupsServiceClient { return &grpcClient{ - retrieveEntity: kitgrpc.NewClient( conn, svcName, diff --git a/groups/api/grpc/endpoint.go b/groups/api/grpc/endpoint.go index bbb6bac0072..d8d53d3de27 100644 --- a/groups/api/grpc/endpoint.go +++ b/groups/api/grpc/endpoint.go @@ -12,15 +12,12 @@ import ( func retrieveEntityEndpoint(svc groups.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(retrieveEntityReq) group, err := svc.RetrieveById(ctx, req.Id) - if err != nil { return retrieveEntityRes{}, err } return retrieveEntityRes{id: group.ID, domain: group.Domain, parentGroup: group.Parent, status: uint8(group.Status)}, nil - } } diff --git a/groups/api/grpc/endpoint_test.go b/groups/api/grpc/endpoint_test.go new file mode 100644 index 00000000000..4787187b0ab --- /dev/null +++ b/groups/api/grpc/endpoint_test.go @@ -0,0 +1,160 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc_test + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/absmach/magistrala/groups" + grpcapi "github.com/absmach/magistrala/groups/api/grpc" + prmocks "github.com/absmach/magistrala/groups/private/mocks" + grpcCommonV1 "github.com/absmach/magistrala/internal/grpc/common/v1" + grpcGroupsV1 "github.com/absmach/magistrala/internal/grpc/groups/v1" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + svcerr "github.com/absmach/magistrala/pkg/errors/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const port = 7004 + +var ( + validID = testsutil.GenerateUUID(&testing.T{}) + valid = "valid" + validGroupResp = groups.Group{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: valid, + Description: valid, + Domain: testsutil.GenerateUUID(&testing.T{}), + Parent: testsutil.GenerateUUID(&testing.T{}), + Metadata: groups.Metadata{ + "name": "test", + }, + Children: []*groups.Group{}, + CreatedAt: time.Now().Add(-1 * time.Second), + UpdatedAt: time.Now(), + UpdatedBy: testsutil.GenerateUUID(&testing.T{}), + Status: groups.EnabledStatus, + } +) + +func startGRPCServer(svc *prmocks.Service, port int) { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + panic(fmt.Sprintf("failed to obtain port: %s", err)) + } + server := grpc.NewServer() + grpcGroupsV1.RegisterGroupsServiceServer(server, grpcapi.NewServer(svc)) + go func() { + if err := server.Serve(listener); err != nil { + panic(fmt.Sprintf("failed to serve: %s", err)) + } + }() +} + +func TestRetrieveEntityEndpoint(t *testing.T) { + svc := new(prmocks.Service) + startGRPCServer(svc, port) + grpAddr := fmt.Sprintf("localhost:%d", port) + conn, _ := grpc.NewClient(grpAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + client := grpcapi.NewClient(conn, time.Second) + + cases := []struct { + desc string + req *grpcCommonV1.RetrieveEntityReq + svcRes groups.Group + svcErr error + res *grpcCommonV1.RetrieveEntityRes + err error + }{ + { + desc: "retrieve group successfully", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcRes: validGroupResp, + svcErr: nil, + res: &grpcCommonV1.RetrieveEntityRes{ + Entity: &grpcCommonV1.EntityBasic{ + Id: validGroupResp.ID, + DomainId: validGroupResp.Domain, + ParentGroupId: validGroupResp.Parent, + Status: uint32(validGroupResp.Status), + }, + }, + err: nil, + }, + { + desc: "retrieve group with authentication error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: svcerr.ErrAuthentication, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: svcerr.ErrAuthentication, + }, + { + desc: "retrieve group with authorization error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: svcerr.ErrAuthorization, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: svcerr.ErrAuthorization, + }, + { + desc: "retrieve group with not found error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: svcerr.ErrNotFound, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: svcerr.ErrNotFound, + }, + { + desc: "retrieve group with malformed entity error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: errors.ErrMalformedEntity, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: errors.ErrMalformedEntity, + }, + { + desc: "retrieve group with conflict error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: svcerr.ErrConflict, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: svcerr.ErrConflict, + }, + { + desc: "retrieve group with unknown error", + req: &grpcCommonV1.RetrieveEntityReq{ + Id: validID, + }, + svcErr: errors.ErrUnidentified, + res: &grpcCommonV1.RetrieveEntityRes{}, + err: errors.ErrUnidentified, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RetrieveById", mock.Anything, tc.req.Id).Return(tc.svcRes, tc.svcErr) + res, err := client.RetrieveEntity(context.Background(), tc.req) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + assert.Equal(t, tc.res, res, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.res, res)) + svcCall.Unset() + }) + } +} diff --git a/groups/api/http/decode.go b/groups/api/http/decode.go index b9041b512ae..9b43b260019 100644 --- a/groups/api/http/decode.go +++ b/groups/api/http/decode.go @@ -195,6 +195,7 @@ func decodeHierarchyPageMeta(r *http.Request) (mggroups.HierarchyPageMeta, error Tree: tree, }, nil } + func decodePageMeta(r *http.Request) (mggroups.PageMeta, error) { s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus) if err != nil { diff --git a/groups/api/http/decode_test.go b/groups/api/http/decode_test.go index 306f8c45773..0db32f12958 100644 --- a/groups/api/http/decode_test.go +++ b/groups/api/http/decode_test.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "net/url" - "reflect" "strings" "testing" @@ -33,14 +32,15 @@ func TestDecodeListGroupsRequest(t *testing.T) { header: map[string][]string{}, resp: listGroupsReq{ PageMeta: groups.PageMeta{ - Limit: 10, + Limit: 10, + Actions: []string{}, }, }, err: nil, }, { desc: "valid request with all parameters", - url: "http://localhost:8080?status=enabled&offset=10&limit=10&name=random&metadata={\"test\":\"test\"}&level=2&parent_id=random&tree=true&dir=-1&member_kind=random&permission=random&list_perms=true", + url: "http://localhost:8080?status=enabled&offset=10&limit=10&name=random&metadata={\"test\":\"test\"}&level=2&t&permission=random&list_perms=true", header: map[string][]string{ "Authorization": {"Bearer 123"}, }, @@ -53,6 +53,7 @@ func TestDecodeListGroupsRequest(t *testing.T) { Metadata: groups.Metadata{ "test": "test", }, + Actions: []string{}, }, }, err: nil, @@ -63,48 +64,6 @@ func TestDecodeListGroupsRequest(t *testing.T) { resp: nil, err: apiutil.ErrValidation, }, - { - desc: "valid request with invalid level", - url: "http://localhost:8080?level=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid parent", - url: "http://localhost:8080?parent_id=random&parent_id=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid tree", - url: "http://localhost:8080?tree=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid dir", - url: "http://localhost:8080?dir=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid member kind", - url: "http://localhost:8080?member_kind=random&member_kind=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid permission", - url: "http://localhost:8080?permission=random&permission=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid list permission", - url: "http://localhost:8080?&list_perms=random", - resp: nil, - err: apiutil.ErrValidation, - }, } for _, tc := range cases { @@ -134,27 +93,27 @@ func TestDecodeRetrieveGroupHierarchy(t *testing.T) { url: "http://localhost:8080", header: map[string][]string{}, resp: retrieveGroupHierarchyReq{ - HierarchyPageMeta: groups.HierarchyPageMeta{}, + HierarchyPageMeta: groups.HierarchyPageMeta{ + Direction: -1, + }, }, err: nil, }, { desc: "valid request with all parameters", - url: "http://localhost:8080?status=enabled&offset=10&limit=10&name=random&metadata={\"test\":\"test\"}&level=2&parent_id=random&tree=true&dir=-1&member_kind=random&permission=random&list_perms=true", + url: "http://localhost:8080?tree=true&level=2&dir=-1", header: map[string][]string{ "Authorization": {"Bearer 123"}, }, resp: retrieveGroupHierarchyReq{ - HierarchyPageMeta: groups.HierarchyPageMeta{}, + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 2, + Direction: -1, + Tree: true, + }, }, err: nil, }, - { - desc: "valid request with invalid page metadata", - url: "http://localhost:8080?metadata=random", - resp: nil, - err: apiutil.ErrValidation, - }, { desc: "valid request with invalid level", url: "http://localhost:8080?level=random", @@ -168,14 +127,8 @@ func TestDecodeRetrieveGroupHierarchy(t *testing.T) { err: apiutil.ErrValidation, }, { - desc: "valid request with invalid permission", - url: "http://localhost:8080?permission=random&permission=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid list permission", - url: "http://localhost:8080?&list_perms=random", + desc: "valid request with invalid direction", + url: "http://localhost:8080?dir=random", resp: nil, err: apiutil.ErrValidation, }, @@ -207,9 +160,12 @@ func TestDecodeListChildrenRequest(t *testing.T) { desc: "valid request with no parameters", url: "http://localhost:8080", header: map[string][]string{}, - resp: listGroupsReq{ + resp: listChildrenGroupsReq{ + startLevel: 1, + endLevel: 0, PageMeta: groups.PageMeta{ - Limit: 10, + Limit: 10, + Actions: []string{}, }, }, err: nil, @@ -220,7 +176,9 @@ func TestDecodeListChildrenRequest(t *testing.T) { header: map[string][]string{ "Authorization": {"Bearer 123"}, }, - resp: listGroupsReq{ + resp: listChildrenGroupsReq{ + startLevel: 1, + endLevel: 0, PageMeta: groups.PageMeta{ Status: groups.EnabledStatus, Offset: 10, @@ -229,6 +187,7 @@ func TestDecodeListChildrenRequest(t *testing.T) { Metadata: groups.Metadata{ "test": "test", }, + Actions: []string{}, }, }, err: nil, @@ -239,30 +198,6 @@ func TestDecodeListChildrenRequest(t *testing.T) { resp: nil, err: apiutil.ErrValidation, }, - { - desc: "valid request with invalid level", - url: "http://localhost:8080?level=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid tree", - url: "http://localhost:8080?tree=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid permission", - url: "http://localhost:8080?permission=random&permission=random", - resp: nil, - err: apiutil.ErrValidation, - }, - { - desc: "valid request with invalid list permission", - url: "http://localhost:8080?&list_perms=random", - resp: nil, - err: apiutil.ErrValidation, - }, } for _, tc := range cases { @@ -290,7 +225,8 @@ func TestDecodePageMeta(t *testing.T) { desc: "valid request with no parameters", url: "http://localhost:8080", resp: groups.PageMeta{ - Limit: 10, + Limit: 10, + Actions: []string{}, }, err: nil, }, @@ -305,6 +241,7 @@ func TestDecodePageMeta(t *testing.T) { Metadata: groups.Metadata{ "test": "test", }, + Actions: []string{}, }, err: nil, }, @@ -528,271 +465,3 @@ func TestDecodeChangeGroupStatus(t *testing.T) { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) } } - -func TestDecodeChangeGroupStatusRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := DecodeChangeGroupStatusRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("DecodeChangeGroupStatusRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("DecodeChangeGroupStatusRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeRetrieveGroupHierarchy(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeRetrieveGroupHierarchy(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRetrieveGroupHierarchy() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRetrieveGroupHierarchy() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeAddParentGroupRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeAddParentGroupRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeAddParentGroupRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeAddParentGroupRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeRemoveParentGroupRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeRemoveParentGroupRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRemoveParentGroupRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRemoveParentGroupRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeAddChildrenGroupsRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeAddChildrenGroupsRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeAddChildrenGroupsRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeAddChildrenGroupsRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeRemoveChildrenGroupsRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeRemoveChildrenGroupsRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRemoveChildrenGroupsRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRemoveChildrenGroupsRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeRemoveAllChildrenGroupsRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeRemoveAllChildrenGroupsRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRemoveAllChildrenGroupsRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRemoveAllChildrenGroupsRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeListChildrenGroupsRequest(t *testing.T) { - type args struct { - in0 context.Context - r *http.Request - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeListChildrenGroupsRequest(tt.args.in0, tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeListChildrenGroupsRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeListChildrenGroupsRequest() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodeHierarchyPageMeta(t *testing.T) { - type args struct { - r *http.Request - } - tests := []struct { - name string - args args - want groups.HierarchyPageMeta - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodeHierarchyPageMeta(tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeHierarchyPageMeta() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeHierarchyPageMeta() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_decodePageMeta(t *testing.T) { - type args struct { - r *http.Request - } - tests := []struct { - name string - args args - want groups.PageMeta - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := decodePageMeta(tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodePageMeta() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodePageMeta() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/groups/api/http/endpoint_test.go b/groups/api/http/endpoint_test.go index 7cd03c095f6..1f537964e98 100644 --- a/groups/api/http/endpoint_test.go +++ b/groups/api/http/endpoint_test.go @@ -2,3 +2,2028 @@ // SPDX-License-Identifier: Apache-2.0 package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/groups" + "github.com/absmach/magistrala/groups/mocks" + mgapi "github.com/absmach/magistrala/internal/api" + "github.com/absmach/magistrala/internal/testsutil" + mglog "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/apiutil" + mgauthn "github.com/absmach/magistrala/pkg/authn" + authnmocks "github.com/absmach/magistrala/pkg/authn/mocks" + "github.com/absmach/magistrala/pkg/errors" + svcerr "github.com/absmach/magistrala/pkg/errors/service" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + validGroupResp = groups.Group{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: valid, + Description: valid, + Domain: testsutil.GenerateUUID(&testing.T{}), + Parent: testsutil.GenerateUUID(&testing.T{}), + Metadata: groups.Metadata{ + "name": "test", + }, + Children: []*groups.Group{}, + CreatedAt: time.Now().Add(-1 * time.Second), + UpdatedAt: time.Now(), + UpdatedBy: testsutil.GenerateUUID(&testing.T{}), + Status: groups.EnabledStatus, + } + validID = testsutil.GenerateUUID(&testing.T{}) + validToken = "validToken" + invalidToken = "invalidToken" + contentType = "application/json" +) + +func newGroupsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + authn := new(authnmocks.Authentication) + svc := new(mocks.Service) + mux := chi.NewRouter() + logger := mglog.NewMock() + mux = MakeHandler(svc, authn, mux, logger, "") + + return httptest.NewServer(mux), svc, authn +} + +func TestCreateGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + reqGroup := groups.Group{ + Name: valid, + Description: valid, + Metadata: map[string]interface{}{ + "name": "test", + }, + } + + cases := []struct { + desc string + token string + session mgauthn.Session + domainID string + req createGroupReq + contentType string + svcResp groups.Group + svcErr error + authnErr error + status int + err error + }{ + { + desc: "create group successfully", + token: validToken, + domainID: validID, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: contentType, + svcResp: validGroupResp, + status: http.StatusCreated, + err: nil, + }, + { + desc: "create group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "create group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "create group with empty domainID", + token: validToken, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "create group with missing name", + token: validToken, + domainID: validID, + req: createGroupReq{ + Group: groups.Group{ + Description: valid, + Metadata: map[string]interface{}{ + "name": "test", + }, + }, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "create group with name that is too long", + token: validToken, + domainID: validID, + req: createGroupReq{ + Group: groups.Group{ + Name: strings.Repeat("a", 1025), + Description: valid, + Metadata: map[string]interface{}{ + "name": "test", + }, + }, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrNameSize, + }, + { + desc: "create group with invalid content type", + token: validToken, + domainID: validID, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: "application/xml", + svcResp: validGroupResp, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "create group with service error", + token: validToken, + domainID: validID, + req: createGroupReq{ + Group: reqGroup, + }, + contentType: contentType, + svcResp: groups.Group{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.req) + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/groups/", gs.URL, tc.domainID), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("CreateGroup", mock.Anything, tc.session, tc.req.Group).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcResp groups.Group + svcErr error + resp groups.Group + status int + authnErr error + err error + }{ + { + desc: "view group successfully", + token: validToken, + domainID: validID, + id: validID, + svcResp: validGroupResp, + svcErr: nil, + resp: validGroupResp, + status: http.StatusOK, + err: nil, + }, + { + desc: "view group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validID, + svcResp: validGroupResp, + svcErr: nil, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "view group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "view group with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "view group with service error", + token: validToken, + id: validID, + domainID: validID, + svcResp: validGroupResp, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/groups/%s", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("ViewGroup", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + updateGroupReq := groups.Group{ + ID: validID, + Name: valid, + Description: valid, + Metadata: map[string]interface{}{ + "name": "test", + }, + } + + cases := []struct { + desc string + token string + id string + domainID string + updateReq groups.Group + contentType string + session mgauthn.Session + svcResp groups.Group + svcErr error + resp groups.Group + status int + authnErr error + err error + }{ + { + desc: "update group successfully", + token: validToken, + domainID: validID, + id: validID, + updateReq: updateGroupReq, + contentType: contentType, + svcResp: validGroupResp, + status: http.StatusOK, + err: nil, + }, + { + desc: "update group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validID, + updateReq: updateGroupReq, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validID, + updateReq: updateGroupReq, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update group with empty domainID", + token: validToken, + id: validID, + updateReq: updateGroupReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update group with name that is too long", + token: validToken, + id: validID, + domainID: validID, + updateReq: groups.Group{ + ID: validID, + Name: strings.Repeat("a", 1025), + Description: valid, + Metadata: map[string]interface{}{ + "name": "test", + }, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrNameSize, + }, + { + desc: "update group with invalid content type", + token: validToken, + id: validID, + domainID: validID, + updateReq: updateGroupReq, + contentType: "application/xml", + svcResp: validGroupResp, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update group with service error", + token: validToken, + id: validID, + domainID: validID, + updateReq: updateGroupReq, + contentType: contentType, + svcResp: groups.Group{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.updateReq) + req := testRequest{ + client: gs.Client(), + method: http.MethodPut, + url: fmt.Sprintf("%s/%s/groups/%s", gs.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateGroup", mock.Anything, tc.session, tc.updateReq).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestEnableGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcResp groups.Group + svcErr error + resp groups.Group + status int + authnErr error + err error + }{ + { + desc: "enable group successfully", + token: validToken, + domainID: validID, + id: validID, + svcResp: validGroupResp, + svcErr: nil, + resp: validGroupResp, + status: http.StatusOK, + err: nil, + }, + { + desc: "enable group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "enable group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "enable group with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "enable group with service error", + token: validToken, + id: validID, + domainID: validID, + svcResp: groups.Group{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "enable group with empty id", + token: validToken, + id: "", + domainID: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/groups/%s/enable", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("EnableGroup", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDisableGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcResp groups.Group + svcErr error + resp groups.Group + status int + authnErr error + err error + }{ + { + desc: "disable group successfully", + token: validToken, + domainID: validID, + id: validID, + svcResp: validGroupResp, + svcErr: nil, + resp: validGroupResp, + status: http.StatusOK, + err: nil, + }, + { + desc: "disable group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "disable group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "disable group with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "disable group with service error", + token: validToken, + id: validID, + domainID: validID, + svcResp: groups.Group{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "disable group with empty id", + token: validToken, + id: "", + domainID: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/groups/%s/disable", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("DisableGroup", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListGroups(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + query string + domainID string + token string + session mgauthn.Session + listGroupsResponse groups.Page + status int + authnErr error + err error + }{ + { + desc: "list groups successfully", + domainID: validID, + token: validToken, + status: http.StatusOK, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + err: nil, + }, + { + desc: "list groups with empty token", + domainID: validID, + token: "", + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "list groups with invalid token", + domainID: validID, + token: invalidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "list groups with offset", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "offset=1", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid offset", + domainID: validID, + token: validToken, + query: "offset=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with limit", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "limit=1", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid limit", + domainID: validID, + token: validToken, + query: "limit=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with limit greater than max", + token: validToken, + domainID: validID, + query: fmt.Sprintf("limit=%d", mgapi.MaxLimitSize+1), + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with name", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "name=clientname", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid name", + domainID: validID, + token: validToken, + query: "name=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate name", + domainID: validID, + token: validToken, + query: "name=1&name=2", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list groups with status", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "status=enabled", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid status", + domainID: validID, + token: validToken, + query: "status=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate status", + domainID: validID, + token: validToken, + query: "status=enabled&status=disabled", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list groups with tags", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "tag=tag1,tag2", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid tags", + domainID: validID, + token: validToken, + query: "tag=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate tags", + domainID: validID, + token: validToken, + query: "tag=tag1&tag=tag2", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list groups with metadata", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "metadata=%7B%22domain%22%3A%20%22example.com%22%7D&", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid metadata", + domainID: validID, + token: validToken, + query: "metadata=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate metadata", + domainID: validID, + token: validToken, + query: "metadata=%7B%22domain%22%3A%20%22example.com%22%7D&metadata=%7B%22domain%22%3A%20%22example.com%22%7D", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list groups with permissions", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "permission=view", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid permissions", + domainID: validID, + token: validToken, + query: "permission=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate permissions", + domainID: validID, + token: validToken, + query: "permission=view&permission=view", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list groups with list perms", + domainID: validID, + token: validToken, + listGroupsResponse: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + query: "list_perms=true", + status: http.StatusOK, + err: nil, + }, + { + desc: "list groups with invalid list perms", + domainID: validID, + token: validToken, + query: "list_perms=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list groups with duplicate list perms", + domainID: validID, + token: validToken, + query: "list_perms=true&listPerms=true", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodGet, + url: gs.URL + "/" + tc.domainID + "/groups?" + tc.query, + contentType: contentType, + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("ListGroups", mock.Anything, tc.session, mock.Anything).Return(tc.listGroupsResponse, tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var bodyRes respBody + err = json.NewDecoder(res.Body).Decode(&bodyRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if bodyRes.Err != "" || bodyRes.Message != "" { + err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcErr error + status int + authnErr error + err error + }{ + { + desc: "delete group successfully", + token: validToken, + domainID: validID, + id: validID, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "delete group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "delete group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "delete group with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "delete group with service error", + token: validToken, + id: validID, + domainID: validID, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/groups/%s", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("DeleteGroup", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRetrieveGroupHierarchyEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + retrieveHierarchRes := groups.HierarchyPage{ + Groups: []groups.Group{validGroupResp}, + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + } + + cases := []struct { + desc string + token string + session mgauthn.Session + domainID string + groupID string + query string + pageMeta groups.HierarchyPageMeta + svcRes groups.HierarchyPage + svcErr error + authnErr error + status int + err error + }{ + { + desc: "retrieve group hierarchy successfully", + token: validToken, + domainID: validID, + groupID: validID, + query: "level=1&dir=-1&tree=false", + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + svcRes: retrieveHierarchRes, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "retrieve group hierarchy with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + groupID: validID, + query: "level=1&dir=-1&tree=false", + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "retrieve group hierarchy with empty token", + token: "", + session: mgauthn.Session{}, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "retrieve group hierarchy with empty domainID", + token: validToken, + groupID: validID, + query: "level=1&dir=-1&tree=false", + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "retrieve group hierarchy with service error", + token: validToken, + groupID: validID, + domainID: validID, + query: "level=1&dir=-1&tree=false", + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + svcRes: groups.HierarchyPage{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "retrieve group hierarchy with invalid level", + token: validToken, + groupID: validID, + domainID: validID, + query: "level=invalid&dir=-1&tree=false", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "retrieve group hierarchy with invalid direction", + token: validToken, + groupID: validID, + domainID: validID, + query: "level=1&dir=invalid&tree=false", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "retrieve group hierarchy with invalid tree", + token: validToken, + groupID: validID, + domainID: validID, + query: "level=1&dir=-1&tree=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "retrieve group hierarchy with empty groupID", + token: validToken, + domainID: validID, + query: "level=1&dir=-1&tree=false", + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/groups/%s/hierarchy?%s", gs.URL, tc.domainID, tc.groupID, tc.query), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RetrieveGroupHierarchy", mock.Anything, tc.session, tc.groupID, tc.pageMeta).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var bodyRes respBody + err = json.NewDecoder(res.Body).Decode(&bodyRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if bodyRes.Err != "" || bodyRes.Message != "" { + err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestAddParentGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + parentID string + session mgauthn.Session + contentType string + svcErr error + status int + authnErr error + err error + }{ + { + desc: "add parent group successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + parentID: validID, + contentType: contentType, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "add parent group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + parentID: validID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "add parent group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + parentID: validID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "add parent group with empty domainID", + token: validToken, + id: validGroupResp.ID, + parentID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "add parent group with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + parentID: validID, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "add parent group with empty id", + token: validToken, + id: "", + domainID: validID, + parentID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "add parent group with empty parentID", + token: validToken, + id: validID, + domainID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "add self parenting group", + token: validToken, + id: validID, + domainID: validID, + parentID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrSelfParentingNotAllowed, + }, + { + desc: "add parent group with invalid content type", + token: validToken, + id: validID, + domainID: validID, + parentID: validID, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + reqData := struct { + ParentID string `json:"parent_id"` + }{ + ParentID: tc.parentID, + } + data := toJSON(reqData) + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/groups/%s/parent", gs.URL, tc.domainID, tc.id), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("AddParentGroup", mock.Anything, tc.session, tc.id, tc.parentID).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveParentGroupEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcErr error + status int + authnErr error + err error + }{ + { + desc: "remove parent group successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove parent group with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "remove parent group with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "remove parent group with empty domainID", + token: validToken, + id: validGroupResp.ID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "remove parent group with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "remove parent group with empty id", + token: validToken, + id: "", + domainID: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/groups/%s/parent", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RemoveParentGroup", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestAddChildrenGroupsEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + childrenIDs []string + session mgauthn.Session + contentType string + svcErr error + status int + authnErr error + err error + }{ + { + desc: "add children groups successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "add children groups with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "add children groups with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "add children groups with empty domainID", + token: validToken, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "add children groups with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{validID}, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "add children groups with empty id", + token: validToken, + id: "", + domainID: validID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "add children groups with empty childrenIDs", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "add children groups with invalid childrenIDs", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{"invalid"}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "add self children group", + token: validToken, + id: validID, + domainID: validID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrSelfParentingNotAllowed, + }, + { + desc: "add children groups with invalid content type", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{validID}, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + reqData := struct { + ChildrenIDs []string `json:"children_ids"` + }{ + ChildrenIDs: tc.childrenIDs, + } + data := toJSON(reqData) + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/groups/%s/children", gs.URL, tc.domainID, tc.id), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("AddChildrenGroups", mock.Anything, tc.session, tc.id, tc.childrenIDs).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveChildrenGroupsEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + childrenIDs []string + contentType string + svcErr error + status int + authnErr error + err error + }{ + { + desc: "remove children groups successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove children groups with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "remove children groups with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "remove children groups with empty domainID", + token: validToken, + id: validGroupResp.ID, + childrenIDs: []string{validID}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "remove children groups with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{validID}, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "remove children groups with empty id", + token: validToken, + id: "", + domainID: validID, + contentType: contentType, + childrenIDs: []string{validID}, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "remove children groups with empty childrenIDs", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "remove children groups with invalid childrenIDs", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{"invalid"}, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "remove children groups with invalid content type", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + childrenIDs: []string{validID}, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + reqData := struct { + ChildrenIDs []string `json:"children_ids"` + }{ + ChildrenIDs: tc.childrenIDs, + } + data := toJSON(reqData) + req := testRequest{ + client: gs.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/groups/%s/children", gs.URL, tc.domainID, tc.id), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RemoveChildrenGroups", mock.Anything, tc.session, tc.id, tc.childrenIDs).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveAllChildrenGroupsEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + svcErr error + status int + authnErr error + err error + }{ + { + desc: "remove all children groups successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove all children groups with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "remove all children groups with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "remove all children groups with empty domainID", + token: validToken, + id: validGroupResp.ID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "remove all children groups with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "remove all children groups with empty id", + token: validToken, + id: "", + domainID: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/groups/%s/children/all", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RemoveAllChildrenGroups", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListChildrenGroupsEndpoint(t *testing.T) { + gs, svc, authn := newGroupsServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session mgauthn.Session + query string + pageMeta groups.PageMeta + svcRes groups.Page + svcErr error + authnErr error + status int + err error + }{ + { + desc: "list children groups successfully", + token: validToken, + domainID: validID, + id: validGroupResp.ID, + query: "limit=1&offset=0", + pageMeta: groups.PageMeta{ + Limit: 1, + Offset: 0, + Actions: []string{}, + }, + svcRes: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{validGroupResp}, + }, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "list children groups with invalid token", + token: invalidToken, + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + query: "limit=1&offset=0", + pageMeta: groups.PageMeta{ + Limit: 1, + Offset: 0, + Actions: []string{}, + }, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "list children groups with empty token", + token: "", + session: mgauthn.Session{}, + domainID: validID, + id: validGroupResp.ID, + query: "limit=1&offset=0", + pageMeta: groups.PageMeta{ + Limit: 1, + Offset: 0, + Actions: []string{}, + }, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "list children groups with empty domainID", + token: validToken, + id: validGroupResp.ID, + query: "limit=1&offset=0", + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "list children groups with service error", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + query: "limit=1&offset=0", + pageMeta: groups.PageMeta{ + Limit: 1, + Offset: 0, + Actions: []string{}, + }, + svcRes: groups.Page{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "list children groups with invalid limit", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + query: "limit=invalid&offset=0", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list children groups with invalid offset", + token: validToken, + id: validGroupResp.ID, + domainID: validID, + query: "limit=1&offset=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list children groups with empty id", + token: validToken, + domainID: validID, + query: "limit=1&offset=0", + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/groups/%s/children?%s", gs.URL, tc.domainID, tc.id, tc.query), + token: tc.token, + } + if tc.token == validToken { + tc.session = mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("ListChildrenGroups", mock.Anything, tc.session, tc.id, int64(1), int64(0), tc.pageMeta).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var bodyRes respBody + err = json.NewDecoder(res.Body).Decode(&bodyRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if bodyRes.Err != "" || bodyRes.Message != "" { + err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + req.Header.Set("Referer", "http://localhost") + + return tr.client.Do(req) +} + +func toJSON(data interface{}) string { + jsonData, err := json.Marshal(data) + if err != nil { + return "" + } + return string(jsonData) +} + +type respBody struct { + Err string `json:"error"` + Message string `json:"message"` + Total int `json:"total"` + Permissions []string `json:"permissions"` + ID string `json:"id"` + Tags []string `json:"tags"` + Status groups.Status `json:"status"` +} diff --git a/groups/api/http/endpoints.go b/groups/api/http/endpoints.go index 37c86d7099a..8ccee14f701 100644 --- a/groups/api/http/endpoints.go +++ b/groups/api/http/endpoints.go @@ -15,8 +15,6 @@ import ( "github.com/go-kit/kit/endpoint" ) -const groupTypeChannels = "channels" - func CreateGroupEndpoint(svc groups.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(createGroupReq) @@ -157,14 +155,14 @@ func ListGroupsEndpoint(svc groups.Service) endpoint.Endpoint { groups = append(groups, toViewGroupRes(g)) } - return groupPageRes{pageRes: pageRes{ - Limit: page.Limit, - Offset: page.Offset, - Total: page.Total, - }, + return groupPageRes{ + pageRes: pageRes{ + Limit: page.Limit, + Offset: page.Offset, + Total: page.Total, + }, Groups: groups, }, nil - } } @@ -208,7 +206,6 @@ func retrieveGroupHierarchyEndpoint(svc groups.Service) endpoint.Endpoint { groups = append(groups, toViewGroupRes(g)) } return retrieveGroupHierarchyRes{Level: hp.Level, Direction: hp.Direction, Groups: groups}, nil - } } @@ -339,73 +336,9 @@ func listChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint { } } -func buildGroupsResponseTree(page groups.Page) groupPageRes { - groupsMap := map[string]*groups.Group{} - // Parents' map keeps its array of children. - parentsMap := map[string][]*groups.Group{} - for i := range page.Groups { - if _, ok := groupsMap[page.Groups[i].ID]; !ok { - groupsMap[page.Groups[i].ID] = &page.Groups[i] - parentsMap[page.Groups[i].ID] = make([]*groups.Group, 0) - } - } - - for _, group := range groupsMap { - if children, ok := parentsMap[group.Parent]; ok { - children = append(children, group) - parentsMap[group.Parent] = children - } - } - - res := groupPageRes{ - pageRes: pageRes{ - Limit: page.Limit, - Offset: page.Offset, - Total: page.Total, - }, - Groups: []viewGroupRes{}, - } - - for _, group := range groupsMap { - if children, ok := parentsMap[group.ID]; ok { - group.Children = children - } - } - - for _, group := range groupsMap { - view := toViewGroupRes(*group) - if children, ok := parentsMap[group.Parent]; len(children) == 0 || !ok { - res.Groups = append(res.Groups, view) - } - } - - return res -} - func toViewGroupRes(group groups.Group) viewGroupRes { view := viewGroupRes{ Group: group, } return view } - -func buildGroupsResponse(gp groups.Page, filterByID bool) groupPageRes { - res := groupPageRes{ - pageRes: pageRes{ - Total: gp.Total, - }, - Groups: []viewGroupRes{}, - } - - for _, group := range gp.Groups { - view := viewGroupRes{ - Group: group, - } - if filterByID && group.Level == 0 { - continue - } - res.Groups = append(res.Groups, view) - } - - return res -} diff --git a/groups/api/http/requests.go b/groups/api/http/requests.go index ccb7f11ad1a..d4f0d73a72b 100644 --- a/groups/api/http/requests.go +++ b/groups/api/http/requests.go @@ -46,7 +46,6 @@ type listGroupsReq struct { } func (req listGroupsReq) validate() error { - if req.Limit > api.MaxLimitSize || req.Limit < 1 { return apiutil.ErrLimitSize } @@ -89,6 +88,9 @@ func (req retrieveGroupHierarchyReq) validate() error { if req.Level > groups.MaxLevel { return apiutil.ErrLevel } + if req.id == "" { + return apiutil.ErrMissingID + } return nil } diff --git a/groups/api/http/requests_test.go b/groups/api/http/requests_test.go index 703c0e9ed62..f35cdf53745 100644 --- a/groups/api/http/requests_test.go +++ b/groups/api/http/requests_test.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/magistrala/groups" "github.com/absmach/magistrala/internal/api" + "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/apiutil" "github.com/stretchr/testify/assert" ) @@ -107,15 +108,6 @@ func TestListGroupReqValidation(t *testing.T) { }, err: nil, }, - { - desc: "invalid upper level", - req: listGroupsReq{ - PageMeta: groups.PageMeta{ - Limit: 10, - }, - }, - err: apiutil.ErrInvalidLevel, - }, { desc: "invalid lower limit", req: listGroupsReq{ @@ -195,206 +187,309 @@ func TestChangeGroupStatusReqValidation(t *testing.T) { } } -func Test_createGroupReq_validate(t *testing.T) { - tests := []struct { - name string - req createGroupReq - wantErr bool +func TestRetrieveGroupHierarchyReqValidation(t *testing.T) { + cases := []struct { + desc string + req retrieveGroupHierarchyReq + err error }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("createGroupReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + { + desc: "valid request", + req: retrieveGroupHierarchyReq{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Tree: true, + Level: 1, + Direction: -1, + }, + id: valid, + }, + }, + { + desc: "invalid level", + req: retrieveGroupHierarchyReq{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Tree: true, + Level: groups.MaxLevel + 1, + Direction: -1, + }, + id: valid, + }, + err: apiutil.ErrLevel, + }, + { + desc: "empty id", + req: retrieveGroupHierarchyReq{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Tree: true, + Level: 1, + Direction: -1, + }, + }, + err: apiutil.ErrMissingID, + }, } -} -func Test_updateGroupReq_validate(t *testing.T) { - tests := []struct { - name string - req updateGroupReq - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("updateGroupReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + for _, tc := range cases { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } } -func Test_listGroupsReq_validate(t *testing.T) { - tests := []struct { - name string - req listGroupsReq - wantErr bool +func TestAddParentGroupReqValidation(t *testing.T) { + cases := []struct { + desc string + req addParentGroupReq + err error }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("listGroupsReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + { + desc: "valid request", + req: addParentGroupReq{ + id: testsutil.GenerateUUID(t), + ParentID: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "empty id", + req: addParentGroupReq{ + ParentID: testsutil.GenerateUUID(t), + }, + err: apiutil.ErrMissingID, + }, + { + desc: "empty parent id", + req: addParentGroupReq{ + id: testsutil.GenerateUUID(t), + }, + err: apiutil.ErrInvalidIDFormat, + }, + { + desc: "invalid parent id", + req: addParentGroupReq{ + id: testsutil.GenerateUUID(t), + ParentID: "invalid", + }, + err: apiutil.ErrInvalidIDFormat, + }, + { + desc: "same id", + req: addParentGroupReq{ + id: validID, + ParentID: validID, + }, + err: apiutil.ErrSelfParentingNotAllowed, + }, } -} -func Test_groupReq_validate(t *testing.T) { - tests := []struct { - name string - req groupReq - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("groupReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } -func Test_changeGroupStatusReq_validate(t *testing.T) { - tests := []struct { - name string - req changeGroupStatusReq - wantErr bool +func TestRemoveParentGroupReqValidation(t *testing.T) { + cases := []struct { + desc string + req removeParentGroupReq + err error }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("changeGroupStatusReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + { + desc: "valid request", + req: removeParentGroupReq{ + id: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "empty id", + req: removeParentGroupReq{}, + err: apiutil.ErrMissingID, + }, } -} -func Test_retrieveGroupHierarchyReq_validate(t *testing.T) { - tests := []struct { - name string - req retrieveGroupHierarchyReq - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("retrieveGroupHierarchyReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } -func Test_addParentGroupReq_validate(t *testing.T) { - tests := []struct { - name string - req addParentGroupReq - wantErr bool +func TestAddChildrenGroupsReqValidation(t *testing.T) { + cases := []struct { + desc string + req addChildrenGroupsReq + err error }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("addParentGroupReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + { + desc: "valid request", + req: addChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + ChildrenIDs: []string{testsutil.GenerateUUID(t)}, + }, + err: nil, + }, + { + desc: "empty id", + req: addChildrenGroupsReq{ + ChildrenIDs: []string{testsutil.GenerateUUID(t)}, + }, + err: apiutil.ErrMissingID, + }, + { + desc: "empty children ids", + req: addChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + }, + err: apiutil.ErrMissingChildrenGroupIDs, + }, + { + desc: "invalid child id", + req: addChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + ChildrenIDs: []string{"invalid"}, + }, + err: apiutil.ErrInvalidIDFormat, + }, + { + desc: "self parenting", + req: addChildrenGroupsReq{ + id: validID, + ChildrenIDs: []string{validID, testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)}, + }, + err: apiutil.ErrSelfParentingNotAllowed, + }, } -} -func Test_removeParentGroupReq_validate(t *testing.T) { - tests := []struct { - name string - req removeParentGroupReq - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("removeParentGroupReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } -func Test_addChildrenGroupsReq_validate(t *testing.T) { - tests := []struct { - name string - req addChildrenGroupsReq - wantErr bool +func TestRemoveChildrenGroupsReqValidation(t *testing.T) { + cases := []struct { + desc string + req removeChildrenGroupsReq + err error }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("addChildrenGroupsReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) + { + desc: "valid request", + req: removeChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + ChildrenIDs: []string{testsutil.GenerateUUID(t)}, + }, + err: nil, + }, + { + desc: "empty id", + req: removeChildrenGroupsReq{}, + err: apiutil.ErrMissingID, + }, + { + desc: "empty children ids", + req: removeChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + }, + err: apiutil.ErrMissingChildrenGroupIDs, + }, + { + desc: "invalid child id", + req: removeChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + ChildrenIDs: []string{"invalid"}, + }, + err: apiutil.ErrInvalidIDFormat, + }, } -} -func Test_removeChildrenGroupsReq_validate(t *testing.T) { - tests := []struct { - name string - req removeChildrenGroupsReq - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("removeChildrenGroupsReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } -func Test_removeAllChildrenGroupsReq_validate(t *testing.T) { - tests := []struct { - name string - req removeAllChildrenGroupsReq - wantErr bool +func TestRemoveAllChildrenGroupsReqValidation(t *testing.T) { + cases := []struct { + desc string + req removeAllChildrenGroupsReq + err error }{ - // TODO: Add test cases. + { + desc: "valid request", + req: removeAllChildrenGroupsReq{ + id: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "empty id", + req: removeAllChildrenGroupsReq{}, + err: apiutil.ErrMissingID, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("removeAllChildrenGroupsReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } -func Test_listChildrenGroupsReq_validate(t *testing.T) { - tests := []struct { - name string - req listChildrenGroupsReq - wantErr bool +func TestListChildrenGroupsReqValidation(t *testing.T) { + cases := []struct { + desc string + req listChildrenGroupsReq + err error }{ - // TODO: Add test cases. + { + desc: "valid request", + req: listChildrenGroupsReq{ + id: validID, + PageMeta: groups.PageMeta{ + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "empty id", + req: listChildrenGroupsReq{}, + err: apiutil.ErrMissingID, + }, + { + desc: "invalid lower limit", + req: listChildrenGroupsReq{ + id: validID, + PageMeta: groups.PageMeta{ + Limit: 0, + }, + }, + err: apiutil.ErrLimitSize, + }, + { + desc: "invalid upper limit", + req: listChildrenGroupsReq{ + id: validID, + PageMeta: groups.PageMeta{ + Limit: api.MaxLimitSize + 1, + }, + }, + err: apiutil.ErrLimitSize, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.req.validate(); (err != nil) != tt.wantErr { - t.Errorf("listChildrenGroupsReq.validate() error = %v, wantErr %v", err, tt.wantErr) - } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } } diff --git a/groups/api/http/responses.go b/groups/api/http/responses.go index afe8cb8f490..ee0df55d15b 100644 --- a/groups/api/http/responses.go +++ b/groups/api/http/responses.go @@ -162,8 +162,7 @@ func (res retrieveGroupHierarchyRes) Empty() bool { return false } -type addParentGroupRes struct { -} +type addParentGroupRes struct{} func (res addParentGroupRes) Code() int { return http.StatusNoContent @@ -177,8 +176,7 @@ func (res addParentGroupRes) Empty() bool { return true } -type removeParentGroupRes struct { -} +type removeParentGroupRes struct{} func (res removeParentGroupRes) Code() int { return http.StatusNoContent @@ -192,8 +190,7 @@ func (res removeParentGroupRes) Empty() bool { return true } -type addChildrenGroupsRes struct { -} +type addChildrenGroupsRes struct{} func (res addChildrenGroupsRes) Code() int { return http.StatusNoContent @@ -207,8 +204,7 @@ func (res addChildrenGroupsRes) Empty() bool { return true } -type removeChildrenGroupsRes struct { -} +type removeChildrenGroupsRes struct{} func (res removeChildrenGroupsRes) Code() int { return http.StatusNoContent @@ -222,8 +218,7 @@ func (res removeChildrenGroupsRes) Empty() bool { return true } -type removeAllChildrenGroupsRes struct { -} +type removeAllChildrenGroupsRes struct{} func (res removeAllChildrenGroupsRes) Code() int { return http.StatusNoContent diff --git a/groups/api/http/transport.go b/groups/api/http/transport.go index 98b26c6ff67..ea1dd3212a5 100644 --- a/groups/api/http/transport.go +++ b/groups/api/http/transport.go @@ -1,8 +1,8 @@ -package api - // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 +package api + import ( "log/slog" @@ -99,7 +99,6 @@ func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, l api.EncodeResponse, opts..., ), "remove_parent_group").ServeHTTP) - }) r.Route("/children", func(r chi.Router) { @@ -132,7 +131,6 @@ func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, l ), "list_children_groups").ServeHTTP) }) }) - }) return mux diff --git a/groups/builtinroles.go b/groups/builtinroles.go index c856210f691..fc647d9ebc2 100644 --- a/groups/builtinroles.go +++ b/groups/builtinroles.go @@ -1,3 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + package groups import "github.com/absmach/magistrala/pkg/roles" diff --git a/groups/groups.go b/groups/groups.go index ca9ad61a373..1aac7d19481 100644 --- a/groups/groups.go +++ b/groups/groups.go @@ -74,7 +74,6 @@ type HierarchyPageMeta struct { Direction int64 `json:"direction"` // ancestors (+1) or descendants (-1) // - `true` - result is JSON tree representing groups hierarchy, // - `false` - result is JSON array of groups. - // ToDo: Tree is build in API layer now, not in service layer. This need to be fine tuned. Tree bool `json:"tree"` } type HierarchyPage struct { @@ -114,7 +113,7 @@ type Repository interface { // UnassignParentGroup unassign parent group id fr given group id UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error - UnassignAllChildrenGroup(ctx context.Context, id string) error + UnassignAllChildrenGroups(ctx context.Context, id string) error RetrieveUserGroups(ctx context.Context, domainID, userID string, pm PageMeta) (Page, error) diff --git a/groups/middleware/authorization.go b/groups/middleware/authorization.go index 88b9179119e..bb1778d7ca7 100644 --- a/groups/middleware/authorization.go +++ b/groups/middleware/authorization.go @@ -334,7 +334,6 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context, }); err != nil { return err } - // ToDo: Add session DomainUserID authz check for all children groups return am.svc.RemoveAllChildrenGroups(ctx, session, id) } diff --git a/groups/mocks/repository.go b/groups/mocks/repository.go index d77315040e2..f292d8750e9 100644 --- a/groups/mocks/repository.go +++ b/groups/mocks/repository.go @@ -260,6 +260,34 @@ func (_m *Repository) RetrieveByID(ctx context.Context, id string) (groups.Group return r0, r1 } +// RetrieveByIDAndUser provides a mock function with given fields: ctx, domainID, userID, groupID +func (_m *Repository) RetrieveByIDAndUser(ctx context.Context, domainID string, userID string, groupID string) (groups.Group, error) { + ret := _m.Called(ctx, domainID, userID, groupID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByIDAndUser") + } + + var r0 groups.Group + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (groups.Group, error)); ok { + return rf(ctx, domainID, userID, groupID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) groups.Group); ok { + r0 = rf(ctx, domainID, userID, groupID) + } else { + r0 = ret.Get(0).(groups.Group) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, domainID, userID, groupID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveByIDs provides a mock function with given fields: ctx, pm, ids func (_m *Repository) RetrieveByIDs(ctx context.Context, pm groups.PageMeta, ids ...string) (groups.Page, error) { ret := _m.Called(ctx, pm, ids) @@ -289,7 +317,7 @@ func (_m *Repository) RetrieveByIDs(ctx context.Context, pm groups.PageMeta, ids } // RetrieveChildrenGroups provides a mock function with given fields: ctx, domainID, userID, groupID, startLevel, endLevel, pm -func (_m *Repository) RetrieveChildrenGroups(ctx context.Context, domainID string, userID string, groupID string, startLevel int, endLevel int, pm groups.PageMeta) (groups.Page, error) { +func (_m *Repository) RetrieveChildrenGroups(ctx context.Context, domainID string, userID string, groupID string, startLevel int64, endLevel int64, pm groups.PageMeta) (groups.Page, error) { ret := _m.Called(ctx, domainID, userID, groupID, startLevel, endLevel, pm) if len(ret) == 0 { @@ -298,16 +326,16 @@ func (_m *Repository) RetrieveChildrenGroups(ctx context.Context, domainID strin var r0 groups.Page var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int, int, groups.PageMeta) (groups.Page, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int64, int64, groups.PageMeta) (groups.Page, error)); ok { return rf(ctx, domainID, userID, groupID, startLevel, endLevel, pm) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int, int, groups.PageMeta) groups.Page); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int64, int64, groups.PageMeta) groups.Page); ok { r0 = rf(ctx, domainID, userID, groupID, startLevel, endLevel, pm) } else { r0 = ret.Get(0).(groups.Page) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, int, int, groups.PageMeta) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, int64, int64, groups.PageMeta) error); ok { r1 = rf(ctx, domainID, userID, groupID, startLevel, endLevel, pm) } else { r1 = ret.Error(1) @@ -741,12 +769,12 @@ func (_m *Repository) Save(ctx context.Context, g groups.Group) (groups.Group, e return r0, r1 } -// UnassignAllChildrenGroup provides a mock function with given fields: ctx, id -func (_m *Repository) UnassignAllChildrenGroup(ctx context.Context, id string) error { +// UnassignAllChildrenGroups provides a mock function with given fields: ctx, id +func (_m *Repository) UnassignAllChildrenGroups(ctx context.Context, id string) error { ret := _m.Called(ctx, id) if len(ret) == 0 { - panic("no return value specified for UnassignAllChildrenGroup") + panic("no return value specified for UnassignAllChildrenGroups") } var r0 error diff --git a/groups/mocks/service.go b/groups/mocks/service.go index 6dd47922ef4..e746d1ed00e 100644 --- a/groups/mocks/service.go +++ b/groups/mocks/service.go @@ -217,9 +217,9 @@ func (_m *Service) ListAvailableActions(ctx context.Context, session authn.Sessi return r0, r1 } -// ListChildrenGroups provides a mock function with given fields: ctx, session, id, pm -func (_m *Service) ListChildrenGroups(ctx context.Context, session authn.Session, id string, pm groups.PageMeta) (groups.Page, error) { - ret := _m.Called(ctx, session, id, pm) +// ListChildrenGroups provides a mock function with given fields: ctx, session, id, startLevel, endLevel, pm +func (_m *Service) ListChildrenGroups(ctx context.Context, session authn.Session, id string, startLevel int64, endLevel int64, pm groups.PageMeta) (groups.Page, error) { + ret := _m.Called(ctx, session, id, startLevel, endLevel, pm) if len(ret) == 0 { panic("no return value specified for ListChildrenGroups") @@ -227,17 +227,17 @@ func (_m *Service) ListChildrenGroups(ctx context.Context, session authn.Session var r0 groups.Page var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.PageMeta) (groups.Page, error)); ok { - return rf(ctx, session, id, pm) + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, int64, int64, groups.PageMeta) (groups.Page, error)); ok { + return rf(ctx, session, id, startLevel, endLevel, pm) } - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.PageMeta) groups.Page); ok { - r0 = rf(ctx, session, id, pm) + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, int64, int64, groups.PageMeta) groups.Page); ok { + r0 = rf(ctx, session, id, startLevel, endLevel, pm) } else { r0 = ret.Get(0).(groups.Page) } - if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, groups.PageMeta) error); ok { - r1 = rf(ctx, session, id, pm) + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, int64, int64, groups.PageMeta) error); ok { + r1 = rf(ctx, session, id, startLevel, endLevel, pm) } else { r1 = ret.Error(1) } diff --git a/groups/postgres/groups.go b/groups/postgres/groups.go index 10bb41a1a46..2805aa4f940 100644 --- a/groups/postgres/groups.go +++ b/groups/postgres/groups.go @@ -59,6 +59,7 @@ func (repo groupRepository) Save(ctx context.Context, g mggroups.Group) (mggroup if err != nil { return mggroups.Group{}, err } + row, err := repo.db.NamedQueryContext(ctx, q, dbg) if err != nil { return mggroups.Group{}, postgres.HandleError(repoerr.ErrCreateEntity, err) @@ -297,8 +298,6 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm mggroups.PageM } func (repo groupRepository) RetrieveHierarchy(ctx context.Context, id string, hm mggroups.HierarchyPageMeta) (mggroups.HierarchyPage, error) { - // ToDo : use the query to userGroupsBaseQuery - // repo.userGroupsBaseQuery(domainID, userID) query := "" switch { // ancestors @@ -381,7 +380,6 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID } }() - //ToDo: Move this logic to service layer pq := `SELECT id, path FROM groups WHERE id = $1 LIMIT 1;` rows, err := tx.Queryx(pq, parentGroupID) if err != nil { @@ -458,8 +456,6 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID return nil } -// ToDo: Query need to change to ANY -// ToDo: If parent is changed, then path of all children need to be updated https://patshaughnessy.net/2017/12/14/manipulating-trees-using-sql-and-the-postgres-ltree-extension func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) (err error) { if len(groupIDs) == 0 { return nil @@ -540,19 +536,20 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup return nil } -func (repo groupRepository) UnassignAllChildrenGroup(ctx context.Context, id string) error { - +func (repo groupRepository) UnassignAllChildrenGroups(ctx context.Context, id string) error { query := ` UPDATE groups AS g SET parent_id = NULL - WHERE g.parent = :parent_id ; + WHERE g.parent_id = :parent_id ; ` - row, err := repo.db.NamedQueryContext(ctx, query, dbGroup{ParentID: &id}) + result, err := repo.db.NamedExecContext(ctx, query, dbGroup{ParentID: &id}) if err != nil { return postgres.HandleError(repoerr.ErrUpdateEntity, err) } - defer row.Close() + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } return nil } @@ -616,7 +613,7 @@ func (repo groupRepository) RetrieveChildrenGroups(ctx context.Context, domainID case startLevel > 0 && endLevel > 0 && startLevel < endLevel: levelCondition = fmt.Sprintf(" path ~ '%s.*{%d,%d}'::::lquery ", pGroup.Path, startLevel, endLevel) default: - return mggroups.Page{}, fmt.Errorf("invalid level range: start level: %d end level: %d", startLevel, endLevel) + return mggroups.Page{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid level range: start level: %d end level: %d", startLevel, endLevel)) } switch { @@ -721,6 +718,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID page.Groups = items return page, nil } + func (repo groupRepository) userGroupsBaseQuery(domainID, userID string) string { return fmt.Sprintf(` WITH direct_groups AS ( @@ -823,7 +821,6 @@ func (repo groupRepository) userGroupsBaseQuery(domainID, userID string) string FROM indirect_child_groups )`, userID, domainID, domainID) - } func buildQuery(gm mggroups.PageMeta, ids ...string) string { @@ -994,7 +991,6 @@ func toDBGroupPageMeta(pm mggroups.PageMeta) (dbGroupPageMeta, error) { }, nil } -// ToDo: check and remove field "Level" after new auth stabilize type dbGroupPageMeta struct { ID string `db:"id"` Name string `db:"name"` diff --git a/groups/postgres/groups_test.go b/groups/postgres/groups_test.go index 2c553111d43..39404cf4b4e 100644 --- a/groups/postgres/groups_test.go +++ b/groups/postgres/groups_test.go @@ -16,14 +16,16 @@ import ( "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/errors" repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/absmach/magistrala/pkg/roles" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( - namegen = namegenerator.NewGenerator() - invalidID = strings.Repeat("a", 37) - validGroup = groups.Group{ + namegen = namegenerator.NewGenerator() + invalidID = strings.Repeat("a", 37) + validTimestamp = time.Now().UTC().Truncate(time.Millisecond) + validGroup = groups.Group{ ID: testsutil.GenerateUUID(&testing.T{}), Domain: testsutil.GenerateUUID(&testing.T{}), Name: namegen.Generate(), @@ -32,6 +34,26 @@ var ( CreatedAt: time.Now().UTC().Truncate(time.Microsecond), Status: groups.EnabledStatus, } + directAccess = "direct" + availableActions = []string{ + "update", + "read", + "membership", + "delete", + "subgroup_create", + "subgroup_client_create", + "subgroup_channel_create", + "subgroup_update", + "subgroup_read", + "subgroup_membership", + "subgroup_delete", + "subgroup_set_child", + "subgroup_set_parent", + "subgroup_manage_role", + "subgroup_add_role_users", + "subgroup_remove_role_users", + "subgroup_view_role_users", + } ) func TestSave(t *testing.T) { @@ -40,16 +62,37 @@ func TestSave(t *testing.T) { require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) }) + validGroupRes := validGroup + validGroupRes.Path = validGroup.ID + validGroupRes.Level = 1 + repo := postgres.New(database) + parentGroup := validGroup + parentGroup.ID = testsutil.GenerateUUID(t) + parentGroup.Name = namegen.Generate() + + pgroup, err := repo.Save(context.Background(), parentGroup) + require.Nil(t, err, fmt.Sprintf("save group unexpected error: %s", err)) + + validChildGroup := validGroup + validChildGroup.ID = testsutil.GenerateUUID(t) + validChildGroup.Name = namegen.Generate() + validChildGroup.Parent = pgroup.ID + validChildGroupRes := validChildGroup + validChildGroupRes.Path = fmt.Sprintf("%s.%s", pgroup.Path, validChildGroupRes.ID) + validChildGroupRes.Level = 2 + cases := []struct { desc string group groups.Group + resp groups.Group err error }{ { desc: "add new group successfully", group: validGroup, + resp: validGroupRes, err: nil, }, { @@ -57,6 +100,12 @@ func TestSave(t *testing.T) { group: validGroup, err: repoerr.ErrConflict, }, + { + desc: "add group with parent", + group: validChildGroup, + resp: validChildGroupRes, + err: nil, + }, { desc: "add group with invalid ID", group: groups.Group{ @@ -65,7 +114,7 @@ func TestSave(t *testing.T) { Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, @@ -78,7 +127,7 @@ func TestSave(t *testing.T) { Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, @@ -87,14 +136,14 @@ func TestSave(t *testing.T) { desc: "add group with invalid parent", group: groups.Group{ ID: testsutil.GenerateUUID(t), - Parent: invalidID, + Parent: testsutil.GenerateUUID(t), Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, - err: repoerr.ErrMalformedEntity, + err: repoerr.ErrNotFound, }, { desc: "add group with invalid name", @@ -104,7 +153,7 @@ func TestSave(t *testing.T) { Name: strings.Repeat("a", 1025), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, @@ -117,7 +166,7 @@ func TestSave(t *testing.T) { Name: namegen.Generate(), Description: strings.Repeat("a", 1025), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, @@ -132,31 +181,20 @@ func TestSave(t *testing.T) { Metadata: map[string]interface{}{ "key": make(chan int), }, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, }, { - desc: "add group with empty domain", + desc: "add group with invalid domain", group: groups.Group{ ID: testsutil.GenerateUUID(t), Name: namegen.Generate(), + Domain: invalidID, Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), - Status: groups.EnabledStatus, - }, - err: repoerr.ErrMalformedEntity, - }, - { - desc: "add group with empty name", - group: groups.Group{ - ID: testsutil.GenerateUUID(t), - Domain: testsutil.GenerateUUID(t), - Description: strings.Repeat("a", 64), - Metadata: map[string]interface{}{"key": "value"}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, }, err: repoerr.ErrMalformedEntity, @@ -164,13 +202,11 @@ func TestSave(t *testing.T) { } for _, tc := range cases { - switch group, err := repo.Save(context.Background(), tc.group); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.group, group, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group, group)) - default: + t.Run(tc.desc, func(t *testing.T) { + group, err := repo.Save(context.Background(), tc.group) + assert.Equal(t, tc.resp, group, fmt.Sprintf("%s: expected %v got %+v\n", tc.desc, tc.resp, group)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + }) } } @@ -186,71 +222,78 @@ func TestUpdate(t *testing.T) { require.Nil(t, err, fmt.Sprintf("save group unexpected error: %s", err)) cases := []struct { - desc string - group groups.Group - err error + desc string + update string + group groups.Group + err error }{ { - desc: "update group successfully", + desc: "update group successfully", + update: "all", group: groups.Group{ ID: group.ID, Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: nil, }, { - desc: "update group name", + desc: "update group name", + update: "name", group: groups.Group{ ID: group.ID, Name: namegen.Generate(), - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: nil, }, { - desc: "update group description", + desc: "update group description", + update: "description", group: groups.Group{ ID: group.ID, - Description: strings.Repeat("a", 64), - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + Description: strings.Repeat("b", 64), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: nil, }, { - desc: "update group metadata", + desc: "update group metadata", + update: "metadata", group: groups.Group{ ID: group.ID, - Metadata: map[string]interface{}{"key": "value"}, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + Metadata: map[string]interface{}{"key1": "value1"}, + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: nil, }, { - desc: "update group with invalid ID", + desc: "update group with invalid ID", + update: "all", group: groups.Group{ ID: testsutil.GenerateUUID(t), Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: repoerr.ErrNotFound, }, { - desc: "update group with empty ID", + desc: "update group with empty ID", + update: "all", group: groups.Group{ Name: namegen.Generate(), Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"key": "value"}, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: repoerr.ErrNotFound, @@ -258,15 +301,27 @@ func TestUpdate(t *testing.T) { } for _, tc := range cases { - switch group, err := repo.Update(context.Background(), tc.group); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.group.ID, group.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.ID, group.ID)) - assert.Equal(t, tc.group.UpdatedAt, group.UpdatedAt, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedAt, group.UpdatedAt)) - assert.Equal(t, tc.group.UpdatedBy, group.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedBy, group.UpdatedBy)) - default: + t.Run(tc.desc, func(t *testing.T) { + group, err := repo.Update(context.Background(), tc.group) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + if err == nil { + assert.Equal(t, tc.group.ID, group.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.ID, group.ID)) + assert.Equal(t, tc.group.UpdatedAt, group.UpdatedAt, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedAt, group.UpdatedAt)) + assert.Equal(t, tc.group.UpdatedBy, group.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedBy, group.UpdatedBy)) + switch tc.update { + case "all": + assert.Equal(t, tc.group.Name, group.Name, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Name, group.Name)) + assert.Equal(t, tc.group.Description, group.Description, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Description, group.Description)) + assert.Equal(t, tc.group.Metadata, group.Metadata, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Metadata, group.Metadata)) + case "name": + assert.Equal(t, tc.group.Name, group.Name, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Name, group.Name)) + case "description": + assert.Equal(t, tc.group.Description, group.Description, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Description, group.Description)) + case "metadata": + assert.Equal(t, tc.group.Metadata, group.Metadata, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Metadata, group.Metadata)) + } + } + }) } } @@ -291,7 +346,7 @@ func TestChangeStatus(t *testing.T) { group: groups.Group{ ID: group.ID, Status: groups.DisabledStatus, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: nil, @@ -301,7 +356,7 @@ func TestChangeStatus(t *testing.T) { group: groups.Group{ ID: testsutil.GenerateUUID(t), Status: groups.DisabledStatus, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: repoerr.ErrNotFound, @@ -310,7 +365,7 @@ func TestChangeStatus(t *testing.T) { desc: "change status group with empty ID", group: groups.Group{ Status: groups.DisabledStatus, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedAt: validTimestamp, UpdatedBy: testsutil.GenerateUUID(t), }, err: repoerr.ErrNotFound, @@ -318,15 +373,16 @@ func TestChangeStatus(t *testing.T) { } for _, tc := range cases { - switch group, err := repo.ChangeStatus(context.Background(), tc.group); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.group.ID, group.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.ID, group.ID)) - assert.Equal(t, tc.group.UpdatedAt, group.UpdatedAt, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedAt, group.UpdatedAt)) - assert.Equal(t, tc.group.UpdatedBy, group.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedBy, group.UpdatedBy)) - default: + t.Run(tc.desc, func(t *testing.T) { + group, err := repo.ChangeStatus(context.Background(), tc.group) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + if err == nil { + assert.Equal(t, tc.group.ID, group.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.ID, group.ID)) + assert.Equal(t, tc.group.UpdatedAt, group.UpdatedAt, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedAt, group.UpdatedAt)) + assert.Equal(t, tc.group.UpdatedBy, group.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedBy, group.UpdatedBy)) + assert.Equal(t, tc.group.Status, group.Status, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Status, group.Status)) + } + }) } } @@ -338,6 +394,9 @@ func TestRetrieveByID(t *testing.T) { repo := postgres.New(database) + validGroupRes := validGroup + validGroupRes.Path = validGroup.ID + group, err := repo.Save(context.Background(), validGroup) require.Nil(t, err, fmt.Sprintf("save group unexpected error: %s", err)) @@ -345,12 +404,14 @@ func TestRetrieveByID(t *testing.T) { desc string id string group groups.Group + resp groups.Group err error }{ { desc: "retrieve group by id successfully", id: group.ID, group: validGroup, + resp: validGroupRes, err: nil, }, { @@ -368,13 +429,143 @@ func TestRetrieveByID(t *testing.T) { } for _, tc := range cases { - switch group, err := repo.RetrieveByID(context.Background(), tc.id); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.group, group, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group, group)) - default: + t.Run(tc.desc, func(t *testing.T) { + group, err := repo.RetrieveByID(context.Background(), tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.resp, group, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group, group)) + } + }) + } +} + +func TestRetrieveByIDAndUser(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM groups") + require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) + }) + + repo := postgres.New(database) + + domainID := testsutil.GenerateUUID(t) + userID := testsutil.GenerateUUID(t) + num := 10 + items := []groups.Group{} + for i := 0; i < num; i++ { + name := namegen.Generate() + group := groups.Group{ + ID: testsutil.GenerateUUID(t), + Domain: domainID, + Name: name, + Description: strings.Repeat("a", 64), + Metadata: map[string]interface{}{"name": name}, + CreatedAt: validTimestamp, + Status: groups.EnabledStatus, + } + grp, err := repo.Save(context.Background(), group) + require.Nil(t, err, fmt.Sprintf("create group unexpected error: %s", err)) + newRolesProvision := []roles.RoleProvision{ + { + Role: roles.Role{ + ID: testsutil.GenerateUUID(t) + "_" + grp.ID, + Name: "admin", + EntityID: grp.ID, + CreatedAt: validTimestamp, + CreatedBy: userID, + }, + OptionalActions: availableActions, + OptionalMembers: []string{userID}, + }, } + _, err = repo.AddRoles(context.Background(), newRolesProvision) + require.Nil(t, err, fmt.Sprintf("add roles unexpected error: %s", err)) + ngrp := grp + ngrp.RoleID = newRolesProvision[0].Role.ID + ngrp.RoleName = newRolesProvision[0].Role.Name + ngrp.AccessType = directAccess + items = append(items, ngrp) + } + + cases := []struct { + desc string + groupID string + userID string + domainID string + resp groups.Group + err error + }{ + { + desc: "retrieve group by id and user successfully", + groupID: items[0].ID, + userID: userID, + domainID: domainID, + resp: items[0], + err: nil, + }, + { + desc: "retrieve group by id and user successfully", + groupID: items[5].ID, + userID: userID, + domainID: domainID, + resp: items[5], + err: nil, + }, + { + desc: "retrieve group by id and user with invalid group ID", + groupID: invalidID, + userID: userID, + domainID: domainID, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group by id and user with empty group ID", + groupID: "", + userID: userID, + domainID: domainID, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group by id and user with invalid user ID", + groupID: items[0].ID, + userID: testsutil.GenerateUUID(t), + domainID: domainID, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group by id and user with empty user ID", + groupID: items[0].ID, + userID: "", + domainID: domainID, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group by id and user with invalid domain ID", + groupID: items[0].ID, + userID: userID, + domainID: invalidID, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group by id and user with empty domain ID", + groupID: items[0].ID, + userID: userID, + domainID: "", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + group, err := repo.RetrieveByIDAndUser(context.Background(), tc.domainID, tc.userID, tc.groupID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + group.Actions = nil + group.Level = 1 + group.AccessProviderRoleActions = nil + assert.Equal(t, tc.resp, group, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, group)) + } + }) } } @@ -402,9 +593,11 @@ func TestRetrieveAll(t *testing.T) { Status: groups.EnabledStatus, } _, err := repo.Save(context.Background(), group) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + require.Nil(t, err, fmt.Sprintf("create group unexpected error: %s", err)) items = append(items, group) - parentID = group.ID + if i%20 == 0 { + parentID = group.ID + } } cases := []struct { @@ -630,59 +823,23 @@ func TestRetrieveAll(t *testing.T) { }, err: errors.ErrMalformedEntity, }, - { - desc: "retrieve parent groups", - page: groups.Page{ - PageMeta: groups.PageMeta{ - Offset: 0, - Limit: uint64(num), - }, - }, - response: groups.Page{ - PageMeta: groups.PageMeta{ - Total: uint64(num), - Offset: 0, - Limit: uint64(num), - }, - Groups: items[:6], - }, - err: nil, - }, - { - desc: "retrieve children groups", - page: groups.Page{ - PageMeta: groups.PageMeta{ - Offset: 0, - Limit: uint64(num), - }, - }, - response: groups.Page{ - PageMeta: groups.PageMeta{ - Total: uint64(num), - Offset: 0, - Limit: uint64(num), - }, - Groups: items[150:], - }, - err: nil, - }, } for _, tc := range cases { - switch groups, err := repo.RetrieveAll(context.Background(), tc.page.PageMeta); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.response.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, groups.Total)) - assert.Equal(t, tc.response.Limit, groups.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, groups.Limit)) - assert.Equal(t, tc.response.Offset, groups.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, groups.Offset)) - for i := range tc.response.Groups { - tc.response.Groups[i].Level = groups.Groups[i].Level - tc.response.Groups[i].Path = groups.Groups[i].Path + t.Run(tc.desc, func(t *testing.T) { + switch groups, err := repo.RetrieveAll(context.Background(), tc.page.PageMeta); { + case err == nil: + assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.response.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, groups.Total)) + assert.Equal(t, tc.response.Limit, groups.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, groups.Limit)) + assert.Equal(t, tc.response.Offset, groups.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, groups.Offset)) + got := stripGroupDetails(groups.Groups) + resp := stripGroupDetails(tc.response.Groups) + assert.ElementsMatch(t, resp, got, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, resp, got)) + default: + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } - assert.ElementsMatch(t, groups.Groups, tc.response.Groups, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.response.Groups, groups.Groups)) - default: - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + }) } } @@ -712,7 +869,9 @@ func TestRetrieveByIDs(t *testing.T) { _, err := repo.Save(context.Background(), group) require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) items = append(items, group) - parentID = group.ID + if i%20 == 0 { + parentID = group.ID + } } cases := []struct { @@ -937,61 +1096,21 @@ func TestRetrieveByIDs(t *testing.T) { }, err: errors.ErrMalformedEntity, }, - { - desc: "retrieve parent groups", - page: groups.Page{ - PageMeta: groups.PageMeta{ - Offset: 0, - Limit: uint64(num), - }, - }, - ids: getIDs(items[0:20]), - response: groups.Page{ - PageMeta: groups.PageMeta{ - Total: 20, - Offset: 0, - Limit: uint64(num), - }, - Groups: items[:6], - }, - err: nil, - }, - { - desc: "retrieve children groups", - page: groups.Page{ - PageMeta: groups.PageMeta{ - Offset: 0, - Limit: uint64(num), - }, - }, - ids: getIDs(items[0:20]), - response: groups.Page{ - PageMeta: groups.PageMeta{ - Total: 20, - Offset: 0, - Limit: uint64(num), - }, - Groups: items[15:20], - }, - err: nil, - }, } for _, tc := range cases { - switch groups, err := repo.RetrieveByIDs(context.Background(), tc.page.PageMeta, tc.ids...); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.response.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, groups.Total)) - assert.Equal(t, tc.response.Limit, groups.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, groups.Limit)) - assert.Equal(t, tc.response.Offset, groups.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, groups.Offset)) - for i := range tc.response.Groups { - tc.response.Groups[i].Level = groups.Groups[i].Level - tc.response.Groups[i].Path = groups.Groups[i].Path - } - assert.ElementsMatch(t, groups.Groups, tc.response.Groups, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.response.Groups, groups.Groups)) - default: + t.Run(tc.desc, func(t *testing.T) { + groups, err := repo.RetrieveByIDs(context.Background(), tc.page.PageMeta, tc.ids...) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + if err == nil { + assert.Equal(t, tc.response.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, groups.Total)) + assert.Equal(t, tc.response.Limit, groups.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, groups.Limit)) + assert.Equal(t, tc.response.Offset, groups.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, groups.Offset)) + got := stripGroupDetails(groups.Groups) + resp := stripGroupDetails(tc.response.Groups) + assert.ElementsMatch(t, resp, got, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, resp, got)) + } + }) } } @@ -1029,12 +1148,10 @@ func TestDelete(t *testing.T) { } for _, tc := range cases { - switch err := repo.Delete(context.Background(), tc.id); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - default: + t.Run(tc.desc, func(t *testing.T) { + err := repo.Delete(context.Background(), tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + }) } } @@ -1049,23 +1166,20 @@ func TestAssignParentGroup(t *testing.T) { num := 10 var items []groups.Group - parentID := "" for i := 0; i < num; i++ { name := namegen.Generate() group := groups.Group{ ID: testsutil.GenerateUUID(t), Domain: testsutil.GenerateUUID(t), - Parent: parentID, Name: name, Description: strings.Repeat("a", 64), Metadata: map[string]interface{}{"name": name}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedAt: validTimestamp, Status: groups.EnabledStatus, } _, err := repo.Save(context.Background(), group) require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) items = append(items, group) - parentID = group.ID } cases := []struct { @@ -1084,13 +1198,13 @@ func TestAssignParentGroup(t *testing.T) { desc: "assign parent group with invalid ID", id: testsutil.GenerateUUID(t), ids: []string{items[1].ID, items[2].ID, items[3].ID, items[4].ID, items[5].ID}, - err: repoerr.ErrCreateEntity, + err: repoerr.ErrUpdateEntity, }, { desc: "assign parent group with empty ID", id: "", ids: []string{items[1].ID, items[2].ID, items[3].ID, items[4].ID, items[5].ID}, - err: repoerr.ErrCreateEntity, + err: repoerr.ErrUpdateEntity, }, { desc: "assign parent group with invalid group IDs", @@ -1107,12 +1221,10 @@ func TestAssignParentGroup(t *testing.T) { } for _, tc := range cases { - switch err := repo.AssignParentGroup(context.Background(), tc.id, tc.ids...); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - default: + t.Run(tc.desc, func(t *testing.T) { + err := repo.AssignParentGroup(context.Background(), tc.id, tc.ids...) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } + }) } } @@ -1143,7 +1255,9 @@ func TestUnassignParentGroup(t *testing.T) { _, err := repo.Save(context.Background(), group) require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) items = append(items, group) - parentID = group.ID + if i == 0 { + parentID = group.ID + } } cases := []struct { @@ -1162,13 +1276,13 @@ func TestUnassignParentGroup(t *testing.T) { desc: "un-assign parent group with invalid ID", id: testsutil.GenerateUUID(t), ids: []string{items[1].ID, items[2].ID, items[3].ID, items[4].ID, items[5].ID}, - err: repoerr.ErrCreateEntity, + err: repoerr.ErrUpdateEntity, }, { desc: "un-assign parent group with empty ID", id: "", ids: []string{items[1].ID, items[2].ID, items[3].ID, items[4].ID, items[5].ID}, - err: repoerr.ErrCreateEntity, + err: repoerr.ErrUpdateEntity, }, { desc: "un-assign parent group with invalid group IDs", @@ -1185,12 +1299,598 @@ func TestUnassignParentGroup(t *testing.T) { } for _, tc := range cases { - switch err := repo.UnassignParentGroup(context.Background(), tc.id, tc.ids...); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - default: + t.Run(tc.desc, func(t *testing.T) { + err := repo.UnassignParentGroup(context.Background(), tc.id, tc.ids...) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestUnassignAllChildrenGroups(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM groups") + require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) + }) + + repo := postgres.New(database) + + num := 10 + + var items []groups.Group + parentID := "" + for i := 0; i < num; i++ { + name := namegen.Generate() + group := groups.Group{ + ID: testsutil.GenerateUUID(t), + Domain: testsutil.GenerateUUID(t), + Parent: parentID, + Name: name, + Description: strings.Repeat("a", 64), + Metadata: map[string]interface{}{"name": name}, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + Status: groups.EnabledStatus, + } + _, err := repo.Save(context.Background(), group) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + items = append(items, group) + if i == 0 { + parentID = group.ID + } + } + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "un-assign all children groups successfully", + id: items[0].ID, + err: nil, + }, + { + desc: "un-assign all children groups with invalid ID", + id: testsutil.GenerateUUID(t), + err: repoerr.ErrNotFound, + }, + { + desc: "un-assign all children groups with empty ID", + id: "", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.UnassignAllChildrenGroups(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestRetrieveHierarchy(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM groups") + require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) + }) + + repo := postgres.New(database) + + num := 10 + + var items []groups.Group + parentID := "" + for i := 0; i < num; i++ { + name := namegen.Generate() + group := groups.Group{ + ID: testsutil.GenerateUUID(t), + Domain: testsutil.GenerateUUID(t), + Parent: parentID, + Name: name, + Description: strings.Repeat("a", 64), + Metadata: map[string]interface{}{"name": name}, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + Status: groups.EnabledStatus, + } + _, err := repo.Save(context.Background(), group) + require.Nil(t, err, fmt.Sprintf("create group unexpected error: %s", err)) + items = append(items, group) + if i == 0 { + parentID = group.ID + } + } + + cases := []struct { + desc string + id string + hm groups.HierarchyPageMeta + resp groups.HierarchyPage + err error + }{ + { + desc: "retrieve ancestors successfully", + id: items[1].ID, + hm: groups.HierarchyPageMeta{ + Level: 1, + Direction: +1, + Tree: false, + }, + resp: groups.HierarchyPage{ + Groups: []groups.Group{items[0], items[1]}, + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: +1, + Tree: false, + }, + }, + err: nil, + }, + { + desc: "retrieve descendants successfully", + id: items[0].ID, + hm: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + resp: groups.HierarchyPage{ + Groups: items, + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + }, + err: nil, + }, + { + desc: "retrieve hierarchy with invalid ID", + id: testsutil.GenerateUUID(t), + err: nil, + }, + { + desc: "retrieve hierarchy with empty ID", + id: "", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + gpPage, err := repo.RetrieveHierarchy(context.Background(), tc.id, tc.hm) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + got := stripGroupDetails(gpPage.Groups) + resp := stripGroupDetails(tc.resp.Groups) + assert.ElementsMatch(t, resp, got, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, resp, got)) + } + }) + } +} + +func TestRetrieveAllParentGroups(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM groups") + require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) + }) + + repo := postgres.New(database) + + parentID := "" + domainID := testsutil.GenerateUUID(t) + userID := testsutil.GenerateUUID(t) + num := 10 + halfindex := num/2 + 1 + items := []groups.Group{} + for i := 0; i < num; i++ { + name := namegen.Generate() + group := groups.Group{ + ID: testsutil.GenerateUUID(t), + Domain: domainID, + Name: name, + Parent: parentID, + Description: strings.Repeat("a", 64), + Metadata: map[string]interface{}{"name": name}, + CreatedAt: validTimestamp, + Status: groups.EnabledStatus, + } + grp, err := repo.Save(context.Background(), group) + require.Nil(t, err, fmt.Sprintf("create group unexpected error: %s", err)) + parentID = grp.ID + newRolesProvision := []roles.RoleProvision{ + { + Role: roles.Role{ + ID: testsutil.GenerateUUID(t) + "_" + grp.ID, + Name: "admin", + EntityID: grp.ID, + CreatedAt: validTimestamp, + CreatedBy: userID, + }, + OptionalActions: availableActions, + OptionalMembers: []string{userID}, + }, } + _, err = repo.AddRoles(context.Background(), newRolesProvision) + require.Nil(t, err, fmt.Sprintf("add roles unexpected error: %s", err)) + ngrp := grp + ngrp.RoleID = newRolesProvision[0].Role.ID + ngrp.RoleName = newRolesProvision[0].Role.Name + ngrp.AccessType = directAccess + items = append(items, ngrp) + } + + cases := []struct { + desc string + id string + domainID string + userID string + pageMeta groups.PageMeta + resp groups.Page + err error + }{ + { + desc: "retrieve all parent groups successfully", + id: items[num-1].ID, + domainID: domainID, + userID: userID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: uint64(num), + }, + Groups: items, + }, + err: nil, + }, + { + desc: "retrieve half of all parent groups successfully", + id: items[num/2].ID, + domainID: domainID, + userID: userID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: uint64(halfindex), + }, + Groups: items[:halfindex], + }, + err: nil, + }, + { + desc: "retrieve all parent groups with invalid group ID", + id: testsutil.GenerateUUID(t), + domainID: domainID, + userID: userID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve all parent groups with empty group ID", + id: "", + domainID: domainID, + userID: userID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve all parent groups with invalid domain ID", + id: items[num-1].ID, + domainID: testsutil.GenerateUUID(t), + userID: userID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: nil, + }, + { + desc: "retrieve all parent groups with invalid user ID", + id: items[num-1].ID, + domainID: domainID, + userID: testsutil.GenerateUUID(t), + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + groups, err := repo.RetrieveAllParentGroups(context.Background(), tc.domainID, tc.userID, tc.id, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.resp.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.resp.Total, groups.Total)) + got := stripGroupDetails(groups.Groups) + resp := stripGroupDetails(tc.resp.Groups) + assert.ElementsMatch(t, resp, got, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, resp, got)) + } + }) + } +} + +func TestRetrieveChildrenGroups(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM groups") + require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err)) + }) + + repo := postgres.New(database) + + parentID := "" + domainID := testsutil.GenerateUUID(t) + userID := testsutil.GenerateUUID(t) + num := 10 + items := []groups.Group{} + for i := 0; i < num; i++ { + name := namegen.Generate() + group := groups.Group{ + ID: testsutil.GenerateUUID(t), + Domain: domainID, + Name: name, + Parent: parentID, + Description: strings.Repeat("a", 64), + Metadata: map[string]interface{}{"name": name}, + CreatedAt: validTimestamp, + Status: groups.EnabledStatus, + } + grp, err := repo.Save(context.Background(), group) + require.Nil(t, err, fmt.Sprintf("create group unexpected error: %s", err)) + parentID = grp.ID + newRolesProvision := []roles.RoleProvision{ + { + Role: roles.Role{ + ID: testsutil.GenerateUUID(t) + "_" + grp.ID, + Name: "admin", + EntityID: grp.ID, + CreatedAt: validTimestamp, + CreatedBy: userID, + }, + OptionalActions: availableActions, + OptionalMembers: []string{userID}, + }, + } + _, err = repo.AddRoles(context.Background(), newRolesProvision) + require.Nil(t, err, fmt.Sprintf("add roles unexpected error: %s", err)) + ngrp := grp + ngrp.RoleID = newRolesProvision[0].Role.ID + ngrp.RoleName = newRolesProvision[0].Role.Name + ngrp.AccessType = directAccess + items = append(items, ngrp) + } + + cases := []struct { + desc string + id string + domainID string + userID string + startLevel int64 + endLevel int64 + pageMeta groups.PageMeta + resp groups.Page + err error + }{ + { + desc: "retrieve children groups from parent group level successfully", + id: items[0].ID, + domainID: domainID, + userID: userID, + startLevel: 0, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: uint64(num), + }, + Groups: items, + }, + err: nil, + }, + { + desc: "Retrieve specific level of children groups from parent group level", + id: items[0].ID, + domainID: domainID, + userID: userID, + startLevel: 1, + endLevel: 1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 1, + }, + Groups: []groups.Group{items[1]}, + }, + err: nil, + }, + { + desc: "Retrieve all children groups from specific level from parent group level", + id: items[0].ID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + domainID: domainID, + userID: userID, + startLevel: 2, + endLevel: -1, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 8, + }, + Groups: items[2:], + }, + err: nil, + }, + { + desc: "Retrieve all children groups from specific level to specific level from parent group level", + id: items[0].ID, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + domainID: domainID, + userID: userID, + startLevel: 1, + endLevel: 2, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 2, + }, + Groups: items[1:3], + }, + err: nil, + }, + { + desc: "Retrieve all children groups with invalid group ID", + id: testsutil.GenerateUUID(t), + domainID: domainID, + userID: userID, + startLevel: 0, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "Retrieve all children groups with empty group ID", + id: "", + domainID: domainID, + userID: userID, + startLevel: 0, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "Retrieve all children groups with invalid domain ID", + id: items[0].ID, + domainID: testsutil.GenerateUUID(t), + userID: userID, + startLevel: 0, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: nil, + }, + { + desc: "Retrieve all children groups with invalid user ID", + id: items[0].ID, + domainID: domainID, + userID: testsutil.GenerateUUID(t), + startLevel: 0, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: nil, + }, + { + desc: "Retrieve all children groups with invalid start level", + id: items[0].ID, + domainID: domainID, + userID: userID, + startLevel: -1, + endLevel: -1, + pageMeta: groups.PageMeta{ + Offset: 0, + Limit: 20, + }, + resp: groups.Page{ + PageMeta: groups.PageMeta{ + Total: 0, + }, + Groups: []groups.Group(nil), + }, + err: repoerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + groups, err := repo.RetrieveChildrenGroups(context.Background(), tc.domainID, tc.userID, tc.id, tc.startLevel, tc.endLevel, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.resp.Total, groups.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.resp.Total, groups.Total)) + got := stripGroupDetails(groups.Groups) + resp := stripGroupDetails(tc.resp.Groups) + assert.ElementsMatch(t, resp, got, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, resp, got)) + } + }) } } @@ -1202,3 +1902,15 @@ func getIDs(groups []groups.Group) []string { return ids } + +func stripGroupDetails(groups []groups.Group) []groups.Group { + for i := range groups { + groups[i].Level = 0 + groups[i].Path = "" + groups[i].CreatedAt = validTimestamp + groups[i].Actions = nil + groups[i].AccessProviderRoleActions = nil + } + + return groups +} diff --git a/groups/postgres/setup_test.go b/groups/postgres/setup_test.go index 716255df8fa..e243bb4e43b 100644 --- a/groups/postgres/setup_test.go +++ b/groups/postgres/setup_test.go @@ -76,12 +76,11 @@ func TestMain(m *testing.M) { SSLRootCert: "", } - mig, err := gpostgres.Migration() + gmig, err := gpostgres.Migration() if err != nil { log.Fatalf("Could not get groups migration : %s", err) - } - if db, err = pgclient.Setup(dbConfig, *mig); err != nil { + if db, err = pgclient.Setup(dbConfig, *gmig); err != nil { log.Fatalf("Could not setup test DB connection: %s", err) } diff --git a/groups/private/service.go b/groups/private/service.go index 65ecafbd617..b02d512328c 100644 --- a/groups/private/service.go +++ b/groups/private/service.go @@ -1,3 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + package private import ( diff --git a/groups/roleoperations.go b/groups/roleoperations.go index 6b72cea7ef5..af66365b471 100644 --- a/groups/roleoperations.go +++ b/groups/roleoperations.go @@ -1,3 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + package groups import ( @@ -56,7 +59,7 @@ func NewOperationPerm() svcutil.OperationPerm { return svcutil.NewOperationPerm(expectedOperations, operationNames) } -// External Operations +// External Operations. const ( DomainOpCreateGroup svcutil.ExternalOperation = iota DomainOpListGroups @@ -72,6 +75,7 @@ var expectedExternalOperations = []svcutil.ExternalOperation{ ClientsOpListGroups, ChannelsOpListGroups, } + var externalOperationNames = []string{ "DomainOpCreateGroup", "DomainOpListGroups", @@ -139,8 +143,7 @@ func NewRolesOperationPermissionMap() map[svcutil.Operation]svcutil.Permission { } const ( - // External Permission - // Domains + // External Permissions for the domain. domainCreateGroupPermission = "channel_create_permission" domainListGroupPermission = "membership_permission" userListGroupsPermission = "membership_permission" diff --git a/groups/service.go b/groups/service.go index 92322abae7e..d22b8135531 100644 --- a/groups/service.go +++ b/groups/service.go @@ -10,7 +10,6 @@ import ( "github.com/absmach/magistrala" grpcChannelsV1 "github.com/absmach/magistrala/internal/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/magistrala/internal/grpc/clients/v1" "github.com/absmach/magistrala/pkg/apiutil" mgauthn "github.com/absmach/magistrala/pkg/authn" @@ -20,7 +19,7 @@ import ( "github.com/absmach/magistrala/pkg/roles" ) -var errGroupIDs = errors.New("invalid group ids") +var ErrGroupIDs = errors.New("invalid group ids") type service struct { repo Repository @@ -129,7 +128,6 @@ func (svc service) ListGroups(ctx context.Context, session mgauthn.Session, gm P } return page, nil } - } func (svc service) ListUserGroups(ctx context.Context, session mgauthn.Session, userID string, pm PageMeta) (Page, error) { @@ -144,7 +142,11 @@ func (svc service) UpdateGroup(ctx context.Context, session mgauthn.Session, g G g.UpdatedAt = time.Now() g.UpdatedBy = session.UserID - return svc.repo.Update(ctx, g) + group, err := svc.repo.Update(ctx, g) + if err != nil { + return Group{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + return group, nil } func (svc service) EnableGroup(ctx context.Context, session mgauthn.Session, id string) (Group, error) { @@ -227,7 +229,6 @@ func (svc service) AddParentGroup(ctx context.Context, session mgauthn.Session, return errors.Wrap(svcerr.ErrViewEntity, err) } - //ToDo: Move parent group check business logic from Repository.AssignParentGroup to here var pols []policies.Policy if group.Parent != "" { return errors.Wrap(svcerr.ErrConflict, fmt.Errorf("%s group already have parent", group.ID)) @@ -265,7 +266,6 @@ func (svc service) RemoveParentGroup(ctx context.Context, session mgauthn.Sessio } if group.Parent != "" { - var pols []policies.Policy pols = append(pols, policies.Policy{ Domain: session.DomainID, @@ -286,8 +286,11 @@ func (svc service) RemoveParentGroup(ctx context.Context, session mgauthn.Sessio } } }() + if err := svc.repo.UnassignParentGroup(ctx, group.Parent, group.ID); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } - return svc.repo.UnassignParentGroup(ctx, group.Parent, group.ID) + return nil } return nil @@ -299,7 +302,7 @@ func (svc service) AddChildrenGroups(ctx context.Context, session mgauthn.Sessio return errors.Wrap(svcerr.ErrViewEntity, err) } if len(childrenGroupsPage.Groups) == 0 { - return errGroupIDs + return ErrGroupIDs } for _, childGroup := range childrenGroupsPage.Groups { @@ -330,8 +333,11 @@ func (svc service) AddChildrenGroups(ctx context.Context, session mgauthn.Sessio } } }() + if err = svc.repo.AssignParentGroup(ctx, parentGroupID, childrenGroupIDs...); err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } - return svc.repo.AssignParentGroup(ctx, parentGroupID, childrenGroupIDs...) + return nil } func (svc service) RemoveChildrenGroups(ctx context.Context, session mgauthn.Session, parentGroupID string, childrenGroupIDs []string) (retErr error) { @@ -340,7 +346,7 @@ func (svc service) RemoveChildrenGroups(ctx context.Context, session mgauthn.Ses return errors.Wrap(svcerr.ErrViewEntity, err) } if len(childrenGroupsPage.Groups) == 0 { - return errGroupIDs + return ErrGroupIDs } var pols []policies.Policy @@ -369,8 +375,11 @@ func (svc service) RemoveChildrenGroups(ctx context.Context, session mgauthn.Ses } } }() + if err := svc.repo.UnassignParentGroup(ctx, parentGroupID, childrenGroupIDs...); err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } - return svc.repo.UnassignParentGroup(ctx, parentGroupID, childrenGroupIDs...) + return nil } func (svc service) RemoveAllChildrenGroups(ctx context.Context, session mgauthn.Session, id string) error { @@ -385,8 +394,11 @@ func (svc service) RemoveAllChildrenGroups(ctx context.Context, session mgauthn. if err := svc.policy.DeletePolicyFilter(ctx, pol); err != nil { return errors.Wrap(svcerr.ErrDeletePolicies, err) } + if err := svc.repo.UnassignAllChildrenGroups(ctx, id); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } - return svc.repo.UnassignAllChildrenGroup(ctx, id) + return nil } func (svc service) ListChildrenGroups(ctx context.Context, session mgauthn.Session, id string, startLevel, endLevel int64, pm PageMeta) (Page, error) { diff --git a/groups/service_test.go b/groups/service_test.go index 365aa8f7678..ea6695c2be1 100644 --- a/groups/service_test.go +++ b/groups/service_test.go @@ -2,3 +2,1280 @@ // SPDX-License-Identifier: Apache-2.0 package groups_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + chmocks "github.com/absmach/magistrala/channels/mocks" + climocks "github.com/absmach/magistrala/clients/mocks" + "github.com/absmach/magistrala/groups" + "github.com/absmach/magistrala/groups/mocks" + grpcChannelsV1 "github.com/absmach/magistrala/internal/grpc/channels/v1" + grpcClientsV1 "github.com/absmach/magistrala/internal/grpc/clients/v1" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/apiutil" + "github.com/absmach/magistrala/pkg/authn" + mgauthn "github.com/absmach/magistrala/pkg/authn" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + svcerr "github.com/absmach/magistrala/pkg/errors/service" + policysvc "github.com/absmach/magistrala/pkg/policies" + policymocks "github.com/absmach/magistrala/pkg/policies/mocks" + "github.com/absmach/magistrala/pkg/roles" + "github.com/absmach/magistrala/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + idProvider = uuid.New() + namegen = namegenerator.NewGenerator() + validGroup = groups.Group{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: namegen.Generate(), + Description: namegen.Generate(), + Metadata: map[string]interface{}{ + "key": "value", + }, + Status: groups.EnabledStatus, + } + parentGroupID = testsutil.GenerateUUID(&testing.T{}) + childGroupID = testsutil.GenerateUUID(&testing.T{}) + childGroup = groups.Group{ + ID: childGroupID, + Name: namegen.Generate(), + Description: namegen.Generate(), + Metadata: map[string]interface{}{ + "key": "value", + }, + Status: groups.EnabledStatus, + Parent: parentGroupID, + } + children = []*groups.Group{&childGroup} + parentGroup = groups.Group{ + ID: parentGroupID, + Name: namegen.Generate(), + Description: namegen.Generate(), + Metadata: map[string]interface{}{ + "key": "value", + }, + Status: groups.EnabledStatus, + Children: children, + } + validID = testsutil.GenerateUUID(&testing.T{}) + errRollbackRoles = errors.New("failed to rollback roles") + validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID} +) + +var ( + repo *mocks.Repository + policies *policymocks.Service + channels *chmocks.ChannelsServiceClient + clients *climocks.ClientsServiceClient +) + +func newService(t *testing.T) groups.Service { + repo = new(mocks.Repository) + policies = new(policymocks.Service) + channels = new(chmocks.ChannelsServiceClient) + clients = new(climocks.ClientsServiceClient) + availableActions := []roles.Action{} + builtInRoles := map[roles.BuiltInRoleName][]roles.Action{ + groups.BuiltInRoleAdmin: availableActions, + } + svc, err := groups.NewService(repo, policies, idProvider, channels, clients, idProvider, availableActions, builtInRoles) + assert.Nil(t, err, fmt.Sprintf(" Unexpected error while creating service %v", err)) + return svc +} + +func TestCreateGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + group groups.Group + saveResp groups.Group + saveErr error + deleteErr error + addPoliciesErr error + deletePoliciesErr error + addRoleErr error + err error + }{ + { + desc: "create group successfully", + group: validGroup, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: validID, + }, + err: nil, + }, + { + desc: "create group with invalid status", + group: groups.Group{ + Name: namegen.Generate(), + Description: namegen.Generate(), + Status: groups.Status(100), + }, + err: svcerr.ErrInvalidStatus, + }, + { + desc: "create group successfully with parent", + group: groups.Group{ + Name: namegen.Generate(), + Description: namegen.Generate(), + Status: groups.EnabledStatus, + Parent: testsutil.GenerateUUID(t), + }, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: testsutil.GenerateUUID(t), + Parent: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "create group with failed to save", + group: validGroup, + saveResp: groups.Group{}, + saveErr: errors.ErrMalformedEntity, + err: errors.Wrap(svcerr.ErrCreateEntity, errors.ErrMalformedEntity), + }, + { + desc: " create group with failed to add policies", + group: validGroup, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: validID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrAuthorization)), + }, + { + desc: " create group with failed to add policies and failed rollback", + group: validGroup, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: validID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + deleteErr: svcerr.ErrRemoveEntity, + err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(apiutil.ErrRollbackTx, svcerr.ErrRemoveEntity)), + }, + { + desc: "create group with failed to add roles", + group: validGroup, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: validID, + }, + addRoleErr: svcerr.ErrCreateEntity, + err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrCreateEntity)), + }, + { + desc: "create groups with failed to add roles and failed to delete policies", + group: validGroup, + saveResp: groups.Group{ + ID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + Domain: validID, + }, + addRoleErr: svcerr.ErrCreateEntity, + deletePoliciesErr: svcerr.ErrRemoveEntity, + err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, errors.Wrap(errRollbackRoles, svcerr.ErrRemoveEntity))), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("Save", context.Background(), mock.Anything).Return(tc.saveResp, tc.saveErr) + policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr) + policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr) + repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.Role{}, tc.addRoleErr) + repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr) + got, err := svc.CreateGroup(context.Background(), validSession, tc.group) + assert.Equal(t, tc.err, err, fmt.Sprintf("expected error %v but got %v", tc.err, err)) + if err == nil { + assert.NotEmpty(t, got.ID) + assert.NotEmpty(t, got.CreatedAt) + assert.NotEmpty(t, got.Domain) + assert.WithinDuration(t, time.Now(), got.CreatedAt, 2*time.Second) + ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) + } + repoCall.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestViewGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + session mgauthn.Session + id string + repoResp groups.Group + repoErr error + err error + }{ + { + desc: "view group successfully", + id: validGroup.ID, + session: validSession, + repoResp: validGroup, + }, + { + desc: "view group with failed to retrieve", + id: testsutil.GenerateUUID(t), + session: validSession, + repoErr: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveByIDAndUser", context.Background(), tc.session.DomainID, tc.session.UserID, tc.id).Return(tc.repoResp, tc.repoErr) + got, err := svc.ViewGroup(context.Background(), validSession, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.repoResp, got) + ok := repo.AssertCalled(t, "RetrieveByIDAndUser", context.Background(), tc.session.DomainID, tc.session.UserID, tc.id) + assert.True(t, ok, fmt.Sprintf("RetrieveByIDAndUser was not called on %s", tc.desc)) + } + repoCall.Unset() + }) + } +} + +func TestUpdateGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + group groups.Group + repoResp groups.Group + repoErr error + err error + }{ + { + desc: "update group successfully", + group: groups.Group{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + }, + repoResp: validGroup, + }, + { + desc: "update group with repo error", + group: groups.Group{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + }, + repoErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("Update", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr) + got, err := svc.UpdateGroup(context.Background(), validSession, tc.group) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.repoResp, got) + ok := repo.AssertCalled(t, "Update", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) + } + repoCall.Unset() + }) + } +} + +func TestEnableGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + retrieveResp groups.Group + retrieveErr error + changeResp groups.Group + changeErr error + err error + }{ + { + desc: "enable group successfully", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{ + Status: groups.DisabledStatus, + }, + changeResp: validGroup, + }, + { + desc: "enable group with enabled group", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{ + Status: groups.EnabledStatus, + }, + err: errors.ErrStatusAlreadyAssigned, + }, + { + desc: "enable group with retrieve error", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{}, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr) + repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeResp, tc.changeErr) + got, err := svc.EnableGroup(context.Background(), validSession, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.changeResp, got) + ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + } + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestDisableGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + retrieveResp groups.Group + retrieveErr error + changeResp groups.Group + changeErr error + err error + }{ + { + desc: "disable group successfully", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{ + Status: groups.EnabledStatus, + }, + changeResp: validGroup, + }, + { + desc: "disable group with disabled group", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{ + Status: groups.DisabledStatus, + }, + err: errors.ErrStatusAlreadyAssigned, + }, + { + desc: "disable group with retrieve error", + id: testsutil.GenerateUUID(t), + retrieveResp: groups.Group{}, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr) + repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeResp, tc.changeErr) + got, err := svc.DisableGroup(context.Background(), validSession, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.changeResp, got) + ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + } + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestListGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + session mgauthn.Session + pageMeta groups.PageMeta + retrieveAllRes groups.Page + retrieveAllErr error + retrieveUserGroupRes groups.Page + retrieveUserGroupErr error + resp groups.Page + err error + }{ + { + desc: "list groups as super admin successfully", + session: mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + DomainID: validID, + }, + retrieveAllRes: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + resp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "list groups as super admin with failed to retrieve", + session: mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + DomainID: validID, + }, + retrieveAllErr: repoerr.ErrNotFound, + resp: groups.Page{}, + err: repoerr.ErrNotFound, + }, + { + desc: "list groups as non admin successfully", + session: validSession, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + retrieveUserGroupRes: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + resp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "list groups as non admin with failed to retrieve user groups", + session: validSession, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + retrieveUserGroupErr: repoerr.ErrNotFound, + resp: groups.Page{}, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.retrieveAllRes, tc.retrieveAllErr) + repoCall1 := repo.On("RetrieveUserGroups", context.Background(), tc.session.DomainID, tc.session.UserID, tc.pageMeta).Return(tc.retrieveUserGroupRes, tc.retrieveUserGroupErr) + got, err := svc.ListGroups(context.Background(), tc.session, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + assert.Equal(t, tc.resp, got) + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestListUserGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + session mgauthn.Session + userID string + pageMeta groups.PageMeta + retrieveUserGroupRes groups.Page + retrieveUserGroupErr error + resp groups.Page + err error + }{ + { + desc: "list user groups successfully", + session: validSession, + userID: validID, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + retrieveUserGroupRes: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + resp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "list user groups with failed to retrieve", + session: validSession, + userID: validID, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + retrieveUserGroupErr: repoerr.ErrNotFound, + resp: groups.Page{}, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveUserGroups", context.Background(), tc.session.DomainID, tc.userID, tc.pageMeta).Return(tc.retrieveUserGroupRes, tc.retrieveUserGroupErr) + got, err := svc.ListUserGroups(context.Background(), tc.session, tc.userID, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + assert.Equal(t, tc.resp, got) + repoCall.Unset() + }) + } +} + +func TestRetrieveGroupHierarchy(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + pageMeta groups.HierarchyPageMeta + retrieveHierarchyRes groups.HierarchyPage + retrieveHierarchyErr error + listAllObjectsRes policysvc.PolicyPage + listAllObjectsErr error + err error + }{ + { + desc: "retrieve group hierarchy successfully", + id: parentGroup.ID, + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + retrieveHierarchyRes: groups.HierarchyPage{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + Groups: []groups.Group{parentGroup}, + }, + listAllObjectsRes: policysvc.PolicyPage{ + Policies: []string{parentGroupID, childGroupID}, + }, + err: nil, + }, + { + desc: "retrieve group hierarchy with failed to retrieve hierarchy", + id: parentGroup.ID, + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + retrieveHierarchyErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve group hierarchy with failed to list all objects", + id: parentGroup.ID, + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + retrieveHierarchyRes: groups.HierarchyPage{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + Groups: []groups.Group{parentGroup}, + }, + listAllObjectsErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthorization, + }, + { + desc: "retrieve group hierarchy for group not allowed for user", + id: parentGroup.ID, + pageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + retrieveHierarchyRes: groups.HierarchyPage{ + HierarchyPageMeta: groups.HierarchyPageMeta{ + Level: 1, + Direction: -1, + Tree: false, + }, + Groups: []groups.Group{parentGroup}, + }, + listAllObjectsRes: policysvc.PolicyPage{ + Policies: []string{testsutil.GenerateUUID(t)}, + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveHierarchy", context.Background(), tc.id, tc.pageMeta).Return(tc.retrieveHierarchyRes, tc.retrieveHierarchyErr) + policyCall := policies.On("ListAllObjects", context.Background(), policysvc.Policy{ + SubjectType: policysvc.UserType, + Subject: validID, + Permission: "read_permission", + ObjectType: policysvc.GroupType, + }).Return(tc.listAllObjectsRes, tc.listAllObjectsErr) + _, err := svc.RetrieveGroupHierarchy(context.Background(), validSession, tc.id, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if tc.err == nil { + ok := repo.AssertCalled(t, "RetrieveHierarchy", context.Background(), tc.id, tc.pageMeta) + assert.True(t, ok, fmt.Sprintf("RetrieveHierarchy was not called on %s", tc.desc)) + } + repoCall.Unset() + policyCall.Unset() + }) + } +} + +func TestAddParentGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + parentID string + retrieveResp groups.Group + retrieveErr error + addPoliciesErr error + deletePoliciesErr error + assignParentErr error + err error + }{ + { + desc: "add parent group successfully", + id: validGroup.ID, + parentID: parentGroupID, + retrieveResp: validGroup, + err: nil, + }, + { + desc: "add parent group with failed to retrieve", + id: validGroup.ID, + parentID: parentGroupID, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "add parent group to group with parent", + id: childGroupID, + parentID: parentGroupID, + retrieveResp: childGroup, + err: svcerr.ErrConflict, + }, + { + desc: "add parent group with failed to add policies", + id: validGroup.ID, + parentID: parentGroupID, + retrieveResp: validGroup, + addPoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrAddPolicies, + }, + { + desc: "add parent group with repo error in assign parent group", + id: validGroup.ID, + parentID: parentGroupID, + retrieveResp: validGroup, + assignParentErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "add parent group with repo error in assign parent group and failed to delete policies", + id: validGroup.ID, + parentID: parentGroupID, + retrieveResp: validGroup, + assignParentErr: repoerr.ErrNotFound, + deletePoliciesErr: svcerr.ErrAuthorization, + err: apiutil.ErrRollbackTx, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + pol := policysvc.Policy{ + Domain: validID, + SubjectType: policysvc.GroupType, + Subject: tc.parentID, + Relation: policysvc.ParentGroupRelation, + ObjectType: policysvc.GroupType, + Object: tc.id, + } + repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr) + policyCall := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr) + policyCall1 := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr) + repoCall1 := repo.On("AssignParentGroup", context.Background(), tc.parentID, []string{tc.id}).Return(tc.assignParentErr) + err := svc.AddParentGroup(context.Background(), validSession, tc.id, tc.parentID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + repoCall.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall1.Unset() + }) + } +} + +func TestRemoveParentGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + retrieveResp groups.Group + retrieveErr error + deletePoliciesErr error + addPoliciesErr error + unassignParentErr error + err error + }{ + { + desc: "remove parent group successfully", + id: childGroupID, + retrieveResp: childGroup, + err: nil, + }, + { + desc: "remove parent group with failed to retrieve", + id: childGroupID, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "remove parent group with no parent", + id: validGroup.ID, + retrieveResp: validGroup, + err: nil, + }, + { + desc: "remove parent group with failed to delete policies", + id: childGroupID, + retrieveResp: childGroup, + deletePoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrDeletePolicies, + }, + { + desc: "remove parent group with repo error in unassign parent group", + id: childGroupID, + retrieveResp: childGroup, + unassignParentErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "remove parent group with repo error in unassign parent group and failed to add policies", + id: childGroupID, + retrieveResp: childGroup, + unassignParentErr: repoerr.ErrNotFound, + addPoliciesErr: svcerr.ErrAuthorization, + err: apiutil.ErrRollbackTx, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + pol := policysvc.Policy{ + Domain: validID, + SubjectType: policysvc.GroupType, + Subject: tc.retrieveResp.Parent, + Relation: policysvc.ParentGroupRelation, + ObjectType: policysvc.GroupType, + Object: tc.id, + } + repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr) + policyCall := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr) + policyCall1 := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr) + repoCall1 := repo.On("UnassignParentGroup", context.Background(), tc.retrieveResp.Parent, []string{tc.id}).Return(tc.unassignParentErr) + err := svc.RemoveParentGroup(context.Background(), validSession, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + repoCall.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall1.Unset() + }) + } +} + +func TestAddChildrenGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + parentID string + childrenIDs []string + retrieveResp groups.Page + retrieveErr error + addPoliciesErr error + deletePoliciesErr error + assignParentErr error + err error + }{ + { + desc: "add children groups successfully", + parentID: parentGroupID, + childrenIDs: []string{validGroup.ID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "add children groups with failed to retrieve", + parentID: parentGroupID, + childrenIDs: []string{validGroup.ID}, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "add non existent child group", + parentID: parentGroupID, + childrenIDs: []string{testsutil.GenerateUUID(&testing.T{})}, + retrieveResp: groups.Page{}, + err: groups.ErrGroupIDs, + }, + { + desc: "add child group with parent", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: svcerr.ErrConflict, + }, + { + desc: "add children groups with failed to add policies", + parentID: parentGroupID, + childrenIDs: []string{validGroup.ID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + addPoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrAddPolicies, + }, + { + desc: "add children groups with repo error in assign children groups", + parentID: parentGroupID, + childrenIDs: []string{validGroup.ID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + assignParentErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "add children groups with repo error in assign children groups and failed to delete policies", + parentID: parentGroupID, + childrenIDs: []string{validGroup.ID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{validGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + assignParentErr: repoerr.ErrNotFound, + deletePoliciesErr: svcerr.ErrAuthorization, + err: apiutil.ErrRollbackTx, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + pol := policysvc.Policy{ + Domain: validID, + SubjectType: policysvc.GroupType, + Subject: tc.parentID, + Relation: policysvc.ParentGroupRelation, + ObjectType: policysvc.GroupType, + Object: validGroup.ID, + } + repoCall := repo.On("RetrieveByIDs", context.Background(), groups.PageMeta{Limit: 1<<63 - 1}, tc.childrenIDs).Return(tc.retrieveResp, tc.retrieveErr) + policyCall := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr) + policyCall1 := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr) + repoCall1 := repo.On("AssignParentGroup", context.Background(), tc.parentID, tc.childrenIDs).Return(tc.assignParentErr) + err := svc.AddChildrenGroups(context.Background(), validSession, tc.parentID, tc.childrenIDs) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + repoCall.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall1.Unset() + }) + } +} + +func TestRemoveChildrenGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + parentID string + childrenIDs []string + retrieveResp groups.Page + retrieveErr error + deletePoliciesErr error + addPoliciesErr error + unassignParentErr error + err error + }{ + { + desc: "remove children groups successfully", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "remove children groups with failed to retrieve", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "remove non existent child group", + parentID: parentGroupID, + childrenIDs: []string{testsutil.GenerateUUID(&testing.T{})}, + retrieveResp: groups.Page{}, + err: groups.ErrGroupIDs, + }, + { + desc: "remove children groups from different parent", + parentID: validGroup.ID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: svcerr.ErrConflict, + }, + { + desc: "remove children groups with failed to delete policies", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + deletePoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrDeletePolicies, + }, + { + desc: "remove children groups with repo error in unassign children groups", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + unassignParentErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "remove children groups with repo error in unassign children groups and failed to add policies", + parentID: parentGroupID, + childrenIDs: []string{childGroupID}, + retrieveResp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + unassignParentErr: repoerr.ErrNotFound, + addPoliciesErr: svcerr.ErrAuthorization, + err: apiutil.ErrRollbackTx, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + pol := policysvc.Policy{ + Domain: validID, + SubjectType: policysvc.GroupType, + Subject: tc.parentID, + Relation: policysvc.ParentGroupRelation, + ObjectType: policysvc.GroupType, + Object: childGroupID, + } + repoCall := repo.On("RetrieveByIDs", context.Background(), groups.PageMeta{Limit: 1<<63 - 1}, tc.childrenIDs).Return(tc.retrieveResp, tc.retrieveErr) + policyCall := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr) + policyCall1 := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr) + repoCall1 := repo.On("UnassignParentGroup", context.Background(), tc.parentID, tc.childrenIDs).Return(tc.unassignParentErr) + err := svc.RemoveChildrenGroups(context.Background(), validSession, tc.parentID, tc.childrenIDs) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + repoCall.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall1.Unset() + }) + } +} + +func TestRemoveAllChildrenGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + parentID string + deletePolicyErr error + unassignAllChildrenErr error + err error + }{ + { + desc: "remove all children groups successfully", + parentID: parentGroupID, + err: nil, + }, + { + desc: "remove all children groups with failed to delete policy", + parentID: parentGroupID, + deletePolicyErr: svcerr.ErrAuthorization, + err: svcerr.ErrDeletePolicies, + }, + { + desc: "remove all children groups with failed to unassign all children", + parentID: parentGroupID, + deletePolicyErr: nil, + unassignAllChildrenErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + policyCall := policies.On("DeletePolicyFilter", context.Background(), policysvc.Policy{ + Domain: validID, + SubjectType: policysvc.GroupType, + Subject: tc.parentID, + Relation: policysvc.ParentGroupRelation, + ObjectType: policysvc.GroupType, + }).Return(tc.deletePolicyErr) + repoCall := repo.On("UnassignAllChildrenGroups", context.Background(), tc.parentID).Return(tc.unassignAllChildrenErr) + err := svc.RemoveAllChildrenGroups(context.Background(), validSession, tc.parentID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + policyCall.Unset() + repoCall.Unset() + }) + } +} + +func TestListAllChildrenGroups(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + session mgauthn.Session + pageMeta groups.PageMeta + parentID string + startLevel int64 + endLevel int64 + retrieveRes groups.Page + retrieveErr error + resp groups.Page + err error + }{ + { + desc: "list all children groups successfully", + session: validSession, + parentID: parentGroupID, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + startLevel: 0, + endLevel: -1, + retrieveRes: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + resp: groups.Page{ + Groups: []groups.Group{childGroup}, + PageMeta: groups.PageMeta{ + Total: 1, + }, + }, + err: nil, + }, + { + desc: "list all children groups with failed to retrieve", + session: validSession, + parentID: parentGroupID, + pageMeta: groups.PageMeta{ + Limit: 10, + Offset: 0, + }, + startLevel: 0, + endLevel: -1, + retrieveErr: repoerr.ErrNotFound, + resp: groups.Page{}, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RetrieveChildrenGroups", context.Background(), tc.session.DomainID, tc.session.UserID, tc.parentID, tc.startLevel, tc.endLevel, tc.pageMeta).Return(tc.retrieveRes, tc.retrieveErr) + page, err := svc.ListChildrenGroups(context.Background(), tc.session, tc.parentID, tc.startLevel, tc.endLevel, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + assert.Equal(t, tc.resp, page) + repoCall.Unset() + }) + } +} + +func TestDeleteGroup(t *testing.T) { + svc := newService(t) + + cases := []struct { + desc string + id string + changeStatusRes groups.Group + changeStatusErr error + deletePoliciesErr error + deleteErr error + unsetFromChannels error + unsetFromClients error + err error + }{ + { + desc: "delete group successfully", + id: validGroup.ID, + err: nil, + }, + { + desc: "delete group with parent successfully", + id: childGroupID, + changeStatusRes: childGroup, + err: nil, + }, + { + desc: "delete group with failed to remove parent group from channels", + id: validGroup.ID, + unsetFromChannels: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "delete group with failed to remove parent group from clients", + id: validGroup.ID, + unsetFromChannels: nil, + unsetFromClients: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "delete group with failed to change status", + id: validGroup.ID, + changeStatusErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "delete group with failed to delete", + id: validGroup.ID, + changeStatusRes: validGroup, + deleteErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "delete group with failed to delete policies", + id: validGroup.ID, + changeStatusRes: validGroup, + deleteErr: nil, + deletePoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrDeletePolicies, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("ChangeStatus", context.Background(), groups.Group{ID: tc.id, Status: groups.DeletedStatus}).Return(tc.changeStatusRes, tc.changeStatusErr) + repoCall1 := repo.On("Delete", context.Background(), tc.id).Return(tc.deleteErr) + svcCall := channels.On("UnsetParentGroupFromChannels", context.Background(), &grpcChannelsV1.UnsetParentGroupFromChannelsReq{ParentGroupId: tc.id}).Return(&grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, tc.unsetFromChannels) + svcCall1 := clients.On("UnsetParentGroupFromClient", context.Background(), &grpcClientsV1.UnsetParentGroupFromClientReq{ParentGroupId: tc.id}).Return(&grpcClientsV1.UnsetParentGroupFromClientRes{}, tc.unsetFromClients) + repoCall2 := repo.On("RetrieveEntitiesRolesActionsMembers", context.Background(), []string{tc.id}).Return([]roles.EntityActionRole{}, []roles.EntityMemberRole{}, nil) + policyCall := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePoliciesErr) + policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(nil) + err := svc.DeleteGroup(context.Background(), validSession, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + policyCall.Unset() + repoCall.Unset() + repoCall1.Unset() + svcCall.Unset() + svcCall1.Unset() + repoCall2.Unset() + policyCall1.Unset() + }) + } +} diff --git a/groups/status.go b/groups/status.go index 22f72e9af15..3a45357ea3d 100644 --- a/groups/status.go +++ b/groups/status.go @@ -19,7 +19,7 @@ const ( EnabledStatus Status = iota // DisabledStatus represents disabled Group. DisabledStatus - // DeletedStatus + // DeletedStatus represents deleted Group. DeletedStatus // AllStatus is used for querying purposes to list groups irrespective diff --git a/internal/api/common.go b/internal/api/common.go index 88d94ecb8f5..049c378b7fb 100644 --- a/internal/api/common.go +++ b/internal/api/common.go @@ -189,7 +189,9 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, apiutil.ErrMissingLastName), errors.Contains(err, apiutil.ErrInvalidUsername), errors.Contains(err, apiutil.ErrMissingIdentity), - errors.Contains(err, apiutil.ErrInvalidProfilePictureURL): + errors.Contains(err, apiutil.ErrInvalidProfilePictureURL), + errors.Contains(err, apiutil.ErrSelfParentingNotAllowed), + errors.Contains(err, apiutil.ErrMissingChildrenGroupIDs): err = unwrap(err) w.WriteHeader(http.StatusBadRequest)