From 50640fb5f7b417bcf40eda32500df7173c22977c Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 17 Jul 2024 12:35:04 +0300 Subject: [PATCH] Pull request 370: 402 Fix ID race Updates #402. Updates #403. Squashed commit of the following: commit 0cf20f15e7e6a9f2d79ab0b2979cb0d65117d823 Author: Eugene Burkov Date: Tue Jul 16 18:17:59 2024 +0300 upstream: fix cache race commit da41431aecf8e9f24e05517d5114a3ff58cf3f74 Author: Eugene Burkov Date: Tue Jul 16 17:53:28 2024 +0300 upstream: decrease external timeout commit f39f02e5d0e1d186656d166a4f6ba72d91446338 Author: Eugene Burkov Date: Tue Jul 16 17:46:57 2024 +0300 upstream: imp docs commit e3d9d20c0976f3bd52c1f517ce673a61e9a680f4 Author: Eugene Burkov Date: Tue Jul 16 15:50:24 2024 +0300 upstream: fix id race --- upstream/dnscrypt.go | 16 +++++++-------- upstream/doh.go | 29 +++++++++++++++----------- upstream/doq.go | 20 +++++++++++------- upstream/dot.go | 14 ++++++------- upstream/resolver.go | 33 ++++++++++++++++++------------ upstream/upstream.go | 15 ++++++++------ upstream/upstream_internal_test.go | 4 +++- 7 files changed, 77 insertions(+), 54 deletions(-) diff --git a/upstream/dnscrypt.go b/upstream/dnscrypt.go index 3a06376c6..574597789 100644 --- a/upstream/dnscrypt.go +++ b/upstream/dnscrypt.go @@ -56,8 +56,8 @@ var _ Upstream = (*dnsCrypt)(nil) func (p *dnsCrypt) Address() string { return p.addr.String() } // Exchange implements the [Upstream] interface for *dnsCrypt. -func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { - resp, err = p.exchangeDNSCrypt(m) +func (p *dnsCrypt) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + resp, err = p.exchangeDNSCrypt(req) if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) { // If request times out, it is possible that the server configuration // has been changed. It is safe to assume that the key was rotated, see @@ -68,7 +68,7 @@ func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return nil, err } - return p.exchangeDNSCrypt(m) + return p.exchangeDNSCrypt(req) } return resp, err @@ -80,7 +80,7 @@ func (p *dnsCrypt) Close() (err error) { } // exchangeDNSCrypt attempts to send the DNS query and returns the response. -func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) { +func (p *dnsCrypt) exchangeDNSCrypt(req *dns.Msg) (resp *dns.Msg, err error) { var client *dnscrypt.Client var resolverInfo *dnscrypt.ResolverInfo func() { @@ -108,9 +108,9 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) { // Go on. } - resp, err = client.Exchange(m, resolverInfo) + resp, err = client.Exchange(req, resolverInfo) if resp != nil && resp.Truncated { - q := &m.Question[0] + q := &req.Question[0] p.logger.Debug( "dnscrypt received truncated, falling back to tcp", "addr", p.addr, @@ -118,9 +118,9 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) { ) tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP} - resp, err = tcpClient.Exchange(m, resolverInfo) + resp, err = tcpClient.Exchange(req, resolverInfo) } - if err == nil && resp != nil && resp.Id != m.Id { + if err == nil && resp != nil && resp.Id != req.Id { err = dns.ErrId } diff --git a/upstream/doh.go b/upstream/doh.go index 1863f83aa..bffe555a3 100644 --- a/upstream/doh.go +++ b/upstream/doh.go @@ -16,6 +16,7 @@ import ( "github.com/AdguardTeam/dnsproxy/internal/bootstrap" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -140,19 +141,21 @@ var _ Upstream = (*dnsOverHTTPS)(nil) // password, the password is replaced with "xxxxx". func (p *dnsOverHTTPS) Address() string { return p.addrRedacted } -// Exchange implements the Upstream interface for *dnsOverHTTPS. -func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { +// Exchange implements the [Upstream] interface for *dnsOverHTTPS. +func (p *dnsOverHTTPS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + // TODO(e.burkov): Use some smarter cloning approach. + req = req.Copy() + // In order to maximize HTTP cache friendliness, DoH clients using media - // formats that include the ID field from the DNS message header, such - // as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS - // request. + // formats that include the ID field from the DNS message header, such as + // "application/dns-message", SHOULD use a DNS ID of 0 in every DNS request. // // See https://www.rfc-editor.org/rfc/rfc8484.html. - id := m.Id - m.Id = 0 + id := req.Id + req.Id = 0 defer func() { // Restore the original ID to not break compatibility with proxies. - m.Id = id + req.Id = id if resp != nil { resp.Id = id } @@ -166,7 +169,7 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { } // Make the first attempt to send the DNS query. - resp, err = p.exchangeHTTPS(client, m) + resp, err = p.exchangeHTTPS(client, req) // Make up to 2 attempts to re-create the HTTP client and send the request // again. There are several cases (mostly, with QUIC) where this workaround @@ -179,7 +182,7 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return nil, fmt.Errorf("failed to reset http client: %w", err) } - resp, err = p.exchangeHTTPS(client, m) + resp, err = p.exchangeHTTPS(client, req) } if err != nil { @@ -266,8 +269,10 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( return nil, fmt.Errorf("creating http request to %s: %w", p.addrRedacted, err) } - httpReq.Header.Set("Accept", "application/dns-message") - httpReq.Header.Set("User-Agent", "") + // Prevent the client from sending User-Agent header, see + // https://github.com/AdguardTeam/dnsproxy/issues/211. + httpReq.Header.Set(httphdr.UserAgent, "") + httpReq.Header.Set(httphdr.Accept, "application/dns-message") httpResp, err := client.Do(httpReq) if err != nil { diff --git a/upstream/doq.go b/upstream/doq.go index fffbad486..468373eed 100644 --- a/upstream/doq.go +++ b/upstream/doq.go @@ -142,14 +142,20 @@ var _ Upstream = (*dnsOverQUIC)(nil) func (p *dnsOverQUIC) Address() string { return p.addr.String() } // Exchange implements the [Upstream] interface for *dnsOverQUIC. -func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { +func (p *dnsOverQUIC) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + // TODO(e.burkov): Use some smarter cloning approach. + req = req.Copy() + // When sending queries over a QUIC connection, the DNS Message ID MUST be - // set to zero. - id := m.Id - m.Id = 0 + // set to 0. The stream mapping for DoQ allows for unambiguous correlation + // of queries and responses, so the Message ID field is not required. + // + // See https://www.rfc-editor.org/rfc/rfc9250#section-4.2.1. + id := req.Id + req.Id = 0 defer func() { // Restore the original ID to not break compatibility with proxies. - m.Id = id + req.Id = id if resp != nil { resp.Id = id } @@ -162,7 +168,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { } // Make the first attempt to send the DNS query. - resp, err = p.exchangeQUIC(m, conn) + resp, err = p.exchangeQUIC(req, conn) // Failure to use a cached connection should be handled gracefully as this // connection could have been closed by the server or simply be broken due @@ -182,7 +188,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { } // Retry sending the request through the new connection. - resp, err = p.exchangeQUIC(m, conn) + resp, err = p.exchangeQUIC(req, conn) } if err != nil { diff --git a/upstream/dot.go b/upstream/dot.go index fd2f2af1d..0d7c3abb5 100644 --- a/upstream/dot.go +++ b/upstream/dot.go @@ -89,7 +89,7 @@ var _ Upstream = (*dnsOverTLS)(nil) func (p *dnsOverTLS) Address() string { return p.addr.String() } // Exchange implements the [Upstream] interface for *dnsOverTLS. -func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { +func (p *dnsOverTLS) Exchange(req *dns.Msg) (reply *dns.Msg, err error) { h, err := p.getDialer() if err != nil { return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err) @@ -100,7 +100,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err) } - reply, err = p.exchangeWithConn(conn, m) + reply, err = p.exchangeWithConn(conn, req) if err != nil { // The pooled connection might have been closed already, see // https://github.com/AdguardTeam/dnsproxy/issues/3. The following @@ -120,7 +120,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { ) } - reply, err = p.exchangeWithConn(conn, m) + reply, err = p.exchangeWithConn(conn, req) if err != nil { return reply, errors.WithDeferred(err, conn.Close()) } @@ -192,15 +192,15 @@ func (p *dnsOverTLS) putBack(conn net.Conn) { } // exchangeWithConn tries to exchange the query using conn. -func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { +func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, req *dns.Msg) (reply *dns.Msg, err error) { addr := p.Address() - logBegin(p.logger, addr, networkTCP, m) + logBegin(p.logger, addr, networkTCP, req) defer func() { logFinish(p.logger, addr, networkTCP, err) }() dnsConn := dns.Conn{Conn: conn} - err = dnsConn.WriteMsg(m) + err = dnsConn.WriteMsg(req) if err != nil { return nil, fmt.Errorf("sending request to %s: %w", addr, err) } @@ -208,7 +208,7 @@ func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg reply, err = dnsConn.ReadMsg() if err != nil { return nil, fmt.Errorf("reading response from %s: %w", addr, err) - } else if reply.Id != m.Id { + } else if reply.Id != req.Id { return reply, dns.ErrId } diff --git a/upstream/resolver.go b/upstream/resolver.go index 451416072..2dbe8f887 100644 --- a/upstream/resolver.go +++ b/upstream/resolver.go @@ -6,6 +6,7 @@ import ( "math" "net/netip" "net/url" + "slices" "strings" "sync" "time" @@ -265,13 +266,13 @@ type CachingResolver struct { // resolver is the underlying resolver to use for lookups. resolver *UpstreamResolver - // mu protects cached and it's elements. + // mu protects cache and it's elements. mu *sync.RWMutex - // cached is the set of cached results sorted by [resolveResult.name]. + // cache is the set of resolved hostnames mapped to cached addresses. // // TODO(e.burkov): Use expiration cache. - cached map[string]*ipResult + cache map[string]*ipResult } // NewCachingResolver creates a new caching resolver that uses r for lookups. @@ -279,7 +280,7 @@ func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { return &CachingResolver{ resolver: r, mu: &sync.RWMutex{}, - cached: map[string]*ipResult{}, + cache: map[string]*ipResult{}, } } @@ -300,32 +301,38 @@ func (r *CachingResolver) LookupNetIP( addrs = r.findCached(host, now) if addrs != nil { - return addrs, nil + return slices.Clone(addrs), nil } - newRes, err := r.resolver.lookupNetIP(ctx, network, host) + res, err := r.resolver.lookupNetIP(ctx, network, host) if err != nil { return []netip.Addr{}, err } - r.mu.Lock() - defer r.mu.Unlock() + r.setCached(host, res) - r.cached[host] = newRes - - return newRes.addrs, nil + return slices.Clone(res.addrs), nil } // findCached returns the cached addresses for host if it's not expired yet, and -// the corresponding cached result, if any. +// the corresponding cached result, if any. It's safe for concurrent use. func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) { r.mu.RLock() defer r.mu.RUnlock() - res, ok := r.cached[host] + res, ok := r.cache[host] if !ok || res.expire.Before(now) { return nil } return res.addrs } + +// setCached sets the result into the address cache for host. It's safe for +// concurrent use. +func (r *CachingResolver) setCached(host string, res *ipResult) { + r.mu.Lock() + defer r.mu.Unlock() + + r.cache[host] = res +} diff --git a/upstream/upstream.go b/upstream/upstream.go index 559d54366..bb57a4725 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -28,17 +28,20 @@ import ( "github.com/quic-go/quic-go/logging" ) -// Upstream is an interface for a DNS resolver. +// Upstream is an interface for a DNS resolver. All the methods must be safe +// for concurrent use. type Upstream interface { - // Exchange sends the DNS query req to this upstream and returns the - // response that has been received or an error if something went wrong. + // Exchange sends req to this upstream and returns the response that has + // been received or an error if something went wrong. The implementations + // must not modify req as well as the caller must not modify it until the + // method returns. It shouldn't be called after closing. Exchange(req *dns.Msg) (resp *dns.Msg, err error) - // Address returns the address of the upstream DNS resolver. + // Address returns the human-readable address of the upstream DNS resolver. + // It may differ from what was passed to [AddressToUpstream]. Address() (addr string) - // Closer used to close the upstreams properly. Exchange shouldn't be - // called after calling Close. + // Closer used to close the upstreams properly. io.Closer } diff --git a/upstream/upstream_internal_test.go b/upstream/upstream_internal_test.go index 4d7d8f599..78e5d1fe9 100644 --- a/upstream/upstream_internal_test.go +++ b/upstream/upstream_internal_test.go @@ -106,7 +106,7 @@ func TestUpstream_bootstrapTimeout(t *testing.T) { func TestUpstreams(t *testing.T) { t.Parallel() - const upsTimeout = 500 * time.Second + const upsTimeout = 10 * time.Second l := slogutil.NewDiscardLogger() @@ -209,6 +209,8 @@ func TestUpstreams(t *testing.T) { for _, test := range upstreams { t.Run(test.address, func(t *testing.T) { + t.Parallel() + u, upsErr := AddressToUpstream( test.address, &Options{Logger: l, Bootstrap: test.bootstrap, Timeout: upsTimeout},