Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sort ports and merge adjacent ones in the nft rule #9010

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions internal/app/machined/pkg/adapters/network/nftables_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package network

import (
"cmp"
"fmt"
"net/netip"
"os"
Expand Down Expand Up @@ -109,9 +110,11 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {

return elements
case SetKindPort:
elements := make([]nftables.SetElement, 0, len(set.Ports))
ports := mergeAdjacentPorts(set.Ports)

for _, p := range set.Ports {
elements := make([]nftables.SetElement, 0, len(ports))

for _, p := range ports {
from := binaryutil.BigEndian.PutUint16(p[0])
to := binaryutil.BigEndian.PutUint16(p[1] + 1)

Expand Down Expand Up @@ -157,6 +160,26 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {
}
}

func mergeAdjacentPorts(in [][2]uint16) [][2]uint16 {
ports := slices.Clone(in)

slices.SortFunc(ports, func(a, b [2]uint16) int {
// sort by the lower bound of the range, assume no overlap
return cmp.Compare(a[0], b[0])
})

for i := 0; i < len(ports)-1; {
if ports[i][1]+1 >= ports[i+1][0] {
ports[i][1] = ports[i+1][1]
ports = append(ports[:i+1], ports[i+2:]...)
} else {
i++
}
}

return ports
}

// NfTablesCompiled is a compiled representation of the rule.
type NfTablesCompiled struct {
Rules [][]expr.Any
Expand Down
55 changes: 50 additions & 5 deletions internal/app/machined/pkg/adapters/network/nftables_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,14 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
Protocol: nethelpers.ProtocolTCP,
MatchSourcePort: &networkres.NfTablesPortMatch{
Ranges: []networkres.PortRange{
{
Lo: 1000,
Hi: 1025,
},
{
Lo: 2000,
Hi: 2000,
},
{
Lo: 1000,
Hi: 1025,
},
},
},
},
Expand Down Expand Up @@ -562,8 +562,8 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{1000, 1025},
{2000, 2000},
{1000, 1025},
},
},
},
Expand Down Expand Up @@ -713,3 +713,48 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
})
}
}

func TestNftablesSet(t *testing.T) { //nolint:tparallel
t.Parallel()

for _, test := range []struct {
name string

set network.NfTablesSet

expectedKeyType nftables.SetDatatype
expectedInterval bool
expectedData []nftables.SetElement
}{
{
name: "ports",

set: network.NfTablesSet{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{443, 443},
{80, 81},
{5000, 5000},
{5001, 5001},
},
},

expectedKeyType: nftables.TypeInetService,
expectedInterval: true,
expectedData: []nftables.SetElement{ // network byte order
{Key: []uint8{0x0, 80}, IntervalEnd: false}, // 80 - 81
{Key: []uint8{0x0, 82}, IntervalEnd: true},
{Key: []uint8{0x1, 0xbb}, IntervalEnd: false}, // 443-443
{Key: []uint8{0x1, 0xbc}, IntervalEnd: true},
{Key: []uint8{0x13, 0x88}, IntervalEnd: false}, // 5000-5001
{Key: []uint8{0x13, 0x8a}, IntervalEnd: true},
},
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expectedKeyType, test.set.KeyType())
assert.Equal(t, test.expectedInterval, test.set.IsInterval())
assert.Equal(t, test.expectedData, test.set.SetElements())
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,60 @@ func (s *NfTablesChainSuite) TestL4Match2() {
s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 1023, 1024 } drop
meta nfproto ipv6 tcp dport { 1023, 1024 } drop
ip saddr != { 10.0.0.0/8 } tcp dport { 1023-1024 } drop
meta nfproto ipv6 tcp dport { 1023-1024 } drop
}
}`)
}

func (s *NfTablesChainSuite) TestL4MatchAdjacentPorts() {
chain := network.NewNfTablesChain(network.NamespaceName, "test-tcp")
chain.TypedSpec().Type = nethelpers.ChainTypeFilter
chain.TypedSpec().Hook = nethelpers.ChainHookInput
chain.TypedSpec().Priority = nethelpers.ChainPriorityFilter
chain.TypedSpec().Policy = nethelpers.VerdictAccept
chain.TypedSpec().Rules = []network.NfTablesRule{
{
MatchSourceAddress: &network.NfTablesAddressMatch{
IncludeSubnets: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
Invert: true,
},
MatchLayer4: &network.NfTablesLayer4Match{
Protocol: nethelpers.ProtocolTCP,
MatchDestinationPort: &network.NfTablesPortMatch{
Ranges: []network.PortRange{
{
Lo: 5000,
Hi: 5000,
},
{
Lo: 5001,
Hi: 5001,
},
{
Lo: 10250,
Hi: 10250,
},
{
Lo: 4240,
Hi: 4240,
},
},
},
},
Verdict: pointer.To(nethelpers.VerdictDrop),
},
}

s.Require().NoError(s.State().Create(s.Ctx(), chain))

s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 4240, 5000-5001, 10250 } drop
meta nfproto ipv6 tcp dport { 4240, 5000-5001, 10250 } drop
}
}`)
}
Expand Down
Loading