diff --git a/aggs_bucket.go b/aggs_bucket.go index 8058ef4..c21eaad 100644 --- a/aggs_bucket.go +++ b/aggs_bucket.go @@ -57,7 +57,7 @@ func (agg *TermsAggregation) Aggs(aggs ...Aggregation) *TermsAggregation { return agg } -// Order sets the sort for terms agg +// Order sets the sorts for terms agg func (agg *TermsAggregation) Order(order map[string]string) *TermsAggregation { agg.order = order return agg diff --git a/aggs_metric.go b/aggs_metric.go index 0bfed92..8505a40 100644 --- a/aggs_metric.go +++ b/aggs_metric.go @@ -416,7 +416,7 @@ type TopHitsAgg struct { name string from uint64 size uint64 - sort []map[string]interface{} + sorts Sorts source Source } @@ -445,15 +445,10 @@ func (agg *TopHitsAgg) Size(size uint64) *TopHitsAgg { return agg } -// Sort sets how the top matching hits should be sorted. By default the hits are +// Sorts sets how the top matching hits should be sorted. By default the hits are // sorted by the score of the main query. -func (agg *TopHitsAgg) Sort(name string, order Order) *TopHitsAgg { - agg.sort = append(agg.sort, map[string]interface{}{ - name: map[string]interface{}{ - "order": order, - }, - }) - +func (agg *TopHitsAgg) Sorts(sorts ...map[string]Sort) *TopHitsAgg { + agg.sorts = sorts return agg } @@ -474,8 +469,8 @@ func (agg *TopHitsAgg) Map() map[string]interface{} { if agg.size > 0 { innerMap["size"] = agg.size } - if len(agg.sort) > 0 { - innerMap["sort"] = agg.sort + if len(agg.sorts) > 0 { + innerMap["sort"] = agg.sorts } if len(agg.source.includes) > 0 { innerMap["_source"] = agg.source.Map() diff --git a/aggs_metric_test.go b/aggs_metric_test.go index b742527..e61050d 100644 --- a/aggs_metric_test.go +++ b/aggs_metric_test.go @@ -154,5 +154,38 @@ func TestMetricAggs(t *testing.T) { }, }, }, + { + "top_hits agg", + TopHits("top_hits").Sorts( + Sorts{ + { + "field_1": { + Order: OrderDesc, + }, + }, + { + "field_2": { + Order: OrderAsc, + }, + }, + }..., + ), + map[string]interface{}{ + "top_hits": map[string]interface{}{ + "sort": []map[string]interface{}{ + { + "field_1": map[string]interface{}{ + "order": OrderDesc, + }, + }, + { + "field_2": map[string]interface{}{ + "order": OrderAsc, + }, + }, + }, + }, + }, + }, }) } diff --git a/common.go b/common.go index 7bf8cbb..b18c73f 100644 --- a/common.go +++ b/common.go @@ -19,8 +19,12 @@ func (source Source) Map() map[string]interface{} { return m } -// Sort represents a list of keys to sort by. -type Sort []map[string]interface{} +// Sorts represents a list of keys to sort by. +type Sorts []map[string]Sort + +type Sort struct { + Order Order `json:"order"` +} // Order is the ordering for a sort key (ascending, descending). type Order string diff --git a/search.go b/search.go index 81037f4..af685ca 100644 --- a/search.go +++ b/search.go @@ -23,7 +23,7 @@ type SearchRequest struct { postFilter Mappable query Mappable size *uint64 - sort Sort + sorts Sorts source Source timeout *time.Duration } @@ -64,14 +64,9 @@ func (req *SearchRequest) Size(size uint64) *SearchRequest { return req } -// Sort sets how the results should be sorted. -func (req *SearchRequest) Sort(name string, order Order) *SearchRequest { - req.sort = append(req.sort, map[string]interface{}{ - name: map[string]interface{}{ - "order": order, - }, - }) - +// Sorts sets how the results should be sorted. +func (req *SearchRequest) Sorts(sorts ...map[string]Sort) *SearchRequest { + req.sorts = sorts return req } @@ -133,8 +128,8 @@ func (req *SearchRequest) Map() map[string]interface{} { if req.size != nil { m["size"] = *req.size } - if len(req.sort) > 0 { - m["sort"] = req.sort + if len(req.sorts) > 0 { + m["sort"] = req.sorts } if req.from != nil { m["from"] = *req.from diff --git a/search_test.go b/search_test.go index be37d6b..c748648 100644 --- a/search_test.go +++ b/search_test.go @@ -53,8 +53,20 @@ func TestSearchMaps(t *testing.T) { Size(30). From(5). Explain(true). - Sort("field_1", OrderDesc). - Sort("field_2", OrderAsc). + Sorts( + Sorts{ + { + "field_1": { + Order: OrderDesc, + }, + }, + { + "field_2": { + Order: OrderAsc, + }, + }, + }..., + ). SourceIncludes("field_1", "field_2"). SourceExcludes("field_3"). Timeout(time.Duration(20000000000)),