Skip to content

Commit

Permalink
DNS v2 Multiple fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmurret committed Feb 7, 2024
1 parent 3c1ee84 commit b9f5668
Show file tree
Hide file tree
Showing 11 changed files with 445 additions and 103 deletions.
11 changes: 10 additions & 1 deletion agent/config/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ type RuntimeConfig struct {
// Records returned in the ANSWER section of a DNS response for UDP
// responses without EDNS support (limited to 512 bytes).
// This parameter is deprecated, if you want to limit the number of
// records returned by A or AAAA questions, please use DNSARecordLimit
// records returned by A or AAAA questions, please use TestDNS_ServiceLookup_Randomize
// instead.
//
// hcl: dns_config { udp_answer_limit = int }
Expand Down Expand Up @@ -564,6 +564,15 @@ type RuntimeConfig struct {
// flag: -data-dir string
DataDir string

// DefaultIntentionPolicy is used to define a default intention action for all
// sources and destinations. Possible values are "allow", "deny", or "" (blank).
// For compatibility, falls back to ACLResolverSettings.ACLDefaultPolicy (which
// itself has a default of "allow") if left blank. Future versions of Consul
// will default this field to "deny" to be secure by default.
//
// hcl: default_intention_policy = string
DefaultIntentionPolicy string

// DefaultQueryTime is the amount of time a blocking query will wait before
// Consul will force a response. This value can be overridden by the 'wait'
// query parameter.
Expand Down
26 changes: 18 additions & 8 deletions agent/discovery/query_fetcher_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e

// If we have no out.NodeServices.Nodeaddress, return not found!
if out.NodeServices == nil {
return nil, errors.New("no nodes found")
return nil, ErrNotFound
}

results := make([]*Result, 0, 1)
Expand Down Expand Up @@ -302,7 +302,11 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R

out, err := f.executePreparedQuery(cfg, args)
if err != nil {
return nil, err
// errors.Is() doesn't work with errors.New() so we need to check the error message.
if err.Error() == structs.ErrQueryNotFound.Error() {
err = ErrNotFound
}
return nil, ECSNotGlobalError{err}
}

// (v2-dns) TODO: (v2-dns) get TTLS working. They come from the database so not having
Expand Down Expand Up @@ -337,12 +341,12 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R

// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
return nil, ErrNoData
return nil, ECSNotGlobalError{ErrNoData}
}

// Perform a random shuffle
out.Nodes.Shuffle()
return f.buildResultsFromServiceNodes(out.Nodes), nil
return f.buildResultsFromServiceNodes(out.Nodes, req), ECSNotGlobalError{}
}

// executePreparedQuery is used to execute a PreparedQuery against the Consul catalog.
Expand Down Expand Up @@ -399,10 +403,16 @@ func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error {
}

// buildResultsFromServiceNodes builds a list of results from a list of nodes.
func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode) []*Result {
results := make([]*Result, 0)
for _, n := range nodes {
func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode, req *QueryPayload) []*Result {
// Convert the service endpoints to results up to the limit
limit := req.Limit
if len(nodes) < limit || limit == 0 {
limit = len(nodes)
}

results := make([]*Result, 0, limit)
for idx := 0; idx < limit; idx++ {
n := nodes[idx]
results = append(results, &Result{
Service: &Location{
Name: n.Service.Service,
Expand Down Expand Up @@ -534,7 +544,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa

// Perform a random shuffle
out.Nodes.Shuffle()
return f.buildResultsFromServiceNodes(out.Nodes), nil
return f.buildResultsFromServiceNodes(out.Nodes, req), nil
}

// findWeight returns the weight of a service node.
Expand Down
2 changes: 1 addition & 1 deletion agent/discovery/query_fetcher_v1_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (f *V1DataFetcher) NormalizeRequest(req *QueryPayload) {
}

func validateEnterpriseTenancy(req QueryTenancy) error {
if req.Namespace != "" || req.Partition != "" {
if req.Namespace != "" || req.Partition != acl.DefaultPartitionName {
return ErrNotSupported
}
return nil
Expand Down
161 changes: 117 additions & 44 deletions agent/dns/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/hashicorp/consul/acl"
"net"
"regexp"
"strings"
Expand All @@ -22,7 +23,6 @@ import (
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/dnsutil"
"github.com/hashicorp/consul/internal/resource"
"github.com/hashicorp/consul/logging"
)

Expand All @@ -42,7 +42,6 @@ var (
errInvalidQuestion = fmt.Errorf("invalid question")
errNameNotFound = fmt.Errorf("name not found")
errNotImplemented = fmt.Errorf("not implemented")
errQueryNotFound = fmt.Errorf("query not found")
errRecursionFailed = fmt.Errorf("recursion failed")

trailingSpacesRE = regexp.MustCompile(" +$")
Expand Down Expand Up @@ -147,6 +146,14 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.A
return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault)
}

// getErrorFromECSNotGlobalError returns the underlying error from an ECSNotGlobalError, if it exists.
func getErrorFromECSNotGlobalError(err error) error {
if errors.Is(err, discovery.ErrECSNotGlobal) {
return err.(discovery.ECSNotGlobalError).Unwrap()
}
return err
}

// handleRequestRecursively is used to process an individual DNS request. It will recurse as needed
// a maximum number of times and returns a message in success or fail cases.
func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
Expand Down Expand Up @@ -190,35 +197,47 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,

reqType := parseRequestType(req)
results, query, err := r.getQueryResults(req, reqCtx, reqType, qName, remoteAddress)
switch {
case errors.Is(err, errNameNotFound):
r.logger.Error("name not found", "name", qName)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, errNotImplemented):
r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String())

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal)
case errors.Is(err, discovery.ErrNotSupported):
r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name)
// incase of the wrapped ECSNotGlobalError, extract the error from it.
isECSGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
err = getErrorFromECSNotGlobalError(err)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNotFound):
r.logger.Debug("query name not found", "name", req.Question[0].Name)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNoData):
r.logger.Debug("no data available", "name", qName)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeSuccess, ecsGlobal)
case err != nil:
r.logger.Error("error processing discovery query", "error", err)
return createServerFailureResponse(req, configCtx, canRecurse(configCtx))
if err != nil {
switch {
case errors.Is(err, errInvalidQuestion):
r.logger.Error("invalid question", "name", qName)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, errNameNotFound):
r.logger.Error("name not found", "name", qName)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, errNotImplemented):
r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String())

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal)
case errors.Is(err, discovery.ErrNotSupported):
r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNotFound):
r.logger.Debug("query name not found", "name", req.Question[0].Name)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal)
case errors.Is(err, discovery.ErrNoData):
r.logger.Debug("no data available", "name", qName)

ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal)
return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeSuccess, ecsGlobal)
default:
r.logger.Error("error processing discovery query", "error", err)
return createServerFailureResponse(req, configCtx, canRecurse(configCtx))
}
}

// This needs the question information because it affects the serialization format.
Expand All @@ -228,6 +247,16 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context,
r.logger.Error("error serializing DNS results", "error", err)
return createServerFailureResponse(req, configCtx, false)
}

// Switch to TCP if the client is
network := "udp"
if _, ok := remoteAddress.(*net.TCPAddr); ok {
network = "tcp"
}

trimDNSResponse(configCtx, network, req, resp, r.logger)

setEDNS(req, resp, isECSGlobal)
return resp
}

Expand Down Expand Up @@ -289,7 +318,7 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy
// We don't want the query processors default partition to be used.
// This is a small hack because for V1 CE, this is not the correct default partition name, but we
// need to add something to disambiguate the empty field.
Partition: resource.DefaultPartitionName,
Partition: acl.DefaultPartitionName, //NOTE: note this won't work if we ever have V2 client agents
},
Limit: 3,
},
Expand All @@ -304,18 +333,12 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy
return nil, query, err
}
results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token})
if err != nil {
r.logger.Error("error processing discovery query", "error", err)
switch err.Error() {
case errNameNotFound.Error():
return nil, query, errNameNotFound
case errQueryNotFound.Error():
return nil, query, errQueryNotFound
}

if getErrorFromECSNotGlobalError(err) != nil {
r.logger.Error("error processing discovery query", "error", err)
return nil, query, err
}
return results, query, nil
return results, query, err
case requestTypeIP:
ip := dnsutil.IPFromARPA(qName)
if ip == nil {
Expand All @@ -332,7 +355,9 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy
}
return results, nil, nil
}
return nil, nil, errors.New("invalid request type")

r.logger.Error("error parsing discovery query type", "requestType", reqType)
return nil, nil, errInvalidQuestion
}

// ServeDNS implements the miekg/dns.Handler interface.
Expand Down Expand Up @@ -452,8 +477,30 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context,
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
case qType == dns.TypeSRV, reqType == requestTypeAddress:
case reqType == requestTypeAddress:
for _, result := range results {
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
case qType == dns.TypeSRV:
handled := make(map[string]struct{})
for _, result := range results {
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.

// The datacenter should be empty during translation if it is a peering lookup.
// This should be fine because we should always prefer the WAN address.
//serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny)
//servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses)
//tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort)

tuple := fmt.Sprintf("%s:%s:%d", result.Node.Name, result.Service.Address, result.PortNumber)
if _, ok := handled[tuple]; ok {
continue
}
handled[tuple] = struct{}{}
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
Expand Down Expand Up @@ -695,6 +742,7 @@ func createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursi
if edns := req.IsEdns0(); edns != nil {
setEDNS(req, m, true)
}

return m
}

Expand Down Expand Up @@ -844,7 +892,12 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
answer = append(answer, ptr)
case qType == dns.TypeNS:
// TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result
fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName)
resultType := result.Type
target := result.Node.Name
if parseRequestType(req) == requestTypeConsul && resultType == discovery.ResultTypeService {
resultType = discovery.ResultTypeNode
}
fqdn := canonicalNameForResult(resultType, target, domain, result.Tenancy, result.PortName)
extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported

answer = append(answer, makeNSRecord(domain, fqdn, ttl))
Expand All @@ -871,7 +924,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req
extra = append(extra, e...)
}

a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain)
a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain, query)
answer = append(answer, a...)
extra = append(extra, e...)
return
Expand Down Expand Up @@ -954,7 +1007,10 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, ser
// getAnswerAndExtraTXT determines whether a TXT needs to be create and then
// returns the TXT record in the answer or extra depending on the question type.
func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string,
result *discovery.Result, ttl uint32, domain string) (answer []dns.RR, extra []dns.RR) {
result *discovery.Result, ttl uint32, domain string, query *discovery.Query) (answer []dns.RR, extra []dns.RR) {
if !shouldAppendTXTRecord(query, cfg, req) {
return
}
recordHeaderName := qName
serviceAddress := newDNSAddress("")
if result.Service != nil {
Expand Down Expand Up @@ -989,6 +1045,23 @@ func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string,
return answer, extra
}

// shouldAppendTXTRecord determines whether a TXT record should be appended to the response.
func shouldAppendTXTRecord(query *discovery.Query, cfg *RouterDynamicConfig, req *dns.Msg) bool {
qType := req.Question[0].Qtype
switch {
// Node records
case query != nil && query.QueryType == discovery.QueryTypeNode && (cfg.NodeMetaTXT || qType == dns.TypeANY || qType == dns.TypeTXT):
return true
// Service records
case query != nil && query.QueryType == discovery.QueryTypeService && cfg.NodeMetaTXT && qType == dns.TypeSRV:
return true
// Prepared query records
case query != nil && query.QueryType == discovery.QueryTypePreparedQuery && cfg.NodeMetaTXT && qType == dns.TypeSRV:
return true
}
return false
}

// getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs.
func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question,
reqType requestType, result *discovery.Result, ttl uint32, _ string) (answer []dns.RR, extra []dns.RR) {
Expand Down
5 changes: 5 additions & 0 deletions agent/dns/router_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func canonicalNameForResult(resultType discovery.ResultType, target, domain stri
}
return ""
}

// getDefaultPartitionName returns the default partition name.
func getDefaultPartitionName() string {
return ""
}
5 changes: 4 additions & 1 deletion agent/dns/router_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain st

portName := parsePort(queryParts)

if queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV {
switch {
case queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV:
// Currently we do not support SRV records for workloads
return nil, errNotImplemented
case queryType == discovery.QueryTypeInvalid, name == "":
return nil, errInvalidQuestion
}

return &discovery.Query{
Expand Down
Loading

0 comments on commit b9f5668

Please sign in to comment.