diff --git a/go/api/base_client.go b/go/api/base_client.go index 99772946d6..046a4240f2 100644 --- a/go/api/base_client.go +++ b/go/api/base_client.go @@ -7123,3 +7123,71 @@ func (client *baseClient) Time() ([]string, error) { } return handleStringArrayResponse(result) } + +// Returns the intersection of members from sorted sets specified by the given `keys`. +// To get the elements with their scores, see [ZInterWithScores]. +// +// Note: +// +// When in cluster mode, all keys must map to the same hash slot. +// +// See [valkey.io] for details. +// +// Parameters: +// +// keys - The keys of the sorted sets, see - [options.KeyArray]. +// +// Return value: +// +// The resulting sorted set from the intersection. +// +// Example: +// +// res, err := client.ZInter(options.NewKeyArray("key1", "key2", "key3")) +// fmt.Println(res) // []string{"member1", "member2", "member3"} +// +// [valkey.io]: https://valkey.io/commands/zinter/ +func (client *baseClient) ZInter(keys options.KeyArray) ([]string, error) { + args := keys.ToArgs() + result, err := client.executeCommand(C.ZInter, args) + if err != nil { + return nil, err + } + return handleStringArrayResponse(result) +} + +// Returns the intersection of members and their scores from sorted sets specified by the given +// `keysOrWeightedKeys`. +// +// Note: +// +// When in cluster mode, all keys must map to the same hash slot. +// +// See [valkey.io] for details. +// +// Parameters: +// +// options - The options for the ZInter command, see - [options.ZInterOptions]. +// +// Return value: +// +// A map of members to their scores. +// +// Example: +// +// res, err := client.ZInterWithScores(options.NewZInterOptionsBuilder(options.NewKeyArray("key1", "key2", "key3"))) +// fmt.Println(res) // map[member1:1.0 member2:2.0 member3:3.0] +// +// [valkey.io]: https://valkey.io/commands/zinter/ +func (client *baseClient) ZInterWithScores(zInterOptions *options.ZInterOptions) (map[string]float64, error) { + args, err := zInterOptions.ToArgs() + if err != nil { + return nil, err + } + args = append(args, options.WithScores) + result, err := client.executeCommand(C.ZInter, args) + if err != nil { + return nil, err + } + return handleStringDoubleMapResponse(result) +} diff --git a/go/api/options/constants.go b/go/api/options/constants.go index e728968633..479d7b3d78 100644 --- a/go/api/options/constants.go +++ b/go/api/options/constants.go @@ -3,13 +3,15 @@ package options const ( - CountKeyword string = "COUNT" // Valkey API keyword used to extract specific number of matching indices from a list. - MatchKeyword string = "MATCH" // Valkey API keyword used to indicate the match filter. - NoValue string = "NOVALUE" // Valkey API keyword for the no value option for hcsan command. - WithScore string = "WITHSCORE" // Valkey API keyword for the with score option for zrank and zrevrank commands. - WithScores string = "WITHSCORES" // Valkey API keyword for ZRandMember command to return scores along with members. - NoScores string = "NOSCORES" // Valkey API keyword for the no scores option for zscan command. - WithValues string = "WITHVALUES" // Valkey API keyword to query hash values along their names in `HRANDFIELD`. + CountKeyword string = "COUNT" // Valkey API keyword used to extract specific number of matching indices from a list. + MatchKeyword string = "MATCH" // Valkey API keyword used to indicate the match filter. + NoValue string = "NOVALUE" // Valkey API keyword for the no value option for hcsan command. + WithScore string = "WITHSCORE" // Valkey API keyword for the with score option for zrank and zrevrank commands. + WithScores string = "WITHSCORES" // Valkey API keyword for ZRandMember command to return scores along with members. + NoScores string = "NOSCORES" // Valkey API keyword for the no scores option for zscan command. + WithValues string = "WITHVALUES" // Valkey API keyword to query hash values along their names in `HRANDFIELD`. + AggregateKeyWord string = "AGGREGATE" // Valkey API keyword for the aggregate option for multiple commands. + WeightsKeyword string = "WEIGHTS" // Valkey API keyword for the weights option for multiple commands. ) type InfBoundary string diff --git a/go/api/options/weight_aggregate_options.go b/go/api/options/weight_aggregate_options.go new file mode 100644 index 0000000000..400150cc57 --- /dev/null +++ b/go/api/options/weight_aggregate_options.go @@ -0,0 +1,64 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +package options + +import "github.com/valkey-io/valkey-glide/go/glide/utils" + +// Aggregate represents the method of aggregating scores from multiple sets +type Aggregate string + +const ( + AggregateSum Aggregate = "SUM" // Aggregates by summing the scores of each element across sets + AggregateMin Aggregate = "MIN" // Aggregates by taking the minimum score of each element across sets + AggregateMax Aggregate = "MAX" // Aggregates by taking the maximum score of each element across sets +) + +// converts the Aggregate to its Valkey API representation +func (a Aggregate) ToArgs() []string { + return []string{AggregateKeyWord, string(a)} +} + +// This is a basic interface. Please use one of the following implementations: +// - KeyArray +// - WeightedKeys +type KeysOrWeightedKeys interface { + ToArgs() []string +} + +// represents a list of keys of the sorted sets involved in the aggregation operation +type KeyArray struct { + Keys []string +} + +// converts the KeyArray to its Valkey API representation +func (k KeyArray) ToArgs() []string { + args := []string{utils.IntToString(int64(len(k.Keys)))} + args = append(args, k.Keys...) + return args +} + +type KeyWeightPair struct { + Key string + Weight float64 +} + +// represents the mapping of sorted set keys to their score weights +type WeightedKeys struct { + KeyWeightPairs []KeyWeightPair +} + +// converts the WeightedKeys to its Valkey API representation +func (w WeightedKeys) ToArgs() []string { + keys := make([]string, 0, len(w.KeyWeightPairs)) + weights := make([]string, 0, len(w.KeyWeightPairs)) + args := make([]string, 0) + for _, pair := range w.KeyWeightPairs { + keys = append(keys, pair.Key) + weights = append(weights, utils.FloatToString(pair.Weight)) + } + args = append(args, utils.IntToString(int64(len(keys)))) + args = append(args, keys...) + args = append(args, WeightsKeyword) + args = append(args, weights...) + return args +} diff --git a/go/api/options/zinter_options.go b/go/api/options/zinter_options.go new file mode 100644 index 0000000000..c36bf4ef07 --- /dev/null +++ b/go/api/options/zinter_options.go @@ -0,0 +1,33 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +package options + +// This struct represents the optional arguments for the ZINTER command. +type ZInterOptions struct { + keysOrWeightedKeys KeysOrWeightedKeys + aggregate Aggregate +} + +func NewZInterOptionsBuilder(keysOrWeightedKeys KeysOrWeightedKeys) *ZInterOptions { + return &ZInterOptions{keysOrWeightedKeys: keysOrWeightedKeys} +} + +// SetAggregate sets the aggregate method for the ZInter command. +func (options *ZInterOptions) SetAggregate(aggregate Aggregate) *ZInterOptions { + options.aggregate = aggregate + return options +} + +func (options *ZInterOptions) ToArgs() ([]string, error) { + args := []string{} + + if options.keysOrWeightedKeys != nil { + args = append(args, options.keysOrWeightedKeys.ToArgs()...) + } + + if options.aggregate != "" { + args = append(args, options.aggregate.ToArgs()...) + } + + return args, nil +} diff --git a/go/api/sorted_set_commands.go b/go/api/sorted_set_commands.go index 1331a09e04..7842284f84 100644 --- a/go/api/sorted_set_commands.go +++ b/go/api/sorted_set_commands.go @@ -78,4 +78,8 @@ type SortedSetCommands interface { ZRandMemberWithCountWithScores(key string, count int64) ([]MemberAndScore, error) ZMScore(key string, members []string) ([]Result[float64], error) + + ZInter(keys options.KeyArray) ([]string, error) + + ZInterWithScores(options *options.ZInterOptions) (map[string]float64, error) } diff --git a/go/integTest/shared_commands_test.go b/go/integTest/shared_commands_test.go index 71fbcb8c93..fb5de4cbf5 100644 --- a/go/integTest/shared_commands_test.go +++ b/go/integTest/shared_commands_test.go @@ -7806,3 +7806,102 @@ func (suite *GlideTestSuite) TestBitFieldRO_MultipleGets() { assert.Equal(suite.T(), []int64{value1, value2}, []int64{getRO[0].Value(), getRO[1].Value()}) }) } + +func (suite *GlideTestSuite) TestZInter() { + suite.SkipIfServerVersionLowerThanBy("6.2.0") + suite.runWithDefaultClients(func(client api.BaseClient) { + key1 := "{key}-" + uuid.New().String() + key2 := "{key}-" + uuid.New().String() + key3 := "{key}-" + uuid.New().String() + memberScoreMap1 := map[string]float64{ + "one": 1.0, + "two": 2.0, + } + memberScoreMap2 := map[string]float64{ + "two": 3.5, + "three": 3.0, + } + + // Add members to sorted sets + res, err := client.ZAdd(key1, memberScoreMap1) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + res, err = client.ZAdd(key2, memberScoreMap2) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // intersection results are aggregated by the max score of elements + zinterResult, err := client.ZInter(options.KeyArray{Keys: []string{key1, key2}}) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), []string{"two"}, zinterResult) + + // intersection with scores + zinterWithScoresResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateSum), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"two": 5.5}, zinterWithScoresResult) + + // intersect results with max aggregate + zinterWithMaxAggregateResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateMax), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"two": 3.5}, zinterWithMaxAggregateResult) + + // intersect results with min aggregate + zinterWithMinAggregateResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateMin), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"two": 2.0}, zinterWithMinAggregateResult) + + // intersect results with sum aggregate + zinterWithSumAggregateResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateSum), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"two": 5.5}, zinterWithSumAggregateResult) + + // Scores are multiplied by a 2.0 weight for key1 and key2 during aggregation + zinterWithWeightedKeysResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder( + options.WeightedKeys{ + KeyWeightPairs: []options.KeyWeightPair{ + {Key: key1, Weight: 2.0}, + {Key: key2, Weight: 2.0}, + }, + }, + ).SetAggregate(options.AggregateSum), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"two": 11.0}, zinterWithWeightedKeysResult) + + // non-existent key - empty intersection + zinterWithNonExistentKeyResult, err := client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key3}}).SetAggregate(options.AggregateSum), + ) + assert.NoError(suite.T(), err) + assert.Empty(suite.T(), zinterWithNonExistentKeyResult) + + // empty key list - request error + _, err = client.ZInterWithScores(options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{}})) + assert.NotNil(suite.T(), err) + assert.IsType(suite.T(), &errors.RequestError{}, err) + + // key exists but not a set + _, err = client.Set(key3, "value") + assert.NoError(suite.T(), err) + + _, err = client.ZInter(options.KeyArray{Keys: []string{key1, key3}}) + assert.NotNil(suite.T(), err) + assert.IsType(suite.T(), &errors.RequestError{}, err) + + _, err = client.ZInterWithScores( + options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key3}}).SetAggregate(options.AggregateSum), + ) + assert.NotNil(suite.T(), err) + assert.IsType(suite.T(), &errors.RequestError{}, err) + }) +}