Skip to content

Commit

Permalink
calculate a proper union
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian committed Nov 1, 2024
1 parent eeccd0d commit 92444fd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 18 deletions.
44 changes: 26 additions & 18 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -1104,16 +1104,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti

// sendHostPunchNotification signals the other side to punch some zero byte udp packets
func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) {
// if the other side is v6-only, tell them to punch to the first v6-addr in fromVpnAddrs
whereToPunch := fromVpnAddrs[0]
if !whereToPunch.Is6() && punchNotifDest.Is6() {
for i := range fromVpnAddrs {
if fromVpnAddrs[i].Is6() {
whereToPunch = fromVpnAddrs[i]
}
}
}

found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
Expand All @@ -1122,24 +1113,30 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
} else {
useVersion = targetHI.GetCert().Certificate.Version()
crt := targetHI.GetCert().Certificate
useVersion = crt.Version()
// we can only retarget if we have a hostinfo
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
if ok {
whereToPunch = newDest
} else {
//TODO this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
}

if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() {
if !whereToPunch.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := fromVpnAddrs[0].As4()
b := whereToPunch.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(useVersion, c, n)

} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch)
lhh.coalesceAnswers(useVersion, c, n)

} else {
panic("unsupported version")
return 0, errors.New("unsupported version")
}
lhh.coalesceAnswers(useVersion, c, n)

return n.MarshalTo(lhh.pb)
})
Expand Down Expand Up @@ -1333,7 +1330,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
}

func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//TODO: this is kinda stupid
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return
}
Expand Down Expand Up @@ -1441,3 +1437,15 @@ func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort {
Port: uint32(port),
}
}

// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
for i := range prefixes {
for j := range addrs {
if prefixes[i].Contains(addrs[j]) {
return addrs[j], true
}
}
}
return netip.Addr{}, false
}
60 changes: 60 additions & 0 deletions lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,63 @@ func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort)
}
}
}

func Test_findNetworkUnion(t *testing.T) {
var out netip.Addr
var ok bool

tenDot := netip.MustParsePrefix("10.0.0.0/8")
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
fe80 := netip.MustParsePrefix("fe80::/8")
fc00 := netip.MustParsePrefix("fc00::/7")

a1 := netip.MustParseAddr("10.0.0.1")
afe81 := netip.MustParseAddr("fe80::1")

//simple
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)

//mixed lengths
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)

//mixed family
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)

//ordering
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, afe81)

//some mismatches
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)

//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}

0 comments on commit 92444fd

Please sign in to comment.