Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unusable networks from remote tunnels at handshake time #1318

Merged
merged 2 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"hash/fnv"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -438,9 +437,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
//TODO: we can make this more performant
if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
Expand Down
15 changes: 8 additions & 7 deletions firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) {
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
}
h.buildNetworks(&c)
h.buildNetworks(c.networks, c.unsafeNetworks)

fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
Expand Down Expand Up @@ -332,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())

c1 := cert.CachedCertificate{
Certificate: &dummyCert{
Expand All @@ -342,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
}
h1 := HostInfo{
vpnAddrs: []netip.Addr{network.Addr()},
ConnectionState: &ConnectionState{
peerCert: &c1,
},
}
h1.buildNetworks(c1.Certificate)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())

fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
Expand Down Expand Up @@ -394,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h1.buildNetworks(c1.Certificate)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())

c2 := cert.CachedCertificate{
Certificate: &dummyCert{
Expand All @@ -409,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h2.buildNetworks(c2.Certificate)
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())

c3 := cert.CachedCertificate{
Certificate: &dummyCert{
Expand All @@ -424,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h3.buildNetworks(c3.Certificate)
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())

fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
Expand Down Expand Up @@ -471,7 +472,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(c.Certificate)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())

fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
Expand Down
54 changes: 45 additions & 9 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}

var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
Expand All @@ -189,15 +190,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}

if addr.IsValid() {
// addr can be invalid when the tunnel is being relayed.
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}

// vpnAddrs outside our vpn networks are of no use to us, filter them out
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}

filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}

if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
}

myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
Expand Down Expand Up @@ -294,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet

hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
hostinfo.buildNetworks(remoteCert.Certificate)
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())

existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
Expand Down Expand Up @@ -431,7 +449,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha

hostinfo := hh.hostinfo
if addr.IsValid() {
//TODO: this is kind of nonsense now
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
Expand Down Expand Up @@ -492,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
e = e.WithField("cert", remoteCert)
}

e.Info("Invalid vpn ip from host")
e.Info("Empty networks from host")
return true
}

Expand All @@ -516,9 +534,26 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}

vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, n := range vpnNetworks {
vpnAddrs[i] = n.Addr()
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}

filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}

if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
}

// Ensure the right host responded
Expand Down Expand Up @@ -558,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)

duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
Expand All @@ -569,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
Info("Handshake message received")

// Build up the radix for the firewall if we have subnets in the cert
hostinfo.buildNetworks(remoteCert.Certificate)
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())

// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)

Expand Down
16 changes: 10 additions & 6 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,12 @@ type HostInfo struct {
ConnectionState *ConnectionState
remoteIndexId uint32
localIndexId uint32
vpnAddrs []netip.Addr
recvError atomic.Uint32

// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32

// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}]
Expand Down Expand Up @@ -712,18 +716,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true
}

func (i *HostInfo) buildNetworks(c cert.Certificate) {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
return
}

i.networks = new(bart.Table[struct{}])
for _, network := range c.Networks() {
for _, network := range networks {
i.networks.Insert(network, struct{}{})
}

for _, network := range c.UnsafeNetworks() {
for _, network := range unsafeNetworks {
i.networks.Insert(network, struct{}{})
}
}
Expand Down
2 changes: 1 addition & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type Interface struct {
myBroadcastAddrsTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool
dropMulticast bool
Expand Down
Loading