From 3f57b17eb3fc9d73184f8c1d41a96fe41dc6c3d5 Mon Sep 17 00:00:00 2001 From: Francesco Cheinasso Date: Thu, 11 Jan 2024 10:43:17 +0100 Subject: [PATCH] fix --- nftables_test.go | 17 ++++---- util.go | 50 ++++++++++++----------- util_test.go | 101 ++++++++++++++++++++--------------------------- 3 files changed, 77 insertions(+), 91 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 743cde6..d0ec731 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -454,12 +454,7 @@ func TestConfigureNAT(t *testing.T) { t.Fatal(err) } - dnatfirstip, err := nftables.GetFirstIPFromCIDR("20.0.0.0/24") - if err != nil { - t.Fatal(err) - } - - dnatlastip, err := nftables.GetLastIPFromCIDR("20.0.0.0/24") + dnatfirstip, dnatlastip, err := nftables.GetFirstAndLastIPFromCIDR("20.0.0.0/24") if err != nil { t.Fatal(err) } @@ -478,8 +473,10 @@ func TestConfigureNAT(t *testing.T) { SourceRegister: 1, DestRegister: 1, Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: dstcidrmatch.Mask, + // By specifying Xor to 0x0,0x0,0x0,0x0 and Mask to the CIDR mask, + // the rule will match the CIDR of the IP (e.g in this case 10.0.0.0/24). + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: dstcidrmatch.Mask, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -488,11 +485,11 @@ func TestConfigureNAT(t *testing.T) { }, &expr.Immediate{ Register: 1, - Data: *dnatfirstip, + Data: dnatfirstip, }, &expr.Immediate{ Register: 2, - Data: *dnatlastip, + Data: dnatlastip, }, &expr.NAT{ Type: expr.NATTypeDestNAT, diff --git a/util.go b/util.go index 6122814..6fd7c39 100644 --- a/util.go +++ b/util.go @@ -46,34 +46,38 @@ func (genmsg *NFGenMsg) Decode(b []byte) { genmsg.ResourceID = binary.BigEndian.Uint16(b[2:]) } -// GetFirstIPFromCIDR returns the first IP address from a CIDR. -func GetFirstIPFromCIDR(cidr string) (*net.IP, error) { +// GetFirstAndLastIPFromCIDR returns the first and last IP address from a CIDR. +func GetFirstAndLastIPFromCIDR(cidr string) (firstIP, lastIP net.IP, err error) { _, subnet, err := net.ParseCIDR(cidr) if err != nil { - return nil, err + return nil, nil, err } - mask := binary.BigEndian.Uint32(subnet.Mask) - ip := binary.BigEndian.Uint32(subnet.IP) + firstIP = make(net.IP, len(subnet.IP)) + lastIP = make(net.IP, len(subnet.IP)) - // find the final address - firstIP := make(net.IP, 4) - binary.BigEndian.PutUint32(firstIP, ip&mask) - - return &firstIP, nil -} - -// GetLastIPFromCIDR returns the last IP address from a CIDR. -func GetLastIPFromCIDR(cidr string) (*net.IP, error) { - _, subnet, err := net.ParseCIDR(cidr) - if err != nil { - return nil, err + switch len(subnet.IP) { + case net.IPv4len: + mask := binary.BigEndian.Uint32(subnet.Mask) + ip := binary.BigEndian.Uint32(subnet.IP) + // To achieve the first IP address, we need to AND the IP with the mask. + // The AND operation will set all bits in the host part to 0. + binary.BigEndian.PutUint32(firstIP, ip&mask) + // To achieve the last IP address, we need to OR the IP network with the inverted mask. + // The AND between the IP and the mask will set all bits in the host part to 0, keeping the network part. + // The XOR between the mask and 0xffffffff will set all bits in the host part to 1, and the network part to 0. + // The OR operation will keep the host part unchanged, and sets the host part to all 1. + binary.BigEndian.PutUint32(lastIP, (ip&mask)|(mask^0xffffffff)) + case net.IPv6len: + mask1 := binary.BigEndian.Uint64(subnet.Mask[:8]) + mask2 := binary.BigEndian.Uint64(subnet.Mask[8:]) + ip1 := binary.BigEndian.Uint64(subnet.IP[:8]) + ip2 := binary.BigEndian.Uint64(subnet.IP[8:]) + binary.BigEndian.PutUint64(firstIP[:8], ip1&mask1) + binary.BigEndian.PutUint64(firstIP[8:], ip2&mask2) + binary.BigEndian.PutUint64(lastIP[:8], (ip1&mask1)|(mask1^0xffffffffffffffff)) + binary.BigEndian.PutUint64(lastIP[8:], (ip2&mask2)|(mask2^0xffffffffffffffff)) } - mask := binary.BigEndian.Uint32(subnet.Mask) - ip := binary.BigEndian.Uint32(subnet.IP) - // find the final address - lastIP := make(net.IP, 4) - binary.BigEndian.PutUint32(lastIP, (ip&mask)|(mask^0xffffffff)) - return &lastIP, nil + return firstIP, lastIP, nil } diff --git a/util_test.go b/util_test.go index d2747ad..8284736 100644 --- a/util_test.go +++ b/util_test.go @@ -6,87 +6,72 @@ import ( "testing" ) -func TestGetFirstIPFromCIDR(t *testing.T) { +func TestGetFirstAndLastIPFromCIDR(t *testing.T) { type args struct { cidr string } tests := []struct { - name string - args args - want *net.IP - wantErr bool + name string + args args + wantFirstIP net.IP + wantLastIP net.IP + wantErr bool }{ { - name: "Test 0", - args: args{cidr: "fakecidr"}, - want: nil, - wantErr: true, + name: "Test Fake", + args: args{cidr: "fakecidr"}, + wantFirstIP: nil, + wantLastIP: nil, + wantErr: true, }, { - name: "Test 1", - args: args{cidr: "10.0.0.0/24"}, - want: &net.IP{10, 0, 0, 0}, - wantErr: false, + name: "Test IPV4 1", + args: args{cidr: "10.0.0.0/24"}, + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{10, 0, 0, 255}, + wantErr: false, }, { - name: "Test 2", - args: args{cidr: "10.0.0.20/24"}, - want: &net.IP{10, 0, 0, 0}, - wantErr: false, + name: "Test IPV4 2", + args: args{cidr: "10.0.0.20/24"}, + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{10, 0, 0, 255}, + wantErr: false, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := GetFirstIPFromCIDR(tt.args.cidr) - if (err != nil) != tt.wantErr { - t.Errorf("GetFirstIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetFirstIPFromCIDR() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestGetLastIPFromCIDR(t *testing.T) { - type args struct { - cidr string - } - tests := []struct { - name string - args args - want *net.IP - wantErr bool - }{ { - name: "Test 0", - args: args{cidr: "fakecidr"}, - want: nil, - wantErr: true, + name: "Test IPV4 2", + args: args{cidr: "10.0.0.0/19"}, + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{10, 0, 31, 255}, + wantErr: false, }, { - name: "Test 1", - args: args{cidr: "10.0.0.0/24"}, - want: &net.IP{10, 0, 0, 255}, - wantErr: false, + name: "Test IPV6 1", + args: args{cidr: "ff00::/16"}, + wantFirstIP: net.ParseIP("ff00::"), + wantLastIP: net.ParseIP("ff00:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), + wantErr: false, }, { - name: "Test 2", - args: args{cidr: "10.0.0.20/24"}, - want: &net.IP{10, 0, 0, 255}, - wantErr: false, + name: "Test IPV6 2", + args: args{cidr: "2001:db8::/62"}, + wantFirstIP: net.ParseIP("2001:db8::"), + wantLastIP: net.ParseIP("2001:db8:0000:0003:ffff:ffff:ffff:ffff"), + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetLastIPFromCIDR(tt.args.cidr) + gotFirstIP, gotLastIP, err := GetFirstAndLastIPFromCIDR(tt.args.cidr) if (err != nil) != tt.wantErr { - t.Errorf("GetLastIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetFirstAndLastIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetLastIPFromCIDR() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(gotFirstIP, tt.wantFirstIP) { + t.Errorf("GetFirstAndLastIPFromCIDR() gotFirstIP = %v, want %v", gotFirstIP, tt.wantFirstIP) + } + if !reflect.DeepEqual(gotLastIP, tt.wantLastIP) { + t.Errorf("GetFirstAndLastIPFromCIDR() gotLastIP = %v, want %v", gotLastIP, tt.wantLastIP) } }) }