diff --git a/README.md b/README.md index a9cd3c93..16d784ff 100644 --- a/README.md +++ b/README.md @@ -274,13 +274,16 @@ if err := pages.Err(); err != nil { --> ```go -page, err := channel.Presence.Get(ctx, nil) -for ; err == nil && page != nil; page, err = page.Next(ctx) { - for _, presence := range page.PresenceMessages() { +pages, err := channel.Presence.Get().Pages(ctx) +if err != nil { + panic(err) +} +for pages.Next(ctx) { + for _, presence := range pages.Items() { fmt.Println(presence) } } -if err != nil { +if err := pages.Err(); err != nil { panic(err) } ``` diff --git a/ably/ablytest/pagination.go b/ably/ablytest/pagination.go index e052157d..8b23a88f 100644 --- a/ably/ablytest/pagination.go +++ b/ably/ablytest/pagination.go @@ -23,7 +23,8 @@ func AllPages(dst, paginatedRequest interface{}) error { } type paginationOptions struct { - equal func(x, y interface{}) bool + equal func(x, y interface{}) bool + sortResult func([]interface{}) } type PaginationOption func(*paginationOptions) @@ -34,9 +35,16 @@ func PaginationWithEqual(equal func(x, y interface{}) bool) PaginationOption { } } +func PaginationWithSortResult(sort func([]interface{})) PaginationOption { + return func(o *paginationOptions) { + o.sortResult = sort + } +} + func TestPagination(expected, request interface{}, perPage int, options ...PaginationOption) error { opts := paginationOptions{ - equal: reflect.DeepEqual, + equal: reflect.DeepEqual, + sortResult: func(items []interface{}) {}, } for _, o := range options { o(&opts) @@ -47,40 +55,36 @@ func TestPagination(expected, request interface{}, perPage int, options ...Pagin for i := 0; i < reflect.ValueOf(expected).Len(); i++ { items = append(items, rexpected.Index(i).Interface()) } - return testPagination(reflect.ValueOf(request), items, perPage, opts.equal) + return testPagination(reflect.ValueOf(request), items, perPage, opts) } -func testPagination(request reflect.Value, expectedItems []interface{}, perPage int, equal func(x, y interface{}) bool) error { +func testPagination(request reflect.Value, expectedItems []interface{}, perPage int, opts paginationOptions) error { getPages, getItems := generalizePagination(request) - var expectedPages [][]interface{} - var page []interface{} - for _, item := range expectedItems { - page = append(page, item) - if len(page) == perPage { - expectedPages = append(expectedPages, page) - page = nil - } - } - if len(page) > 0 { - expectedPages = append(expectedPages, page) - } - for i := 0; i < 2; i++ { pages, err := getPages() if err != nil { return fmt.Errorf("calling Pages: %w", err) } - var gotPages [][]interface{} + var gotItems []interface{} + pageNum := 1 + expectedFullPages := len(expectedItems) / perPage for pages.next() { - gotPages = append(gotPages, pages.items()) + if (pageNum <= expectedFullPages && len(pages.items()) != perPage) || + (pageNum > expectedFullPages && len(pages.items()) >= perPage) { + return fmt.Errorf("page #%d got %d items, expected at most %d", pageNum, len(pages.items()), perPage) + } + gotItems = append(gotItems, pages.items()...) + pageNum++ } if err := pages.err(); err != nil { return fmt.Errorf("iterating pages: %w", err) } - if !PagesEqual(expectedPages, gotPages, equal) { - return fmt.Errorf("expected pages: %+v, got: %+v", expectedPages, gotPages) + opts.sortResult(gotItems) + + if !ItemsEqual(expectedItems, gotItems, opts.equal) { + return fmt.Errorf("expected items: %+v, got: %+v", expectedItems, gotItems) } if err := pages.first(); err != nil { @@ -101,7 +105,9 @@ func testPagination(request reflect.Value, expectedItems []interface{}, perPage return fmt.Errorf("iterating items: %w", err) } - if !ItemsEqual(expectedItems, gotItems, equal) { + opts.sortResult(gotItems) + + if !ItemsEqual(expectedItems, gotItems, opts.equal) { return fmt.Errorf("expected items: %+v, got: %+v", expectedItems, gotItems) } @@ -184,18 +190,6 @@ func generalizePagination(request reflect.Value) (func() (paginatedResult, error return pages, items } -func PagesEqual(expected, got [][]interface{}, equal func(x, y interface{}) bool) bool { - if len(expected) != len(got) { - return false - } - for i := range expected { - if !ItemsEqual(expected[i], got[i], equal) { - return false - } - } - return true -} - func ItemsEqual(expected, got []interface{}, equal func(x, y interface{}) bool) bool { if len(expected) != len(got) { return false diff --git a/ably/ablytest/sandbox.go b/ably/ablytest/sandbox.go index a3bdf828..36782f25 100644 --- a/ably/ablytest/sandbox.go +++ b/ably/ablytest/sandbox.go @@ -77,18 +77,22 @@ func DefaultConfig() *Config { }, Channels: []Channel{ { - Name: "persisted:presence_fixtures", - Presence: []Presence{ - {ClientID: "client_bool", Data: "true"}, - {ClientID: "client_int", Data: "true"}, - {ClientID: "client_string", Data: "true"}, - {ClientID: "client_json", Data: `{"test": "This is a JSONObject clientData payload"}`}, - }, + Name: "persisted:presence_fixtures", + Presence: PresenceFixtures(), }, }, } } +var PresenceFixtures = func() []Presence { + return []Presence{ + {ClientID: "client_bool", Data: "true"}, + {ClientID: "client_int", Data: "true"}, + {ClientID: "client_string", Data: "true"}, + {ClientID: "client_json", Data: `{"test": "This is a JSONObject clientData payload"}`}, + } +} + type Sandbox struct { Config *Config Environment string diff --git a/ably/export_test.go b/ably/export_test.go index 96b4b195..77855d66 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -7,10 +7,6 @@ import ( "time" ) -func (p *PaginatedResult) BuildPath(base, rel string) string { - return p.buildPath(base, rel) -} - func (opts *clientOptions) RestURL() string { return opts.restURL() } diff --git a/ably/http_paginated_response.go b/ably/http_paginated_response.go deleted file mode 100644 index dadfafe0..00000000 --- a/ably/http_paginated_response.go +++ /dev/null @@ -1,59 +0,0 @@ -package ably - -import ( - "context" - "net/http" - "reflect" - - "github.com/ably/ably-go/ably/proto" -) - -// HTTPPaginatedResponse represent a response from an http request. -type HTTPPaginatedResponse struct { - *PaginatedResult - StatusCode int //spec HP4 - Success bool //spec HP5 - ErrorCode ErrorCode //spec HP6 - ErrorMessage string //spec HP7 - Headers http.Header //spec HP8 -} - -func decodeHTTPPaginatedResult(opts *proto.ChannelOptions, typ reflect.Type, resp *http.Response) (interface{}, error) { - var o interface{} - err := decodeResp(resp, &o) - if err != nil { - return nil, err - } - return o, nil -} - -func newHTTPPaginatedResult(ctx context.Context, path string, params *PaginateParams, - query queryFunc, log *LoggerOptions) (*HTTPPaginatedResponse, error) { - p, err := newPaginatedResult(ctx, nil, paginatedRequest{typ: arrayTyp, path: path, params: params, query: query, logger: log, respCheck: func(_ *http.Response) error { - return nil - }, decoder: decodeHTTPPaginatedResult}) - if err != nil { - return nil, err - } - //spec RSC19d - return newHTTPPaginatedResultFromPaginatedResult(p), nil -} - -func newHTTPPaginatedResultFromPaginatedResult(p *PaginatedResult) *HTTPPaginatedResponse { - h := &HTTPPaginatedResponse{PaginatedResult: p} - h.StatusCode = p.statusCode - h.Success = p.success - h.ErrorCode = ErrorCode(p.errorCode) - h.ErrorMessage = p.errorMessage - return h -} - -// Next overrides PaginatedResult.Next -// spec HP2 -func (h *HTTPPaginatedResponse) Next(ctx context.Context) (*HTTPPaginatedResponse, error) { - p, err := h.PaginatedResult.Next(ctx) - if err != nil { - return nil, err - } - return newHTTPPaginatedResultFromPaginatedResult(p), nil -} diff --git a/ably/http_paginated_response_test.go b/ably/http_paginated_response_test.go index 7ce0e146..5edf6d83 100644 --- a/ably/http_paginated_response_test.go +++ b/ably/http_paginated_response_test.go @@ -3,6 +3,7 @@ package ably_test import ( "context" "net/http" + "net/url" "reflect" "sort" "testing" @@ -26,37 +27,43 @@ func TestHTTPPaginatedResponse(t *testing.T) { t.Fatal(err) } t.Run("request_time", func(ts *testing.T) { - res, err := client.Request(context.Background(), "get", "/time", nil, nil, nil) + res, err := client.Request("get", "/time").Pages(context.Background()) if err != nil { ts.Fatal(err) } - if res.StatusCode != http.StatusOK { - ts.Errorf("expected %d got %d", http.StatusOK, res.StatusCode) + if res.StatusCode() != http.StatusOK { + ts.Errorf("expected %d got %d", http.StatusOK, res.StatusCode()) } - if !res.Success { + if !res.Success() { ts.Error("expected success to be true") } - n := len(res.Items()) + res.Next(context.Background()) + var items []interface{} + err = res.Items(&items) + if err != nil { + ts.Error(err) + } + n := len(items) if n != 1 { ts.Errorf("expected 1 item got %d", n) } }) t.Run("request_404", func(ts *testing.T) { - res, err := client.Request(context.Background(), "get", "/keys/ablyjs.test/requestToken", nil, nil, nil) + res, err := client.Request("get", "/keys/ablyjs.test/requestToken").Pages(context.Background()) if err != nil { ts.Fatal(err) } - if res.StatusCode != http.StatusNotFound { - ts.Errorf("expected %d got %d", http.StatusNotFound, res.StatusCode) + if res.StatusCode() != http.StatusNotFound { + ts.Errorf("expected %d got %d", http.StatusNotFound, res.StatusCode()) } - if res.ErrorCode != ably.ErrNotFound { - ts.Errorf("expected %d got %d", ably.ErrNotFound, res.ErrorCode) + if res.ErrorCode() != ably.ErrNotFound { + ts.Errorf("expected %d got %d", ably.ErrNotFound, res.ErrorCode()) } - if res.Success { + if res.Success() { ts.Error("expected success to be false") } - if res.ErrorMessage == "" { + if res.ErrorMessage() == "" { ts.Error("expected error message") } }) @@ -71,17 +78,23 @@ func TestHTTPPaginatedResponse(t *testing.T) { ts.Run("post", func(ts *testing.T) { for _, message := range msgs { - res, err := client.Request(context.Background(), "POST", channelPath, nil, message, nil) + res, err := client.Request("POST", channelPath, ably.RequestWithBody(message)).Pages(context.Background()) if err != nil { ts.Fatal(err) } - if res.StatusCode != http.StatusCreated { - ts.Errorf("expected %d got %d", http.StatusCreated, res.StatusCode) + if res.StatusCode() != http.StatusCreated { + ts.Errorf("expected %d got %d", http.StatusCreated, res.StatusCode()) } - if !res.Success { + if !res.Success() { ts.Error("expected success to be true") } - n := len(res.Items()) + res.Next(context.Background()) + var items []interface{} + err = res.Items(&items) + if err != nil { + ts.Error(err) + } + n := len(items) if n != 1 { ts.Errorf("expected 1 item got %d", n) } @@ -89,21 +102,27 @@ func TestHTTPPaginatedResponse(t *testing.T) { }) ts.Run("get", func(ts *testing.T) { - res, err := client.Request(context.Background(), "get", channelPath, &ably.PaginateParams{ - Limit: 1, - Direction: "forwards", - }, nil, nil) + res, err := client.Request("get", channelPath, ably.RequestWithParams(url.Values{ + "limit": {"1"}, + "direction": {"forwards"}, + })).Pages(context.Background()) if err != nil { ts.Fatal(err) } - if res.StatusCode != http.StatusOK { - ts.Errorf("expected %d got %d", http.StatusOK, res.StatusCode) + if res.StatusCode() != http.StatusOK { + ts.Errorf("expected %d got %d", http.StatusOK, res.StatusCode()) } - n := len(res.Items()) + res.Next(context.Background()) + var items []interface{} + err = res.Items(&items) + if err != nil { + ts.Error(err) + } + n := len(items) if n != 1 { ts.Fatalf("expected 1 item got %d", n) } - m := res.Items()[0].(map[string]interface{}) + m := items[0].(map[string]interface{}) name := m["name"].(string) data := m["data"].(string) if name != msgs[0].Name { @@ -113,15 +132,18 @@ func TestHTTPPaginatedResponse(t *testing.T) { ts.Errorf("expected %v got %s", msgs[0].Data, data) } - res, err = res.Next(context.Background()) + if !res.Next(context.Background()) { + ts.Fatal(res.Err()) + } + err = res.Items(&items) if err != nil { - ts.Fatal(err) + ts.Error(err) } - n = len(res.Items()) + n = len(items) if n != 1 { ts.Fatalf("expected 1 item got %d", n) } - m = res.Items()[0].(map[string]interface{}) + m = items[0].(map[string]interface{}) name = m["name"].(string) data = m["data"].(string) if name != msgs[1].Name { diff --git a/ably/paginated_result.go b/ably/paginated_result.go index cfbffa37..4bfa13b1 100644 --- a/ably/paginated_result.go +++ b/ably/paginated_result.go @@ -2,16 +2,11 @@ package ably import ( "context" - "errors" - "fmt" - "io" "net/http" "net/url" "path" "reflect" "regexp" - "strconv" - "strings" "github.com/ably/ably-go/ably/proto" ) @@ -23,15 +18,15 @@ const ( Forwards Direction = "forwards" ) -type paginatedRequestNew struct { +type paginatedRequest struct { path string params url.Values query queryFunc } -func (r *REST) newPaginatedRequest(path string, params url.Values) paginatedRequestNew { - return paginatedRequestNew{ +func (r *REST) newPaginatedRequest(path string, params url.Values) paginatedRequest { + return paginatedRequest{ path: path, params: params, @@ -39,11 +34,11 @@ func (r *REST) newPaginatedRequest(path string, params url.Values) paginatedRequ } } -// PaginatedResultNew is a generic iterator for PaginatedResult pagination. +// PaginatedResult is a generic iterator for PaginatedResult pagination. // Items decoding is delegated to type-specific wrappers. // // See "Paginated results" section in the package-level documentation. -type PaginatedResultNew struct { +type PaginatedResult struct { basePath string nextLink string firstLink string @@ -56,7 +51,7 @@ type PaginatedResultNew struct { // load loads the first page of results. Must be called from the type-specific // wrapper Pages method that creates the PaginatedResult object. -func (p *PaginatedResultNew) load(ctx context.Context, r paginatedRequestNew) error { +func (p *PaginatedResult) load(ctx context.Context, r paginatedRequest) error { p.basePath = path.Dir(r.path) p.firstLink = (&url.URL{ Path: r.path, @@ -78,9 +73,9 @@ func (p *PaginatedResultNew) load(ctx context.Context, r paginatedRequestNew) er // It should return a destination object on which the page of results will be // decoded, and a pageLength function that, when called after the page has been // decoded, must return the length of the page. -func (p *PaginatedResultNew) loadItems( +func (p *PaginatedResult) loadItems( ctx context.Context, - r paginatedRequestNew, + r paginatedRequest, pageDecoder func() (page interface{}, pageLength func() int), ) ( next func(context.Context) (int, bool), @@ -115,7 +110,7 @@ func (p *PaginatedResultNew) loadItems( }, nil } -func (p *PaginatedResultNew) goTo(ctx context.Context, link string) error { +func (p *PaginatedResult) goTo(ctx context.Context, link string) error { var err error p.res, err = p.query(ctx, link) if err != nil { @@ -146,7 +141,7 @@ func (p *PaginatedResultNew) goTo(ctx context.Context, link string) error { // Items can then be inspected with the type-specific Items method. // // For items iterators, use the next function returned by loadItems instead. -func (p *PaginatedResultNew) next(ctx context.Context, into interface{}) bool { +func (p *PaginatedResult) next(ctx context.Context, into interface{}) bool { if !p.first { if p.nextLink == "" { return false @@ -164,13 +159,13 @@ func (p *PaginatedResultNew) next(ctx context.Context, into interface{}) bool { // First loads the first page of items. Next should be called before inspecting // the Items. -func (p *PaginatedResultNew) First(ctx context.Context) error { +func (p *PaginatedResult) First(ctx context.Context) error { p.first = true return p.goTo(ctx, p.firstLink) } // Err returns the error that caused Next to fail, if there was one. -func (p *PaginatedResultNew) Err() error { +func (p *PaginatedResult) Err() error { return p.err } @@ -189,103 +184,6 @@ func (err errInvalidType) Error() string { // occurred. type queryFunc func(ctx context.Context, url string) (*http.Response, error) -// PaginatedResult represents a single page coming back from the REST API. -// Any call to create a new page will generate a new instance. -type PaginatedResult struct { - path string - headers map[string]string - links []string - items []interface{} - typItems interface{} - opts *proto.ChannelOptions - req paginatedRequest - - statusCode int - success bool - errorCode int - errorMessage string - respHeaders http.Header -} - -type paginatedRequest struct { - typ reflect.Type - path string - params *PaginateParams - query queryFunc - logger *LoggerOptions - respCheck func(*http.Response) error - decoder func(*proto.ChannelOptions, reflect.Type, *http.Response) (interface{}, error) -} - -func decodePaginatedResult(opts *proto.ChannelOptions, typ reflect.Type, resp *http.Response) (interface{}, error) { - v := reflect.New(typ) - if err := decodeResp(resp, v.Interface()); err != nil { - return nil, err - } - return v.Elem().Interface(), nil -} - -func newPaginatedResult(ctx context.Context, opts *proto.ChannelOptions, req paginatedRequest) (*PaginatedResult, error) { - if req.decoder == nil { - req.decoder = decodePaginatedResult - } - p := &PaginatedResult{ - opts: opts, - req: req, - } - builtPath, err := p.buildPaginatedPath(req.path, req.params) - if err != nil { - return nil, err - } - resp, err := p.req.query(ctx, builtPath) - if err != nil { - return nil, err - } - if err = req.respCheck(resp); err != nil { - return nil, err - } - defer resp.Body.Close() - if p.respHeaders == nil { - p.respHeaders = make(http.Header) - } - p.statusCode = resp.StatusCode - p.success = 200 <= p.statusCode && p.statusCode < 300 - copyHeader(p.respHeaders, resp.Header) - if h := p.respHeaders.Get(proto.AblyErrorCodeHeader); h != "" { - i, err := strconv.Atoi(h) - if err != nil { - return nil, err - } - p.errorCode = i - } else if !p.success { - return nil, malformedPaginatedResponseError(resp) - } - if h := p.respHeaders.Get(proto.AblyErrormessageHeader); h != "" { - p.errorMessage = h - } else if !p.success { - return nil, malformedPaginatedResponseError(resp) - } - p.path = builtPath - p.links = resp.Header["Link"] - v, err := p.req.decoder(opts, p.req.typ, resp) - if err != nil { - return nil, err - } - p.typItems = v - return p, nil -} - -func malformedPaginatedResponseError(resp *http.Response) error { - body := make([]byte, 200) - n, err := io.ReadFull(resp.Body, body) - body = body[:n] - msg := fmt.Sprintf("invalid PaginatedResult HTTP response; status: %d; body (first %d bytes): %q", resp.StatusCode, len(body), body) - if err != nil && !errors.Is(err, io.EOF) { - return fmt.Errorf("%s; body read error: %w", msg, err) - } - return errors.New(msg) -} - func copyHeader(dest, src http.Header) { for k, v := range src { d := make([]string, len(v)) @@ -294,49 +192,6 @@ func copyHeader(dest, src http.Header) { } } -// Next returns the path to the next page as found in the response headers. -// The response headers from the REST API contains a relative link to the next result. -// (Link: <./path>; rel="next"). -// -// If there is no next link, both return values are nil. -func (p *PaginatedResult) Next(ctx context.Context) (*PaginatedResult, error) { - nextPath, ok := p.paginationHeaders()["next"] - if !ok { - return nil, nil - } - nextPage := p.buildPath(p.path, nextPath) - req := p.req - req.path = nextPage - req.params = nil - return newPaginatedResult(ctx, p.opts, req) -} - -// Items gives a slice of results of the current page. -func (p *PaginatedResult) Items() []interface{} { - if p.items == nil { - v := reflect.ValueOf(p.typItems) - if v.Kind() == reflect.Slice { - p.items = make([]interface{}, v.Len()) - for i := range p.items { - p.items[i] = v.Index(i).Interface() - } - } else { - p.items = []interface{}{p.typItems} - } - } - return p.items -} - -// PresenceMessages gives a slice of presence messages for the current path. -// The method panics if the underlying paginated result is not a presence message. -func (p *PaginatedResult) PresenceMessages() []*PresenceMessage { - items, ok := p.typItems.([]*proto.PresenceMessage) - if !ok { - panic(errInvalidType{typ: p.req.typ}) - } - return items -} - type Stats = proto.Stats type StatsMessageTypes = proto.MessageTypes type StatsMessageCount = proto.MessageCount @@ -347,48 +202,3 @@ type StatsRequestCount = proto.RequestCount type StatsPushStats = proto.PushStats type StatsXchgMessages = proto.XchgMessages type PushStats = proto.PushStats - -func (c *PaginatedResult) buildPaginatedPath(path string, params *PaginateParams) (string, error) { - if params == nil { - return path, nil - } - values := &url.Values{} - err := params.EncodeValues(values) - if err != nil { - return "", newError(50000, err) - } - queryString := values.Encode() - if len(queryString) > 0 { - return path + "?" + queryString, nil - } - return path, nil -} - -// buildPath finds the absolute path based on the path parameter and the new relative path. -func (p *PaginatedResult) buildPath(origPath string, newRelativePath string) string { - if i := strings.IndexRune(origPath, '?'); i != -1 { - origPath = origPath[:i] - } - return path.Join(path.Dir(origPath), newRelativePath) -} - -func (p *PaginatedResult) paginationHeaders() map[string]string { - if p.headers == nil { - p.headers = make(map[string]string) - for _, link := range p.links { - if result := relLinkRegexp.FindStringSubmatch(link); result != nil { - p.addMatch(result) - } - } - } - return p.headers -} - -func (p *PaginatedResult) addMatch(matches []string) { - matchingNames := relLinkRegexp.SubexpNames() - matchMap := map[string]string{} - for i, value := range matches { - matchMap[matchingNames[i]] = value - } - p.headers[matchMap["rel"]] = matchMap["url"] -} diff --git a/ably/paginated_result_test.go b/ably/paginated_result_test.go deleted file mode 100644 index 5f740027..00000000 --- a/ably/paginated_result_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package ably_test - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/ably/ably-go/ably" -) - -func TestPaginatedResult(t *testing.T) { - t.Parallel() - result := &ably.PaginatedResult{} - newPath := result.BuildPath("/path/to/resource?hello", "./newresource?world") - expected := "/path/to/newresource?world" - if newPath != expected { - t.Errorf("expected %s got %s", expected, newPath) - } -} - -func TestMalformedPaginatedResult(t *testing.T) { - bodyBytes, _ := json.Marshal([]string{"\x00 not really a PaginatedResult"}) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(400) - w.Write(bodyBytes) - })) - defer srv.Close() - - srvAddr := srv.Listener.Addr().(*net.TCPAddr) - client, err := ably.NewREST( - ably.WithToken("xxxxxxx.yyyyyyy:zzzzzzz"), - ably.WithTLS(false), - ably.WithRESTHost(srvAddr.IP.String()), - ably.WithPort(srvAddr.Port), - ) - if err != nil { - t.Fatal(err) - } - - resp, err := client.Request(context.Background(), "POST", "/foo", nil, nil, nil) - if resp != nil { - t.Errorf("expected no HTTPPaginatedResult; got %+v", resp) - } - if err == nil { - t.Fatal("expected an error") - } - errMsg := err.Error() - if !strings.Contains(errMsg, "status: 400") { - t.Errorf("expected error to contain status code; got: %v", err) - } - if !strings.Contains(errMsg, fmt.Sprintf("%q", bodyBytes)) { - t.Errorf("expected error to contain body; got: %v", err) - } -} diff --git a/ably/proto/http.go b/ably/proto/http.go index 38dd222b..21a0b551 100644 --- a/ably/proto/http.go +++ b/ably/proto/http.go @@ -5,7 +5,7 @@ const ( AblyVersionHeader = "X-Ably-Version" AblyLibHeader = "X-Ably-Lib" AblyErrorCodeHeader = "X-Ably-Errorcode" - AblyErrormessageHeader = "X-Ably-Errormessage" + AblyErrorMessageHeader = "X-Ably-Errormessage" LibraryVersion = "1.2.0-apipreview.4" LibraryName = "go" LibraryString = LibraryName + "-" + LibraryVersion diff --git a/ably/proto/presence_message.go b/ably/proto/presence_message.go index ff40e2cc..3b40d0f9 100644 --- a/ably/proto/presence_message.go +++ b/ably/proto/presence_message.go @@ -20,11 +20,11 @@ type PresenceMessage struct { } func (m PresenceMessage) String() string { - return fmt.Sprintf("", [...]string{ + return fmt.Sprintf("", [...]string{ "absent", "present", "enter", "leave", "update", - }[m.Action], m.Data) + }[m.Action], m.ClientID, m.Data) } diff --git a/ably/readme_examples_test.go b/ably/readme_examples_test.go index 5998f981..b2643a34 100644 --- a/ably/readme_examples_test.go +++ b/ably/readme_examples_test.go @@ -190,58 +190,62 @@ func TestReadmeExamples(t *testing.T) { /* README.md:267 */ } /* README.md:273 */ { - /* README.md:277 */ page, err := channel.Presence.Get(ctx, nil) - /* README.md:278 */ for ; err == nil && page != nil; page, err = page.Next(ctx) { - /* README.md:279 */ for _, presence := range page.PresenceMessages() { - /* README.md:280 */ fmt.Println(presence) - /* README.md:281 */ - } - /* README.md:282 */ + /* README.md:277 */ pages, err := channel.Presence.Get().Pages(ctx) + /* README.md:278 */ if err != nil { + /* README.md:279 */ panic(err) + /* README.md:280 */ } - /* README.md:283 */ if err != nil { - /* README.md:284 */ panic(err) + /* README.md:281 */ for pages.Next(ctx) { + /* README.md:282 */ for _, presence := range pages.Items() { + /* README.md:283 */ fmt.Println(presence) + /* README.md:284 */ + } /* README.md:285 */ } - /* README.md:289 */ + /* README.md:286 */ if err := pages.Err(); err != nil { + /* README.md:287 */ panic(err) + /* README.md:288 */ + } + /* README.md:292 */ } - /* README.md:295 */ { - /* README.md:299 */ pages, err := channel.Presence.History().Pages(ctx) - /* README.md:300 */ if err != nil { - /* README.md:301 */ panic(err) - /* README.md:302 */ + /* README.md:298 */ { + /* README.md:302 */ pages, err := channel.Presence.History().Pages(ctx) + /* README.md:303 */ if err != nil { + /* README.md:304 */ panic(err) + /* README.md:305 */ } - /* README.md:303 */ for pages.Next(ctx) { - /* README.md:304 */ for _, presence := range pages.Items() { - /* README.md:305 */ fmt.Println(presence) - /* README.md:306 */ + /* README.md:306 */ for pages.Next(ctx) { + /* README.md:307 */ for _, presence := range pages.Items() { + /* README.md:308 */ fmt.Println(presence) + /* README.md:309 */ } - /* README.md:307 */ - } - /* README.md:308 */ if err := pages.Err(); err != nil { - /* README.md:309 */ panic(err) /* README.md:310 */ } - /* README.md:314 */ + /* README.md:311 */ if err := pages.Err(); err != nil { + /* README.md:312 */ panic(err) + /* README.md:313 */ + } + /* README.md:317 */ } - /* README.md:320 */ { - /* README.md:324 */ pages, err := client.Stats().Pages(ctx) - /* README.md:325 */ if err != nil { - /* README.md:326 */ panic(err) - /* README.md:327 */ + /* README.md:323 */ { + /* README.md:327 */ pages, err := client.Stats().Pages(ctx) + /* README.md:328 */ if err != nil { + /* README.md:329 */ panic(err) + /* README.md:330 */ } - /* README.md:328 */ for pages.Next(ctx) { - /* README.md:329 */ for _, stat := range pages.Items() { - /* README.md:330 */ fmt.Println(stat) - /* README.md:331 */ + /* README.md:331 */ for pages.Next(ctx) { + /* README.md:332 */ for _, stat := range pages.Items() { + /* README.md:333 */ fmt.Println(stat) + /* README.md:334 */ } - /* README.md:332 */ - } - /* README.md:333 */ if err := pages.Err(); err != nil { - /* README.md:334 */ panic(err) /* README.md:335 */ } - /* README.md:339 */ + /* README.md:336 */ if err := pages.Err(); err != nil { + /* README.md:337 */ panic(err) + /* README.md:338 */ + } + /* README.md:342 */ } - /* README.md:343 */ + /* README.md:346 */ } } diff --git a/ably/rest_channel.go b/ably/rest_channel.go index 8ac47b56..852771a1 100644 --- a/ably/rest_channel.go +++ b/ably/rest_channel.go @@ -194,7 +194,7 @@ func (o *historyOptions) apply(opts ...HistoryOption) url.Values { // HistoryRequest represents a request prepared by the RESTChannel.History or // RealtimeChannel.History method, ready to be performed by its Pages or Items methods. type HistoryRequest struct { - r paginatedRequestNew + r paginatedRequest channel *RESTChannel } @@ -210,7 +210,7 @@ func (r HistoryRequest) Pages(ctx context.Context) (*MessagesPaginatedResult, er // // See "Paginated results" section in the package-level documentation. type MessagesPaginatedResult struct { - PaginatedResultNew + PaginatedResult items []*Message } @@ -293,7 +293,7 @@ func (t *fullMessagesDecoder) decodeMessagesData() { } type MessagesPaginatedItems struct { - PaginatedResultNew + PaginatedResult items []*Message item *Message next func(context.Context) (int, bool) diff --git a/ably/rest_channel_spec_test.go b/ably/rest_channel_spec_test.go new file mode 100644 index 00000000..1128920d --- /dev/null +++ b/ably/rest_channel_spec_test.go @@ -0,0 +1,211 @@ +package ably_test + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/ably/ably-go/ably" + "github.com/ably/ably-go/ably/ablytest" +) + +func TestRSL1f1(t *testing.T) { + t.Parallel() + app, err := ablytest.NewSandbox(nil) + if err != nil { + t.Fatal(err) + } + defer app.Close() + opts := app.Options() + // RSL1f + opts = append(opts, ably.WithUseTokenAuth(false)) + client, err := ably.NewREST(opts...) + if err != nil { + t.Fatal(err) + } + channel := client.Channels.Get("RSL1f") + id := "any_client_id" + var msgs []*ably.Message + size := 10 + for i := 0; i < size; i++ { + msgs = append(msgs, &ably.Message{ + ClientID: id, + Data: fmt.Sprint(i), + }) + } + err = channel.PublishMultiple(context.Background(), msgs) + if err != nil { + t.Fatal(err) + } + var m []*ably.Message + err = ablytest.AllPages(&m, channel.History()) + if err != nil { + t.Fatal(err) + } + n := len(m) + if n != size { + t.Errorf("expected %d messages got %d", size, n) + } + for _, v := range m { + if v.ClientID != id { + t.Errorf("expected clientId %s got %s data:%v", id, v.ClientID, v.Data) + } + } +} + +func TestRSL1g(t *testing.T) { + t.Parallel() + app, err := ablytest.NewSandbox(nil) + if err != nil { + t.Fatal(err) + } + defer app.Close() + opts := append(app.Options(), + ably.WithUseTokenAuth(true), + ) + clientID := "some_client_id" + opts = append(opts, ably.WithClientID(clientID)) + client, err := ably.NewREST(opts...) + if err != nil { + t.Fatal(err) + } + t.Run("RSL1g1b", func(ts *testing.T) { + channel := client.Channels.Get("RSL1g1b") + err := channel.PublishMultiple(context.Background(), []*ably.Message{ + {Name: "some 1"}, + {Name: "some 2"}, + {Name: "some 3"}, + }) + if err != nil { + ts.Fatal(err) + } + var history []*ably.Message + err = ablytest.AllPages(&history, channel.History()) + if err != nil { + ts.Fatal(err) + } + for _, m := range history { + if m.ClientID != clientID { + ts.Errorf("expected %s got %s", clientID, m.ClientID) + } + } + }) + t.Run("RSL1g2", func(ts *testing.T) { + channel := client.Channels.Get("RSL1g2") + err := channel.PublishMultiple(context.Background(), []*ably.Message{ + {Name: "1", ClientID: clientID}, + {Name: "2", ClientID: clientID}, + {Name: "3", ClientID: clientID}, + }) + if err != nil { + ts.Fatal(err) + } + var history []*ably.Message + err = ablytest.AllPages(&history, channel.History()) + if err != nil { + ts.Fatal(err) + } + for _, m := range history { + if m.ClientID != clientID { + ts.Errorf("expected %s got %s", clientID, m.ClientID) + } + } + }) + t.Run("RSL1g3", func(ts *testing.T) { + channel := client.Channels.Get("RSL1g3") + err := channel.PublishMultiple(context.Background(), []*ably.Message{ + {Name: "1", ClientID: clientID}, + {Name: "2", ClientID: "other client"}, + {Name: "3", ClientID: clientID}, + }) + if err == nil { + ts.Fatal("expected an error") + } + }) +} + +func TestHistory_RSL2_RSL2b3(t *testing.T) { + t.Parallel() + + for _, limit := range []int{2, 3, 20} { + t.Run(fmt.Sprintf("limit=%d", limit), func(t *testing.T) { + t.Parallel() + app, rest := ablytest.NewREST() + defer app.Close() + channel := rest.Channels.Get("test") + + fixtures := historyFixtures() + channel.PublishMultiple(context.Background(), fixtures) + + err := ablytest.TestPagination( + reverseMessages(fixtures), + channel.History(ably.HistoryWithLimit(limit)), + limit, + ablytest.PaginationWithEqual(messagesEqual), + ) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestHistory_Direction_RSL2b2(t *testing.T) { + t.Parallel() + + for _, c := range []struct { + direction ably.Direction + expected []*ably.Message + }{ + { + direction: ably.Backwards, + expected: reverseMessages(historyFixtures()), + }, + { + direction: ably.Forwards, + expected: historyFixtures(), + }, + } { + c := c + t.Run(fmt.Sprintf("direction=%v", c.direction), func(t *testing.T) { + app, rest := ablytest.NewREST() + defer app.Close() + channel := rest.Channels.Get("test") + + fixtures := historyFixtures() + channel.PublishMultiple(context.Background(), fixtures) + + expected := c.expected + + err := ablytest.TestPagination(expected, channel.History( + ably.HistoryWithLimit(len(expected)), + ably.HistoryWithDirection(c.direction), + ), 100, ablytest.PaginationWithEqual(messagesEqual)) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func historyFixtures() []*ably.Message { + var fixtures []*ably.Message + for i := 0; i < 10; i++ { + fixtures = append(fixtures, &ably.Message{Name: fmt.Sprintf("msg%d", i)}) + } + return fixtures +} + +func reverseMessages(msgs []*ably.Message) []*ably.Message { + var reversed []*ably.Message + for i := len(msgs) - 1; i >= 0; i-- { + reversed = append(reversed, msgs[i]) + } + return reversed +} + +func messagesEqual(x, y interface{}) bool { + mx, my := x.(*ably.Message), y.(*ably.Message) + return mx.Name == my.Name && reflect.DeepEqual(mx.Data, my.Data) +} diff --git a/ably/rest_channel_test.go b/ably/rest_channel_test.go index acd4f0c8..d4776f52 100644 --- a/ably/rest_channel_test.go +++ b/ably/rest_channel_test.go @@ -473,203 +473,3 @@ func TestIdempotent_retry(t *testing.T) { }) }) } - -func TestRSL1f1(t *testing.T) { - t.Parallel() - app, err := ablytest.NewSandbox(nil) - if err != nil { - t.Fatal(err) - } - defer app.Close() - opts := app.Options() - // RSL1f - opts = append(opts, ably.WithUseTokenAuth(false)) - client, err := ably.NewREST(opts...) - if err != nil { - t.Fatal(err) - } - channel := client.Channels.Get("RSL1f") - id := "any_client_id" - var msgs []*ably.Message - size := 10 - for i := 0; i < size; i++ { - msgs = append(msgs, &ably.Message{ - ClientID: id, - Data: fmt.Sprint(i), - }) - } - err = channel.PublishMultiple(context.Background(), msgs) - if err != nil { - t.Fatal(err) - } - var m []*ably.Message - err = ablytest.AllPages(&m, channel.History()) - if err != nil { - t.Fatal(err) - } - n := len(m) - if n != size { - t.Errorf("expected %d messages got %d", size, n) - } - for _, v := range m { - if v.ClientID != id { - t.Errorf("expected clientId %s got %s data:%v", id, v.ClientID, v.Data) - } - } -} - -func TestRSL1g(t *testing.T) { - t.Parallel() - app, err := ablytest.NewSandbox(nil) - if err != nil { - t.Fatal(err) - } - defer app.Close() - opts := append(app.Options(), - ably.WithUseTokenAuth(true), - ) - clientID := "some_client_id" - opts = append(opts, ably.WithClientID(clientID)) - client, err := ably.NewREST(opts...) - if err != nil { - t.Fatal(err) - } - t.Run("RSL1g1b", func(ts *testing.T) { - channel := client.Channels.Get("RSL1g1b") - err := channel.PublishMultiple(context.Background(), []*ably.Message{ - {Name: "some 1"}, - {Name: "some 2"}, - {Name: "some 3"}, - }) - if err != nil { - ts.Fatal(err) - } - var history []*ably.Message - err = ablytest.AllPages(&history, channel.History()) - if err != nil { - ts.Fatal(err) - } - for _, m := range history { - if m.ClientID != clientID { - ts.Errorf("expected %s got %s", clientID, m.ClientID) - } - } - }) - t.Run("RSL1g2", func(ts *testing.T) { - channel := client.Channels.Get("RSL1g2") - err := channel.PublishMultiple(context.Background(), []*ably.Message{ - {Name: "1", ClientID: clientID}, - {Name: "2", ClientID: clientID}, - {Name: "3", ClientID: clientID}, - }) - if err != nil { - ts.Fatal(err) - } - var history []*ably.Message - err = ablytest.AllPages(&history, channel.History()) - if err != nil { - ts.Fatal(err) - } - for _, m := range history { - if m.ClientID != clientID { - ts.Errorf("expected %s got %s", clientID, m.ClientID) - } - } - }) - t.Run("RSL1g3", func(ts *testing.T) { - channel := client.Channels.Get("RSL1g3") - err := channel.PublishMultiple(context.Background(), []*ably.Message{ - {Name: "1", ClientID: clientID}, - {Name: "2", ClientID: "other client"}, - {Name: "3", ClientID: clientID}, - }) - if err == nil { - ts.Fatal("expected an error") - } - }) -} - -func TestHistory_RSL2_RSL2b3(t *testing.T) { - t.Parallel() - - for _, limit := range []int{2, 3, 20} { - t.Run(fmt.Sprintf("limit=%d", limit), func(t *testing.T) { - t.Parallel() - app, rest := ablytest.NewREST() - defer app.Close() - channel := rest.Channels.Get("test") - - fixtures := historyFixtures() - channel.PublishMultiple(context.Background(), fixtures) - - err := ablytest.TestPagination( - reverseMessages(fixtures), - channel.History(ably.HistoryWithLimit(limit)), - limit, - ablytest.PaginationWithEqual(messagesEqual), - ) - if err != nil { - t.Fatal(err) - } - }) - } -} - -func TestHistory_Direction_RSL2b2(t *testing.T) { - t.Parallel() - - for _, c := range []struct { - direction ably.Direction - expected []*ably.Message - }{ - { - direction: ably.Backwards, - expected: reverseMessages(historyFixtures()), - }, - { - direction: ably.Forwards, - expected: historyFixtures(), - }, - } { - c := c - t.Run(fmt.Sprintf("direction=%v", c.direction), func(t *testing.T) { - app, rest := ablytest.NewREST() - defer app.Close() - channel := rest.Channels.Get("test") - - fixtures := historyFixtures() - channel.PublishMultiple(context.Background(), fixtures) - - expected := c.expected - - err := ablytest.TestPagination(expected, channel.History( - ably.HistoryWithLimit(len(expected)), - ably.HistoryWithDirection(c.direction), - ), 100, ablytest.PaginationWithEqual(messagesEqual)) - if err != nil { - t.Fatal(err) - } - }) - } -} - -func historyFixtures() []*ably.Message { - var fixtures []*ably.Message - for i := 0; i < 10; i++ { - fixtures = append(fixtures, &ably.Message{Name: fmt.Sprintf("msg%d", i)}) - } - return fixtures -} - -func reverseMessages(msgs []*ably.Message) []*ably.Message { - var reversed []*ably.Message - for i := len(msgs) - 1; i >= 0; i-- { - reversed = append(reversed, msgs[i]) - } - return reversed -} - -func messagesEqual(x, y interface{}) bool { - mx, my := x.(*ably.Message), y.(*ably.Message) - return mx.Name == my.Name && reflect.DeepEqual(mx.Data, my.Data) -} diff --git a/ably/rest_client.go b/ably/rest_client.go index 0a3ecbbe..3b6dc3e5 100644 --- a/ably/rest_client.go +++ b/ably/rest_client.go @@ -134,6 +134,9 @@ func NewREST(options ...ClientOption) (*REST, error) { cache: make(map[string]*RESTChannel), client: c, } + c.successFallbackHost = &fallbackCache{ + duration: c.opts.fallbackRetryTimeout(), + } return c, nil } @@ -218,7 +221,7 @@ func (o *statsOptions) apply(opts ...StatsOption) url.Values { // StatsRequest represents a request prepared by the REST.Stats or // Realtime.Stats method, ready to be performed by its Pages or Items methods. type StatsRequest struct { - r paginatedRequestNew + r paginatedRequest } // Pages returns an iterator for whole pages of Stats. @@ -233,7 +236,7 @@ func (r StatsRequest) Pages(ctx context.Context) (*StatsPaginatedResult, error) // // See "Paginated results" section in the package-level documentation. type StatsPaginatedResult struct { - PaginatedResultNew + PaginatedResult items []*Stats } @@ -267,7 +270,7 @@ func (r StatsRequest) Items(ctx context.Context) (*StatsPaginatedItems, error) { } type StatsPaginatedItems struct { - PaginatedResultNew + PaginatedResult items []*Stats item *Stats next func(context.Context) (int, bool) @@ -308,30 +311,174 @@ type request struct { header http.Header } -// Request sends http request to ably. -// spec RSC19 -func (c *REST) Request(ctx context.Context, method string, path string, params *PaginateParams, body interface{}, headers http.Header) (*HTTPPaginatedResponse, error) { +// Request prepares an arbitrary request to the REST API. +func (c *REST) Request(method string, path string, o ...RequestOption) RESTRequest { method = strings.ToUpper(method) - switch method { - case "GET", "POST", "PUT", "PATCH", "DELETE": // spec RSC19a - return newHTTPPaginatedResult(ctx, path, params, func(ctx context.Context, p string) (*http.Response, error) { + var opts requestOptions + opts.apply(o...) + return RESTRequest{r: paginatedRequest{ + path: path, + params: opts.params, + query: func(ctx context.Context, path string) (*http.Response, error) { + switch method { + case "GET", "POST", "PUT", "PATCH", "DELETE": // spec RSC19a + default: + return nil, fmt.Errorf("invalid HTTP method: %q", method) + } + req := &request{ Method: method, - Path: p, - In: body, - header: headers, + Path: path, + In: opts.body, + header: opts.headers, } return c.doWithHandle(ctx, req, func(resp *http.Response, out interface{}) (*http.Response, error) { return resp, nil }) - }, c.logger()) - default: - return nil, newErrorFromProto(&proto.ErrorInfo{ - Message: fmt.Sprintf("%s method is not supported", method), - Code: int(ErrMethodNotAllowed), - StatusCode: http.StatusMethodNotAllowed, - }) + }, + }} +} + +type requestOptions struct { + params url.Values + headers http.Header + body interface{} +} + +// A RequestOption configures a call to REST.Request. +type RequestOption func(*requestOptions) + +func RequestWithParams(params url.Values) RequestOption { + return func(o *requestOptions) { + o.params = params + } +} + +func RequestWithHeaders(headers http.Header) RequestOption { + return func(o *requestOptions) { + o.headers = headers + } +} + +func RequestWithBody(body interface{}) RequestOption { + return func(o *requestOptions) { + o.body = body + } +} + +func (o *requestOptions) apply(opts ...RequestOption) { + o.params = make(url.Values) + for _, opt := range opts { + opt(o) + } +} + +// RESTRequest represents a request prepared by the REST.Request method, ready +// to be performed by its Pages or Items methods. +type RESTRequest struct { + r paginatedRequest +} + +// Pages returns an iterator for whole pages of results. +// +// See "Paginated results" section in the package-level documentation. +func (r RESTRequest) Pages(ctx context.Context) (*HTTPPaginatedResponse, error) { + var res HTTPPaginatedResponse + return &res, res.load(ctx, r.r) +} + +// A HTTPPaginatedResponse is an iterator for the response of a REST request. +// +// See "Paginated results" section in the package-level documentation. +type HTTPPaginatedResponse struct { + PaginatedResult + items jsonRawArray +} + +func (r *HTTPPaginatedResponse) StatusCode() int { + return r.res.StatusCode +} + +func (r *HTTPPaginatedResponse) Success() bool { + return 200 <= r.res.StatusCode && r.res.StatusCode < 300 +} + +func (r *HTTPPaginatedResponse) ErrorCode() ErrorCode { + codeStr := r.res.Header.Get(proto.AblyErrorCodeHeader) + if codeStr == "" { + return ErrNotSet + } + code, err := strconv.Atoi(codeStr) + if err != nil { + return ErrNotSet + } + return ErrorCode(code) +} + +func (r *HTTPPaginatedResponse) ErrorMessage() string { + return r.res.Header.Get(proto.AblyErrorMessageHeader) +} + +func (r *HTTPPaginatedResponse) Headers() http.Header { + return r.res.Header +} + +// Next retrieves the next page of results. +// +// See the "Paginated results" section in the package-level documentation. +func (p *HTTPPaginatedResponse) Next(ctx context.Context) bool { + p.items = nil + return p.next(ctx, &p.items) +} + +// Items unmarshals the current page of results as JSON into the provided +// variable. +// +// See the "Paginated results" section in the package-level documentation. +func (p *HTTPPaginatedResponse) Items(dst interface{}) error { + return json.Unmarshal(p.items, dst) +} + +// Items returns a convenience iterator for single items, over an underlying +// paginated iterator. +// +// For each item, +// +// See "Paginated results" section in the package-level documentation. +func (r RESTRequest) Items(ctx context.Context) (*RESTPaginatedItems, error) { + var res RESTPaginatedItems + var err error + res.next, err = res.loadItems(ctx, r.r, func() (interface{}, func() int) { + res.items = nil + return &res.items, func() int { return len(res.items) } + }) + return &res, err +} + +type RESTPaginatedItems struct { + PaginatedResult + items []json.RawMessage + item json.RawMessage + next func(context.Context) (int, bool) +} + +// Next retrieves the next result. +// +// See the "Paginated results" section in the package-level documentation. +func (p *RESTPaginatedItems) Next(ctx context.Context) bool { + i, ok := p.next(ctx) + if !ok { + return false } + p.item = p.items[i] + return true +} + +// Item unmarshal the current result as JSON into the provided variable. +// +// See the "Paginated results" section in the package-level documentation. +func (p *RESTPaginatedItems) Item(dst interface{}) error { + return json.Unmarshal(p.item, dst) } func (c *REST) get(ctx context.Context, path string, out interface{}) (*http.Response, error) { @@ -422,12 +569,6 @@ func (f *fallbackCache) put(host string) { func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.Response, interface{}) (*http.Response, error)) (*http.Response, error) { log := c.opts.Logger.sugar() - if c.successFallbackHost == nil { - c.successFallbackHost = &fallbackCache{ - duration: c.opts.fallbackRetryTimeout(), - } - log.Verbosef("RestClient: setup fallback duration to %v", c.successFallbackHost.duration) - } req, err := c.newHTTPRequest(ctx, r) if err != nil { return nil, err @@ -681,3 +822,25 @@ func decodeResp(resp *http.Response, out interface{}) error { return decode(typ, bytes.NewReader(b), out) } + +// jsonRawArray is a json.RawMessage that, if it's not an array already, wrap +// itself in a JSON array when marshaled into. +type jsonRawArray json.RawMessage + +func (m *jsonRawArray) UnmarshalJSON(data []byte) error { + err := (*json.RawMessage)(m).UnmarshalJSON(data) + if err != nil { + return err + } + token, _ := json.NewDecoder(bytes.NewReader(*m)).Token() + if token != json.Delim('[') { + *m = append( + jsonRawArray("["), + append( + *m, + ']', + )..., + ) + } + return nil +} diff --git a/ably/rest_client_test.go b/ably/rest_client_test.go index 175c5c2e..52efec36 100644 --- a/ably/rest_client_test.go +++ b/ably/rest_client_test.go @@ -260,7 +260,7 @@ func TestRSC7(t *testing.T) { t.Fatal(err) } - _, _ = c.Request(context.Background(), "POST", "/foo", nil, nil, nil) + _, _ = c.Request("POST", "/foo").Pages(context.Background()) var req *http.Request ablytest.Instantly.Recv(t, &req, requests, t.Fatalf) diff --git a/ably/rest_presence.go b/ably/rest_presence.go index a67d67da..11b804ae 100644 --- a/ably/rest_presence.go +++ b/ably/rest_presence.go @@ -15,12 +15,45 @@ type RESTPresence struct { channel *RESTChannel } -// Get gives the channel's presence messages according to the given parameters. -// The returned result can be inspected for the presence messages via -// the PresenceMessages() method. -func (p *RESTPresence) Get(ctx context.Context, params *PaginateParams) (*PaginatedResult, error) { - path := p.channel.baseURL + "/presence" - return newPaginatedResult(ctx, nil, paginatedRequest{typ: presMsgType, path: path, params: params, query: query(p.client.get), logger: p.logger(), respCheck: checkValidHTTPResponse}) +func (c *RESTPresence) Get(o ...GetPresenceOption) PresenceRequest { + params := (&getPresenceOptions{}).apply(o...) + return PresenceRequest{ + r: c.client.newPaginatedRequest("/channels/"+c.channel.Name+"/presence", params), + channel: c.channel, + } +} + +// A GetPresenceOption configures a call to RESTPresence.Get or RealtimePresence.Get. +type GetPresenceOption func(*getPresenceOptions) + +func GetPresenceWithLimit(limit int) GetPresenceOption { + return func(o *getPresenceOptions) { + o.params.Set("limit", strconv.Itoa(limit)) + } +} + +func GetPresenceWithClientID(clientID string) GetPresenceOption { + return func(o *getPresenceOptions) { + o.params.Set("clientId", clientID) + } +} + +func GetPresenceWithConnectionID(connectionID string) GetPresenceOption { + return func(o *getPresenceOptions) { + o.params.Set("connectionId", connectionID) + } +} + +type getPresenceOptions struct { + params url.Values +} + +func (o *getPresenceOptions) apply(opts ...GetPresenceOption) url.Values { + o.params = make(url.Values) + for _, opt := range opts { + opt(o) + } + return o.params } func (p *RESTPresence) logger() *LoggerOptions { @@ -28,9 +61,9 @@ func (p *RESTPresence) logger() *LoggerOptions { } // History gives the channel's presence history. -func (c *RESTPresence) History(o ...PresenceHistoryOption) PresenceHistoryRequest { +func (c *RESTPresence) History(o ...PresenceHistoryOption) PresenceRequest { params := (&presenceHistoryOptions{}).apply(o...) - return PresenceHistoryRequest{ + return PresenceRequest{ r: c.client.newPaginatedRequest("/channels/"+c.channel.Name+"/presence/history", params), channel: c.channel, } @@ -75,17 +108,17 @@ func (o *presenceHistoryOptions) apply(opts ...PresenceHistoryOption) url.Values return o.params } -// PresenceHistoryRequest represents a request prepared by the RESTPresence.History or +// PresenceRequest represents a request prepared by the RESTPresence.History or // RealtimePresence.History method, ready to be performed by its Pages or Items methods. -type PresenceHistoryRequest struct { - r paginatedRequestNew +type PresenceRequest struct { + r paginatedRequest channel *RESTChannel } // Pages returns an iterator for whole pages of presence messages. // // See "Paginated results" section in the package-level documentation. -func (r PresenceHistoryRequest) Pages(ctx context.Context) (*PresencePaginatedResult, error) { +func (r PresenceRequest) Pages(ctx context.Context) (*PresencePaginatedResult, error) { res := PresencePaginatedResult{decoder: r.channel.fullPresenceDecoder} return &res, res.load(ctx, r.r) } @@ -94,7 +127,7 @@ func (r PresenceHistoryRequest) Pages(ctx context.Context) (*PresencePaginatedRe // // See "Paginated results" section in the package-level documentation. type PresencePaginatedResult struct { - PaginatedResultNew + PaginatedResult items []*PresenceMessage decoder func(*[]*PresenceMessage) interface{} } @@ -118,7 +151,7 @@ func (p *PresencePaginatedResult) Items() []*PresenceMessage { // paginated iterator. // // See "Paginated results" section in the package-level documentation. -func (r PresenceHistoryRequest) Items(ctx context.Context) (*PresencePaginatedItems, error) { +func (r PresenceRequest) Items(ctx context.Context) (*PresencePaginatedItems, error) { var res PresencePaginatedItems var err error res.next, err = res.loadItems(ctx, r.r, func() (interface{}, func() int) { @@ -177,7 +210,7 @@ func (t *fullPresenceDecoder) decodeMessagesData() { } type PresencePaginatedItems struct { - PaginatedResultNew + PaginatedResult items []*PresenceMessage item *PresenceMessage next func(context.Context) (int, bool) diff --git a/ably/rest_presence_test.go b/ably/rest_presence_spec_test.go similarity index 50% rename from ably/rest_presence_test.go rename to ably/rest_presence_spec_test.go index 022b87b7..0b761664 100644 --- a/ably/rest_presence_test.go +++ b/ably/rest_presence_spec_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "reflect" + "sort" "testing" "github.com/ably/ably-go/ably" @@ -12,70 +13,6 @@ import ( "github.com/ably/ably-go/ably/proto" ) -func TestChannel_Presence(t *testing.T) { - t.Parallel() - app, err := ablytest.NewSandbox(nil) - if err != nil { - t.Fatal(err) - } - defer app.Close() - client, err := ably.NewREST(app.Options()...) - if err != nil { - t.Fatal(err) - } - channel := client.Channels.Get("persisted:presence_fixtures") - presence := channel.Presence - - t.Run("Get", func(ts *testing.T) { - page, err := presence.Get(context.Background(), nil) - if err != nil { - ts.Fatal(err) - } - n := len(page.PresenceMessages()) - expect := len(app.Config.Channels[0].Presence) - if n != expect { - ts.Errorf("expected %d got %d", expect, n) - } - - ts.Run("With limit option", func(ts *testing.T) { - limit := 2 - page1, err := presence.Get(context.Background(), &ably.PaginateParams{Limit: limit}) - if err != nil { - ts.Fatal(err) - } - n := len(page1.PresenceMessages()) - if n != limit { - ts.Errorf("expected %d messages got %d", limit, n) - } - n = len(page1.Items()) - if n != limit { - ts.Errorf("expected %d items got %d", limit, n) - } - - page2, err := page1.Next(context.Background()) - if err != nil { - ts.Fatal(err) - } - n = len(page2.PresenceMessages()) - if n != limit { - ts.Errorf("expected %d messages got %d", limit, n) - } - n = len(page2.Items()) - if n != limit { - ts.Errorf("expected %d items got %d", limit, n) - } - - noPage, err := page2.Next(context.Background()) - if err != nil { - ts.Fatal(err) - } - if noPage != nil { - ts.Fatal("no more pages expected") - } - }) - }) -} - func TestPresenceHistory_RSP4_RSP4b3(t *testing.T) { t.Parallel() @@ -150,6 +87,123 @@ func TestPresenceHistory_Direction_RSP4b2(t *testing.T) { } } +func TestPresenceGet_RSP3_RSP3a1(t *testing.T) { + t.Parallel() + + for _, limit := range []int{2, 3, 20} { + t.Run(fmt.Sprintf("limit=%d", limit), func(t *testing.T) { + t.Parallel() + + app, rest := ablytest.NewREST() + defer app.Close() + channel := rest.Channels.Get("persisted:presence_fixtures") + + expected := persistedPresenceFixtures() + + var err error + if !ablytest.Soon.IsTrue(func() bool { + err = ablytest.TestPagination( + expected, + channel.Presence.Get(ably.GetPresenceWithLimit(limit)), + limit, + ablytest.PaginationWithEqual(presenceEqual), + ablytest.PaginationWithSortResult(sortPresenceByClientID), + ) + return err == nil + }) { + t.Fatal(err) + } + }) + } +} + +func TestPresenceGet_ClientID_RSP3a2(t *testing.T) { + t.Parallel() + + for _, clientID := range []string{ + "client_bool", + "client_string", + } { + clientID := clientID + t.Run(fmt.Sprintf("clientID=%v", clientID), func(t *testing.T) { + t.Parallel() + + app, rest := ablytest.NewREST() + defer app.Close() + channel := rest.Channels.Get("persisted:presence_fixtures") + + expected := persistedPresenceFixtures(func(p ablytest.Presence) bool { + return p.ClientID == clientID + }) + + var err error + if !ablytest.Soon.IsTrue(func() bool { + err = ablytest.TestPagination( + expected, + channel.Presence.Get(ably.GetPresenceWithClientID(clientID)), + 1, + ablytest.PaginationWithEqual(presenceEqual), + ablytest.PaginationWithSortResult(sortPresenceByClientID), + ) + return err == nil + }) { + t.Fatal(err) + } + }) + } +} + +func TestPresenceGet_ConnectionID_RSP3a3(t *testing.T) { + t.Parallel() + + app, rest := ablytest.NewREST() + defer app.Close() + + expectedByConnID := map[string]ably.Message{} + + for i := 0; i < 3; i++ { + realtime := app.NewRealtime() + defer safeclose(t, ablytest.FullRealtimeCloser(realtime)) + m := ably.Message{ + Data: fmt.Sprintf("msg%d", i), + ClientID: fmt.Sprintf("client%d", i), + } + realtime.Channels.Get("test").Presence.EnterClient(context.Background(), m.ClientID, m.Data) + expectedByConnID[realtime.Connection.ID()] = m + } + + channel := rest.Channels.Get("test") + + var rg ablytest.ResultGroup + + for connID, expected := range expectedByConnID { + connID, expected := connID, expected + rg.GoAdd(func(ctx context.Context) error { + var err error + if !ablytest.Soon.IsTrue(func() bool { + err = ablytest.TestPagination( + []*ably.PresenceMessage{{ + Action: proto.PresencePresent, + Message: expected, + }}, + channel.Presence.Get(ably.GetPresenceWithConnectionID(connID)), + 1, + ablytest.PaginationWithEqual(presenceEqual), + ablytest.PaginationWithSortResult(sortPresenceByClientID), + ) + return err == nil + }) { + return fmt.Errorf("connID %s: %w", connID, err) + } + return nil + }) + } + + if err := rg.Wait(); err != nil { + t.Fatal(err) + } +} + func presenceHistoryFixtures() []*ably.PresenceMessage { actions := []proto.PresenceAction{ proto.PresenceEnter, @@ -204,3 +258,34 @@ func presenceEqual(x, y interface{}) bool { mx.Name == my.Name && reflect.DeepEqual(mx.Data, my.Data) } + +func persistedPresenceFixtures(filter ...func(ablytest.Presence) bool) []interface{} { + var expected []interface{} +fixtures: + for _, p := range ablytest.PresenceFixtures() { + for _, f := range filter { + if !f(p) { + continue fixtures + } + } + expected = append(expected, &ably.PresenceMessage{ + Action: proto.PresencePresent, + Message: ably.Message{ + ClientID: p.ClientID, + Data: p.Data, + }, + }) + } + + // presence.get result order is undefined, so we need to sort both + // expected and actual items client-side to get consistent results. + sortPresenceByClientID(expected) + + return expected +} + +func sortPresenceByClientID(items []interface{}) { + sort.Slice(items, func(i, j int) bool { + return items[i].(*ably.PresenceMessage).ClientID < items[j].(*ably.PresenceMessage).ClientID + }) +}