Skip to content

Commit

Permalink
fix: clients channels listing sdk
Browse files Browse the repository at this point in the history
Signed-off-by: Arvindh <[email protected]>
  • Loading branch information
arvindh123 committed Jan 17, 2025
1 parent 9dc7809 commit 67571b6
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 570 deletions.
7 changes: 0 additions & 7 deletions channels/api/http/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,6 @@ func TestListChannelsReqValidation(t *testing.T) {
},
err: apiutil.ErrNameSize,
},
{
desc: "invalid visibility",
req: listChannelsReq{
limit: 10,
},
err: apiutil.ErrInvalidVisibilityType,
},
}
for _, tc := range cases {
err := tc.req.validate()
Expand Down
26 changes: 0 additions & 26 deletions cli/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,32 +370,6 @@ var cmdUsers = []cobra.Command{
logOKCmd(*cmd)
},
},
{
Use: "clients <user_id> <domain_id> <user_auth_token>",
Short: "List clients",
Long: "List clients of user\n" +
"Usage:\n" +
"\tsupermq-cli users clients <user_id> <user_auth_token>\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 {
logUsageCmd(*cmd, cmd.Use)
return
}

pm := smqsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
}

tp, err := sdk.ListUserClients(args[0], args[1], pm, args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
}

logJSONCmd(*cmd, tp)
},
},
{
Use: "search <query> <user_auth_token>",
Short: "Search users",
Expand Down
2 changes: 1 addition & 1 deletion clients/api/http/endpoints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ func TestListClients(t *testing.T) {
}

authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, "", mock.Anything).Return(tc.listClientsResponse, tc.err)
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, mock.Anything).Return(tc.listClientsResponse, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))

Expand Down
7 changes: 0 additions & 7 deletions clients/api/http/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,6 @@ func TestListClientsReqValidate(t *testing.T) {
},
err: apiutil.ErrLimitSize,
},
{
desc: "invalid visibility",
req: listClientsReq{
limit: 10,
},
err: apiutil.ErrInvalidVisibilityType,
},
{
desc: "name too long",
req: listClientsReq{
Expand Down
20 changes: 10 additions & 10 deletions clients/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Repository interface {
RetrieveAll(ctx context.Context, pm Page) (ClientsPage, error)

// RetrieveUserClients retrieve all clients of a given user id.
RetrieveUserClients(ctx context.Context, domainID, userID string, includeDomainClients bool, pm Page) (ClientsPage, error)
RetrieveUserClients(ctx context.Context, domainID, userID string, pm Page) (ClientsPage, error)

// SearchClients retrieves clients based on search criteria.
SearchClients(ctx context.Context, pm Page) (ClientsPage, error)
Expand Down Expand Up @@ -171,15 +171,15 @@ type Client struct {
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
Identity string `json:"identity,omitempty"`
// Extended
ParentGroupPath string `json:"parent_group_path"`
RoleID string `json:"role_id"`
RoleName string `json:"role_name"`
Actions []string `json:"actions"`
AccessType string `json:"access_type"`
AccessProviderId string `json:"access_provider_id"`
AccessProviderRoleId string `json:"access_provider_role_id"`
AccessProviderRoleName string `json:"access_provider_role_name"`
AccessProviderRoleActions []string `json:"access_provider_role_actions"`
ParentGroupPath string `json:"parent_group_path,omitempty"`
RoleID string `json:"role_id,omitempty"`
RoleName string `json:"role_name,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
AccessProviderId string `json:"access_provider_id,omitempty"`
AccessProviderRoleId string `json:"access_provider_role_id,omitempty"`
AccessProviderRoleName string `json:"access_provider_role_name,omitempty"`
AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"`
}

// ClientsPage contains page related metadata as well as list.
Expand Down
18 changes: 9 additions & 9 deletions clients/mocks/repository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 6 additions & 42 deletions clients/postgres/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
return page, nil
}

func (repo *clientRepo) RetrieveUserClients(ctx context.Context, domainID, userID string, includeDomainClients bool, pm clients.Page) (clients.ClientsPage, error) {
return repo.retrieveClients(ctx, domainID, userID, includeDomainClients, pm)
func (repo *clientRepo) RetrieveUserClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
return repo.retrieveClients(ctx, domainID, userID, pm)
}

func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID string, includeDomainClients bool, pm clients.Page) (clients.ClientsPage, error) {
func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
pageQuery, err := PageQuery(pm)
if err != nil {
return clients.ClientsPage{}, err
}

bq := repo.userClientBaseQuery(domainID, userID, includeDomainClients)
bq := repo.userClientBaseQuery(domainID, userID)

q := fmt.Sprintf(`
%s
Expand Down Expand Up @@ -370,42 +370,7 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
return page, nil
}

func (repo *clientRepo) userClientBaseQuery(domainID, userID string, includeDomainClients bool) string {
domainClientsQuery := ""
if includeDomainClients {
domainClientsQuery = fmt.Sprintf(`
UNION
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
'' AS parent_group_path,
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
'indirect_domain' AS access_type,
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
FROM
clients c
WHERE
c.domain_id = '%s'
AND
c.id NOT IN (SELECT id FROM group_direct_clients)
`, domainID)
}
func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
WITH direct_clients AS (
SELECT
Expand Down Expand Up @@ -623,9 +588,8 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string, includeDoma
gdc.access_provider_role_actions
FROM
group_direct_clients AS gdc
%s
)
`, userID, domainID, userID, domainID, domainID, domainClientsQuery)
`, userID, domainID, userID, domainID, domainID)
}

func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
Expand Down
12 changes: 2 additions & 10 deletions clients/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,7 @@ func (svc service) ListClients(ctx context.Context, session authn.Session, pm Pa
}
return cp, nil
default:
includeDomainClients, ok := ctx.Value(ListDomainClients).(bool)
if !ok {
includeDomainClients = false
}
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, session.UserID, includeDomainClients, pm)
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, session.UserID, pm)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
Expand All @@ -151,11 +147,7 @@ func (svc service) ListClients(ctx context.Context, session authn.Session, pm Pa
}

func (svc service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error) {
includeDomainClients, ok := ctx.Value(ListDomainClients).(bool)
if !ok {
includeDomainClients = false
}
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, userID, includeDomainClients, pm)
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, userID, pm)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
Expand Down
73 changes: 5 additions & 68 deletions clients/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,29 +411,6 @@ func TestListClients(t *testing.T) {
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as non admin with failed to list permissions",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: []string{},
response: clients.ClientsPage{},
listPermissionsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as non admin with failed super admin",
userKind: "non-admin",
Expand All @@ -455,6 +432,7 @@ func TestListClients(t *testing.T) {
Offset: 0,
Limit: 100,
},
retrieveAllErr: repoerr.ErrNotFound,
response: clients.ClientsPage{},
listObjectsResponse: policysvc.PolicyPage{},
listObjectsErr: svcerr.ErrNotFound,
Expand All @@ -463,15 +441,13 @@ func TestListClients(t *testing.T) {
}

for _, tc := range cases {
listAllObjectsCall := pService.On("ListAllObjects", mock.Anything, mock.Anything).Return(tc.listObjectsResponse, tc.listObjectsErr)
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveUserClientsCall := repo.On("RetrieveUserClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
listAllObjectsCall.Unset()
retrieveAllCall.Unset()
listPermissionsCall.Unset()
retrieveUserClientsCall.Unset()
}

cases2 := []struct {
Expand Down Expand Up @@ -534,29 +510,6 @@ func TestListClients(t *testing.T) {
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list permissions",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: []string{},
listPermissionsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list clients",
userKind: "admin",
Expand All @@ -574,27 +527,11 @@ func TestListClients(t *testing.T) {
}

for _, tc := range cases2 {
listAllObjectsCall := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: tc.session.DomainID + "_" + adminID,
Permission: "",
ObjectType: policysvc.ClientType,
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
listAllObjectsCall2 := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: tc.session.UserID,
Permission: "",
ObjectType: policysvc.ClientType,
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
listAllObjectsCall.Unset()
listAllObjectsCall2.Unset()
retrieveAllCall.Unset()
listPermissionsCall.Unset()
}
}

Expand Down
2 changes: 1 addition & 1 deletion clients/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestStatusUnmarshalJSON(t *testing.T) {
}
}

func TestUserMarshalJSON(t *testing.T) {
func TestClientMarshalJSON(t *testing.T) {
cases := []struct {
desc string
expected []byte
Expand Down
Loading

0 comments on commit 67571b6

Please sign in to comment.