Skip to content

Commit

Permalink
Merge pull request safing#463 from safing/fix/dns-request-flooding
Browse files Browse the repository at this point in the history
Improve nameserver and performance when in failing network condition
  • Loading branch information
dhaavi authored Dec 13, 2021
2 parents 99f8d8e + 049b803 commit b28e626
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 39 deletions.
2 changes: 1 addition & 1 deletion compat/notify.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (issue *appIssue) notify(proc *process.Process) {
"compat: detected %s issue with %s",
strings.ReplaceAll(
strings.TrimPrefix(
strings.TrimSuffix(issue.id, "-%d"),
strings.TrimSuffix(issue.id, "-%s"),
"compat:",
),
"-", " ",
Expand Down
135 changes: 135 additions & 0 deletions nameserver/failing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package nameserver

import (
"sync"
"time"

"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/resolver"
)

type failingQuery struct {
// Until specifies until when the query should be regarded as failing.
Until time.Time

// Keep specifies until when the failing status shall be kept.
Keep time.Time

// Times specifies how often this query failed.
Times int

// Err holds the error the query failed with.
Err error
}

const (
failingDelay = 900 * time.Millisecond
failingBaseDuration = 900 * time.Millisecond
failingFactorDuration = 500 * time.Millisecond
failingMaxDuration = 30 * time.Second
failingKeepAddedDuration = 10 * time.Second
)

var (
failingQueries = make(map[string]*failingQuery)
failingQueriesLock sync.RWMutex
failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag()
)

func checkIfQueryIsFailing(q *resolver.Query) (failingErr error, failingUntil *time.Time) {
// If the network changed, reset the failed queries.
if failingQueriesNetworkChangedFlag.IsSet() {
failingQueriesNetworkChangedFlag.Refresh()

failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()

// Compiler optimized map reset.
for key, _ := range failingQueries {
delete(failingQueries, key)
}

return nil, nil
}

failingQueriesLock.RLock()
defer failingQueriesLock.RUnlock()

// Quickly return if map is empty.
if len(failingQueries) == 0 {
return nil, nil
}

// Check if query failed recently.
failing, ok := failingQueries[q.ID()]
if !ok {
return nil, nil
}

// Check if failing query should still be regarded as failing.
if time.Now().After(failing.Until) {
return nil, nil
}

// Return failing error and until when it's valid.
return failing.Err, &failing.Until
}

func addFailingQuery(q *resolver.Query, err error) {
// Check if we were given an error.
if err == nil {
return
}

// Exclude reverse and mDNS queries, as they fail _often_ and are usually not
// retried quickly.
// if strings.HasSuffix(q.FQDN, ".in-addr.arpa.") ||
// strings.HasSuffix(q.FQDN, ".ip6.arpa.") ||
// strings.HasSuffix(q.FQDN, ".local.") {
// return
// }

failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()

failing, ok := failingQueries[q.ID()]
if !ok {
failing = &failingQuery{Err: err}
failingQueries[q.ID()] = failing
}

// Calculate fail duration.
// Initial fail duration will be at 900ms, perfect for a normal retry after 1s,
// but not any earlier.
failDuration := failingBaseDuration + time.Duration(failing.Times)*failingFactorDuration
if failDuration > failingMaxDuration {
failDuration = failingMaxDuration
}

// Update failing query.
failing.Times++
failing.Until = time.Now().Add(failDuration)
failing.Keep = failing.Until.Add(failingKeepAddedDuration)
}

func cleanFailingQueries(maxRemove, maxMiss int) {
failingQueriesLock.Lock()
defer failingQueriesLock.Unlock()

now := time.Now()
for key, failing := range failingQueries {
if now.After(failing.Keep) {
delete(failingQueries, key)

maxRemove--
if maxRemove == 0 {
return
}
} else {
maxMiss--
if maxMiss == 0 {
return
}
}
}
}
85 changes: 62 additions & 23 deletions nameserver/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,29 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
startTime := time.Now()
defer requestsHistogram.UpdateDuration(startTime)

// Only process first question, that's how everyone does it.
if len(request.Question) == 0 {
return errors.New("missing question")
// Check Question, only process the first, that's how everyone does it.
var originalQuestion dns.Question
switch len(request.Question) {
case 0:
log.Warning("nameserver: received query without question")
return sendResponse(ctx, w, request, nsutil.Refused("no question provided"))
case 1:
originalQuestion = request.Question[0]
default:
log.Warningf(
"nameserver: received query with multiple questions, first is %s.%s",
request.Question[0].Name,
dns.Type(request.Question[0].Qtype),
)
return sendResponse(ctx, w, request, nsutil.Refused("multiple question provided"))
}

// Check the Query Class.
if originalQuestion.Qclass != dns.ClassINET {
// We only serve IN records.
log.Warningf("nameserver: received unsupported qclass %d question for %s", originalQuestion.Qclass, originalQuestion.Name)
return sendResponse(ctx, w, request, nsutil.Refused("unsupported qclass"))
}
originalQuestion := request.Question[0]

// Check if we are handling a non-standard query name.
var nonStandardQuestionFormat bool
Expand All @@ -57,21 +75,15 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
// Get remote address of request.
remoteAddr, ok := w.RemoteAddr().(*net.UDPAddr)
if !ok {
log.Warningf("nameserver: failed to get remote address of request for %s%s, ignoring", q.FQDN, q.QType)
return nil
log.Warningf("nameserver: failed to get remote address of dns query: is type %+T", w.RemoteAddr())
return sendResponse(ctx, w, request, nsutil.Refused("unsupported transport"))
}
// log.Errorf("DEBUG: nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port)

// Start context tracer for context-aware logging.
ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit()
tracer.Tracef("nameserver: handling new request for %s from %s:%d", q.ID(), remoteAddr.IP, remoteAddr.Port)

// Check if there are more than one question.
if len(request.Question) > 1 {
tracer.Warningf("nameserver: received more than one question from (%s:%d), first question is %s", remoteAddr.IP, remoteAddr.Port, q.ID())
}

// Setup quick reply function.
reply := func(responder nsutil.Responder, rrProviders ...nsutil.RRProvider) error {
err := sendResponse(ctx, w, request, responder, rrProviders...)
Expand All @@ -82,13 +94,6 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
return nil
}

// Check the Query Class.
if originalQuestion.Qclass != dns.ClassINET {
// we only serve IN records, return nxdomain
tracer.Warningf("nameserver: only IN record requests are supported but received QClass %d, returning NXDOMAIN", originalQuestion.Qclass)
return reply(nsutil.Refused("unsupported qclass"))
}

// Handle request for localhost and the hostname.
if strings.HasSuffix(q.FQDN, "localhost.") || q.FQDN == hostname {
tracer.Tracef("nameserver: returning localhost records")
Expand All @@ -101,11 +106,43 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
return reply(nsutil.Refused("invalid domain"))
}

// Authenticate request - only requests from the local host, but with any of its IPs, are allowed.
// Check if query is failing.
// Some software retries failing queries excessively. This might not be a
// problem normally, but handling a request is pretty expensive for the
// Portmaster, as it has to find out who sent the query. If we know the query
// will fail with a very high probability, it is beneficial to just kill the
// query for some time before doing any expensive work.
defer cleanFailingQueries(10, 3)
failingErr, failingUntil := checkIfQueryIsFailing(q)
if failingErr != nil {
remainingFailingDuration := time.Until(*failingUntil)
tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr)

// Delay the response a bit in order to mitigate request flooding.
if remainingFailingDuration < failingDelay {
// Delay for remainind fail duration.
tracer.Tracef("nameserver: delaying failing lookup until end of fail duration for %s", remainingFailingDuration.Round(time.Millisecond))
time.Sleep(remainingFailingDuration)
return reply(nsutil.ServerFailure(
"internal error: "+failingErr.Error(),
"delayed failing query to mitigate request flooding",
))
}
// Delay for default duration.
tracer.Tracef("nameserver: delaying failing lookup for %s", failingDelay.Round(time.Millisecond))
time.Sleep(failingDelay)
return reply(nsutil.ServerFailure(
"internal error: "+failingErr.Error(),
"delayed failing query to mitigate request flooding",
fmt.Sprintf("error is cached for another %s", remainingFailingDuration.Round(time.Millisecond)),
))
}

// Check if the request is local.
local, err := netenv.IsMyIP(remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to check if request for %s is local: %s", q.ID(), err)
return nil // Do no reply, drop request immediately.
return reply(nsutil.ServerFailure("internal error: failed to check if request is local"))
}

// Create connection ID for dns request.
Expand All @@ -127,12 +164,12 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
conn, err = network.NewConnectionFromExternalDNSRequest(ctx, q.FQDN, nil, connID, remoteAddr.IP)
if err != nil {
tracer.Warningf("nameserver: failed to get host/profile for request for %s%s: %s", q.FQDN, q.QType, err)
return nil // Do no reply, drop request immediately.
return reply(nsutil.ServerFailure("internal error: failed to get profile"))
}

default:
tracer.Warningf("nameserver: external request for %s%s, ignoring", q.FQDN, q.QType)
return nil // Do no reply, drop request immediately.
return reply(nsutil.Refused("external queries are not permitted"))
}
conn.Lock()
defer conn.Unlock()
Expand Down Expand Up @@ -226,13 +263,15 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
log.Tracer(ctx).Debugf("nameserver: device is offline, using backup cache for %s", q.ID())
default:
tracer.Warningf("nameserver: failed to resolve %s: %s", q.ID(), err)
addFailingQuery(q, err)
return reply(nsutil.ServerFailure("internal error: " + err.Error()))
}
}
// Handle special cases.
switch {
case rrCache == nil:
tracer.Warning("nameserver: received successful, but empty reply from resolver")
addFailingQuery(q, errors.New("emptry reply from resolver"))
return reply(nsutil.ServerFailure("internal error: empty reply"))
case rrCache.RCode == dns.RcodeNameError:
// Return now if NXDomain.
Expand Down
15 changes: 11 additions & 4 deletions netenv/online-status.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,21 @@ func monitorOnlineStatus(ctx context.Context) error {
func getDynamicStatusTrigger() <-chan time.Time {
switch GetOnlineStatus() {
case StatusOffline:
return time.After(1 * time.Second)
// Will be triggered by network change anyway.
return time.After(20 * time.Second)
case StatusLimited, StatusPortal:
// Change will not be detected otherwise, but impact is minor.
return time.After(5 * time.Second)
case StatusSemiOnline:
// Very small impact.
return time.After(20 * time.Second)
case StatusOnline:
// Don't check until resolver reports problems.
return nil
case StatusUnknown:
return time.After(2 * time.Second)
return time.After(5 * time.Second)
default: // other unknown status
return time.After(1 * time.Minute)
return time.After(5 * time.Minute)
}
}

Expand Down Expand Up @@ -501,7 +505,9 @@ func checkOnlineStatus(ctx context.Context) {

// Check with primary dns check domain.
ips, err := net.LookupIP(DNSTestDomain)
if err == nil {
if err != nil {
log.Warningf("netenv: dns check query failed: %s", err)
} else {
// check for expected response
for _, ip := range ips {
if ip.Equal(DNSTestExpectedIP) {
Expand All @@ -514,6 +520,7 @@ func checkOnlineStatus(ctx context.Context) {
// If that did not work, check with fallback dns check domain.
ips, err = net.LookupIP(DNSFallbackTestDomain)
if err != nil {
log.Warningf("netenv: dns fallback check query failed: %s", err)
updateOnlineStatus(StatusLimited, nil, "dns fallback check query failed")
return
}
Expand Down
4 changes: 1 addition & 3 deletions network/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (
// currently unused and is unlikely to be used within the next seconds.
func GetUnusedLocalPort(protocol uint8) (port uint16, ok bool) {
allConns := conns.clone()

tries := 1000
hundredth := tries / 100

// Try up to 1000 times to find an unused port.
nextPort:
Expand All @@ -25,7 +23,7 @@ nextPort:
port := uint16(rN + 10000)

// Shrink range when we chew through the tries.
portRangeStart := port - uint16(100-(i/hundredth))
portRangeStart := port - 10

// Check if the generated port is unused.
nextConnection:
Expand Down
4 changes: 2 additions & 2 deletions resolver/scopes.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ var (
".168.192.in-addr.arpa.",

// RFC4193: IPv6 private-address reverse-mapping domains.
".d.f.ip6.arpa",
".c.f.ip6.arpa",
".d.f.ip6.arpa.",
".c.f.ip6.arpa.",

// RFC6761: Special use domains for documentation and testing.
".example.",
Expand Down
10 changes: 4 additions & 6 deletions updates/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,12 @@ func start() error {
// TriggerUpdate queues the update task to execute ASAP.
func TriggerUpdate(force bool) error {
switch {
case !module.OnlineSoon():
return fmt.Errorf("updates module is disabled")
case !module.Online():
updateASAP = true

case !force && !enableUpdates():
return fmt.Errorf("automatic updating is disabled")

case !module.Online():
updateASAP = true

default:
forceUpdate.Set()
updateTask.StartASAP()
Expand All @@ -189,7 +186,8 @@ func TriggerUpdate(force bool) error {
// If called, updates are only checked when TriggerUpdate()
// is called.
func DisableUpdateSchedule() error {
if module.OnlineSoon() {
switch module.Status() {
case modules.StatusStarting, modules.StatusOnline, modules.StatusStopping:
return fmt.Errorf("module already online")
}

Expand Down

0 comments on commit b28e626

Please sign in to comment.