From fcdf440964465db87aaceee41fd3fff74e44eb5e Mon Sep 17 00:00:00 2001 From: Aleksei Ilin Date: Fri, 18 Oct 2024 18:21:51 +0200 Subject: [PATCH] set: Add set support for size specifier Handle attribute NFTNL_SET_DESC_SIZE, as done in libnftnl: https://git.netfilter.org/libnftnl/tree/src/set.c#n424 Example: nft add set ip filter myset { type ipv4_addr\; size 65535\; flags dynamic\; } --- nftables_test.go | 40 +++++++++++++++++++++++ set.go | 40 +++++++++++++++++++++-- set_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 2 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 584d6f9..438cd86 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4103,6 +4103,46 @@ func TestSetElementsInterval(t *testing.T) { } } +func TestSetSizeConcat(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv6, + Name: "filter", + }) + + set := &nftables.Set{ + Name: "test-set", + Table: filter, + KeyType: nftables.MustConcatSetType(nftables.TypeIP6Addr, nftables.TypeInetService, nftables.TypeIP6Addr), + Dynamic: true, + Concatenation: true, + Size: 200, + } + + if err := c.AddSet(set, nil); err != nil { + t.Errorf("c.AddSet(set) failed: %v", err) + } + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + sets, err := c.GetSets(filter) + if err != nil { + t.Errorf("c.GetSets() failed: %v", err) + } + if len(sets) != 1 { + t.Fatalf("len(sets) = %d, want 1", len(sets)) + } +} + func TestCreateListFlowtable(t *testing.T) { c, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) diff --git a/set.go b/set.go index 816847e..51abc04 100644 --- a/set.go +++ b/set.go @@ -267,6 +267,8 @@ type Set struct { // https://git.netfilter.org/nftables/tree/include/datatype.h?id=d486c9e626405e829221b82d7355558005b26d8a#n109 KeyByteOrder binaryutil.ByteOrder Comment string + // Indicates that the set has "size" specifier + Size uint32 } // SetElement represents a data point within a set. @@ -553,6 +555,21 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) } + + var descBytes []byte + + if s.Size > 0 { + // Marshal set size description + descSizeBytes, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, + }) + if err != nil { + return fmt.Errorf("fail to marshal set size description: %w", err) + } + + descBytes = append(descBytes, descSizeBytes...) + } + if s.Concatenation { // Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset var concatDefinition []byte @@ -579,8 +596,13 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return fmt.Errorf("fail to marshal concat definition %v", err) } - // Marshal concat size description as set description - tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes}) + + descBytes = append(descBytes, concatBytes...) + } + + if len(descBytes) > 0 { + // Marshal set description + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: descBytes}) } // https://git.netfilter.org/libnftnl/tree/include/udata.h#n17 @@ -763,6 +785,20 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { data := ad.Bytes() value, ok := userdata.GetUint32(data, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS) set.AutoMerge = ok && value == 1 + case unix.NFTA_SET_DESC: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return nil, fmt.Errorf("nested NewAttributeDecoder() failed: %w", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_SET_DESC_SIZE: + set.Size = binary.BigEndian.Uint32(nestedAD.Bytes()) + } + } + if nestedAD.Err() != nil { + return nil, fmt.Errorf("decoding set description: %w", nestedAD.Err()) + } } } return &set, nil diff --git a/set_test.go b/set_test.go index dda0a56..65a8e00 100644 --- a/set_test.go +++ b/set_test.go @@ -1,7 +1,11 @@ package nftables import ( + "reflect" "testing" + "time" + + "github.com/mdlayher/netlink" ) // unknownNFTMagic is an nftMagic value that's unhandled by this @@ -185,3 +189,82 @@ func TestConcatSetTypeElements(t *testing.T) { }) } } + +func TestMarshalSet(t *testing.T) { + t.Parallel() + + tbl := &Table{ + Name: "ipv4table", + Family: TableFamilyIPv4, + } + + c, err := New(WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + return req, nil + })) + if err != nil { + t.Fatal(err) + } + + c.AddTable(tbl) + + // Ensure the table is added. + const connMsgStart = 1 + if len(c.messages) != connMsgStart { + t.Fatalf("AddSet() wrong start message count: %d, expected: %d", len(c.messages), connMsgStart) + } + + tests := []struct { + name string + set Set + }{ + { + name: "Set without flags", + set: Set{ + Name: "test-set", + ID: uint32(1), + Table: tbl, + KeyType: TypeIPAddr, + }, + }, + { + name: "Set with size, timeout, dynamic flag specified", + set: Set{ + Name: "test-set", + ID: uint32(2), + HasTimeout: true, + Dynamic: true, + Size: 10, + Table: tbl, + KeyType: TypeIPAddr, + Timeout: 30 * time.Second, + }, + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := c.AddSet(&tt.set, nil); err != nil { + t.Fatal(err) + } + + connMsgSetIdx := connMsgStart + i + if len(c.messages) != connMsgSetIdx+1 { + t.Fatalf("AddSet() wrong message count: %d, expected: %d", len(c.messages), connMsgSetIdx+1) + } + msg := c.messages[connMsgSetIdx] + + nset, err := setsFromMsg(msg) + if err != nil { + t.Fatalf("setsFromMsg() error: %+v", err) + } + + // Table pointer is set after flush, which is not implemented in the test. + tt.set.Table = nil + + if !reflect.DeepEqual(&tt.set, nset) { + t.Fatalf("original %+v and recovered %+v Set structs are different", tt.set, nset) + } + }) + } +}