Skip to content

Commit

Permalink
dag: method for filtering a set on arbitrary criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
apparentlymart committed May 11, 2017
1 parent 510733f commit b28fc1c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
14 changes: 14 additions & 0 deletions dag/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ func (s *Set) Difference(other *Set) *Set {
return result
}

// Filter returns a set that contains the elements from the receiver
// where the given callback returns true.
func (s *Set) Filter(cb func(interface{}) bool) *Set {
result := new(Set)

for _, v := range s.m {
if cb(v) {
result.Add(v)
}
}

return result
}

// Len is the number of items in the set.
func (s *Set) Len() int {
if s == nil {
Expand Down
42 changes: 42 additions & 0 deletions dag/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,45 @@ func TestSetDifference(t *testing.T) {
})
}
}

func TestSetFilter(t *testing.T) {
cases := []struct {
Input []interface{}
Expected []interface{}
}{
{
[]interface{}{1, 2, 3},
[]interface{}{1, 2, 3},
},

{
[]interface{}{4, 5, 6},
[]interface{}{4},
},

{
[]interface{}{7, 8, 9},
[]interface{}{},
},
}

for i, tc := range cases {
t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) {
var input, expected Set
for _, v := range tc.Input {
input.Add(v)
}
for _, v := range tc.Expected {
expected.Add(v)
}

actual := input.Filter(func(v interface{}) bool {
return v.(int) < 5
})
match := actual.Intersection(&expected)
if match.Len() != expected.Len() {
t.Fatalf("bad: %#v", actual.List())
}
})
}
}

0 comments on commit b28fc1c

Please sign in to comment.