Skip to content

Commit

Permalink
Pull request 370: 402 Fix ID race
Browse files Browse the repository at this point in the history
Updates #402.
Updates #403.

Squashed commit of the following:

commit 0cf20f1
Author: Eugene Burkov <[email protected]>
Date:   Tue Jul 16 18:17:59 2024 +0300

    upstream: fix cache race

commit da41431
Author: Eugene Burkov <[email protected]>
Date:   Tue Jul 16 17:53:28 2024 +0300

    upstream: decrease external timeout

commit f39f02e
Author: Eugene Burkov <[email protected]>
Date:   Tue Jul 16 17:46:57 2024 +0300

    upstream: imp docs

commit e3d9d20
Author: Eugene Burkov <[email protected]>
Date:   Tue Jul 16 15:50:24 2024 +0300

    upstream: fix id race
  • Loading branch information
EugeneOne1 committed Jul 17, 2024
1 parent 59666b4 commit 50640fb
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 54 deletions.
16 changes: 8 additions & 8 deletions upstream/dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ var _ Upstream = (*dnsCrypt)(nil)
func (p *dnsCrypt) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp, err = p.exchangeDNSCrypt(m)
func (p *dnsCrypt) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
resp, err = p.exchangeDNSCrypt(req)
if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
// If request times out, it is possible that the server configuration
// has been changed. It is safe to assume that the key was rotated, see
Expand All @@ -68,7 +68,7 @@ func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return nil, err
}

return p.exchangeDNSCrypt(m)
return p.exchangeDNSCrypt(req)
}

return resp, err
Expand All @@ -80,7 +80,7 @@ func (p *dnsCrypt) Close() (err error) {
}

// exchangeDNSCrypt attempts to send the DNS query and returns the response.
func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) {
func (p *dnsCrypt) exchangeDNSCrypt(req *dns.Msg) (resp *dns.Msg, err error) {
var client *dnscrypt.Client
var resolverInfo *dnscrypt.ResolverInfo
func() {
Expand Down Expand Up @@ -108,19 +108,19 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) {
// Go on.
}

resp, err = client.Exchange(m, resolverInfo)
resp, err = client.Exchange(req, resolverInfo)
if resp != nil && resp.Truncated {
q := &m.Question[0]
q := &req.Question[0]
p.logger.Debug(
"dnscrypt received truncated, falling back to tcp",
"addr", p.addr,
"question", q,
)

tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP}
resp, err = tcpClient.Exchange(m, resolverInfo)
resp, err = tcpClient.Exchange(req, resolverInfo)
}
if err == nil && resp != nil && resp.Id != m.Id {
if err == nil && resp != nil && resp.Id != req.Id {
err = dns.ErrId
}

Expand Down
29 changes: 17 additions & 12 deletions upstream/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
Expand Down Expand Up @@ -140,19 +141,21 @@ var _ Upstream = (*dnsOverHTTPS)(nil)
// password, the password is replaced with "xxxxx".
func (p *dnsOverHTTPS) Address() string { return p.addrRedacted }

// Exchange implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
// Exchange implements the [Upstream] interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
// TODO(e.burkov): Use some smarter cloning approach.
req = req.Copy()

// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such
// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
// request.
// formats that include the ID field from the DNS message header, such as
// "application/dns-message", SHOULD use a DNS ID of 0 in every DNS request.
//
// See https://www.rfc-editor.org/rfc/rfc8484.html.
id := m.Id
m.Id = 0
id := req.Id
req.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
m.Id = id
req.Id = id
if resp != nil {
resp.Id = id
}
Expand All @@ -166,7 +169,7 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
}

// Make the first attempt to send the DNS query.
resp, err = p.exchangeHTTPS(client, m)
resp, err = p.exchangeHTTPS(client, req)

// Make up to 2 attempts to re-create the HTTP client and send the request
// again. There are several cases (mostly, with QUIC) where this workaround
Expand All @@ -179,7 +182,7 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return nil, fmt.Errorf("failed to reset http client: %w", err)
}

resp, err = p.exchangeHTTPS(client, m)
resp, err = p.exchangeHTTPS(client, req)
}

if err != nil {
Expand Down Expand Up @@ -266,8 +269,10 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
return nil, fmt.Errorf("creating http request to %s: %w", p.addrRedacted, err)
}

httpReq.Header.Set("Accept", "application/dns-message")
httpReq.Header.Set("User-Agent", "")
// Prevent the client from sending User-Agent header, see
// https://github.com/AdguardTeam/dnsproxy/issues/211.
httpReq.Header.Set(httphdr.UserAgent, "")
httpReq.Header.Set(httphdr.Accept, "application/dns-message")

httpResp, err := client.Do(httpReq)
if err != nil {
Expand Down
20 changes: 13 additions & 7 deletions upstream/doq.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,20 @@ var _ Upstream = (*dnsOverQUIC)(nil)
func (p *dnsOverQUIC) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
func (p *dnsOverQUIC) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
// TODO(e.burkov): Use some smarter cloning approach.
req = req.Copy()

// When sending queries over a QUIC connection, the DNS Message ID MUST be
// set to zero.
id := m.Id
m.Id = 0
// set to 0. The stream mapping for DoQ allows for unambiguous correlation
// of queries and responses, so the Message ID field is not required.
//
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.2.1.
id := req.Id
req.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
m.Id = id
req.Id = id
if resp != nil {
resp.Id = id
}
Expand All @@ -162,7 +168,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
}

// Make the first attempt to send the DNS query.
resp, err = p.exchangeQUIC(m, conn)
resp, err = p.exchangeQUIC(req, conn)

// Failure to use a cached connection should be handled gracefully as this
// connection could have been closed by the server or simply be broken due
Expand All @@ -182,7 +188,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
}

// Retry sending the request through the new connection.
resp, err = p.exchangeQUIC(m, conn)
resp, err = p.exchangeQUIC(req, conn)
}

if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions upstream/dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ var _ Upstream = (*dnsOverTLS)(nil)
func (p *dnsOverTLS) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
func (p *dnsOverTLS) Exchange(req *dns.Msg) (reply *dns.Msg, err error) {
h, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
Expand All @@ -100,7 +100,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
}

reply, err = p.exchangeWithConn(conn, m)
reply, err = p.exchangeWithConn(conn, req)
if err != nil {
// The pooled connection might have been closed already, see
// https://github.com/AdguardTeam/dnsproxy/issues/3. The following
Expand All @@ -120,7 +120,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
)
}

reply, err = p.exchangeWithConn(conn, m)
reply, err = p.exchangeWithConn(conn, req)
if err != nil {
return reply, errors.WithDeferred(err, conn.Close())
}
Expand Down Expand Up @@ -192,23 +192,23 @@ func (p *dnsOverTLS) putBack(conn net.Conn) {
}

// exchangeWithConn tries to exchange the query using conn.
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, req *dns.Msg) (reply *dns.Msg, err error) {
addr := p.Address()

logBegin(p.logger, addr, networkTCP, m)
logBegin(p.logger, addr, networkTCP, req)
defer func() { logFinish(p.logger, addr, networkTCP, err) }()

dnsConn := dns.Conn{Conn: conn}

err = dnsConn.WriteMsg(m)
err = dnsConn.WriteMsg(req)
if err != nil {
return nil, fmt.Errorf("sending request to %s: %w", addr, err)
}

reply, err = dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", addr, err)
} else if reply.Id != m.Id {
} else if reply.Id != req.Id {
return reply, dns.ErrId
}

Expand Down
33 changes: 20 additions & 13 deletions upstream/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"net/netip"
"net/url"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -265,21 +266,21 @@ type CachingResolver struct {
// resolver is the underlying resolver to use for lookups.
resolver *UpstreamResolver

// mu protects cached and it's elements.
// mu protects cache and it's elements.
mu *sync.RWMutex

// cached is the set of cached results sorted by [resolveResult.name].
// cache is the set of resolved hostnames mapped to cached addresses.
//
// TODO(e.burkov): Use expiration cache.
cached map[string]*ipResult
cache map[string]*ipResult
}

// NewCachingResolver creates a new caching resolver that uses r for lookups.
func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) {
return &CachingResolver{
resolver: r,
mu: &sync.RWMutex{},
cached: map[string]*ipResult{},
cache: map[string]*ipResult{},
}
}

Expand All @@ -300,32 +301,38 @@ func (r *CachingResolver) LookupNetIP(

addrs = r.findCached(host, now)
if addrs != nil {
return addrs, nil
return slices.Clone(addrs), nil
}

newRes, err := r.resolver.lookupNetIP(ctx, network, host)
res, err := r.resolver.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}

r.mu.Lock()
defer r.mu.Unlock()
r.setCached(host, res)

r.cached[host] = newRes

return newRes.addrs, nil
return slices.Clone(res.addrs), nil
}

// findCached returns the cached addresses for host if it's not expired yet, and
// the corresponding cached result, if any.
// the corresponding cached result, if any. It's safe for concurrent use.
func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) {
r.mu.RLock()
defer r.mu.RUnlock()

res, ok := r.cached[host]
res, ok := r.cache[host]
if !ok || res.expire.Before(now) {
return nil
}

return res.addrs
}

// setCached sets the result into the address cache for host. It's safe for
// concurrent use.
func (r *CachingResolver) setCached(host string, res *ipResult) {
r.mu.Lock()
defer r.mu.Unlock()

r.cache[host] = res
}
15 changes: 9 additions & 6 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ import (
"github.com/quic-go/quic-go/logging"
)

// Upstream is an interface for a DNS resolver.
// Upstream is an interface for a DNS resolver. All the methods must be safe
// for concurrent use.
type Upstream interface {
// Exchange sends the DNS query req to this upstream and returns the
// response that has been received or an error if something went wrong.
// Exchange sends req to this upstream and returns the response that has
// been received or an error if something went wrong. The implementations
// must not modify req as well as the caller must not modify it until the
// method returns. It shouldn't be called after closing.
Exchange(req *dns.Msg) (resp *dns.Msg, err error)

// Address returns the address of the upstream DNS resolver.
// Address returns the human-readable address of the upstream DNS resolver.
// It may differ from what was passed to [AddressToUpstream].
Address() (addr string)

// Closer used to close the upstreams properly. Exchange shouldn't be
// called after calling Close.
// Closer used to close the upstreams properly.
io.Closer
}

Expand Down
4 changes: 3 additions & 1 deletion upstream/upstream_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestUpstream_bootstrapTimeout(t *testing.T) {
func TestUpstreams(t *testing.T) {
t.Parallel()

const upsTimeout = 500 * time.Second
const upsTimeout = 10 * time.Second

l := slogutil.NewDiscardLogger()

Expand Down Expand Up @@ -209,6 +209,8 @@ func TestUpstreams(t *testing.T) {

for _, test := range upstreams {
t.Run(test.address, func(t *testing.T) {
t.Parallel()

u, upsErr := AddressToUpstream(
test.address,
&Options{Logger: l, Bootstrap: test.bootstrap, Timeout: upsTimeout},
Expand Down

0 comments on commit 50640fb

Please sign in to comment.