Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add sync.RWMutex to headerforwarder map #511

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
Expand All @@ -36,7 +37,9 @@ import (
// TODO: this should expire entries after a certain amount of time
type HeaderForwarder struct {
incomingHeaders map[string]http.Header
incomingHeaderLock sync.RWMutex
outgoingHeaders map[string]http.Header
outgoingHeaderLock sync.RWMutex
interestingHeaders []string
actualTransport http.RoundTripper
}
Expand Down Expand Up @@ -66,14 +69,20 @@ func (hf *HeaderForwarder) captureOutgoingHeaders(req *http.Request) {
ctx := req.Context()
rosettaRequestID := RosettaIDFromContext(ctx)

hf.outgoingHeaders[rosettaRequestID] = make(http.Header)
// We don't worry about overwriting headers here because we only handle "outgoing" headers
// once: when the rosetta request is made
outgoingRequestHeaders := make(http.Header)

// Only capture interesting headers
for _, interestingHeader := range hf.interestingHeaders {
if _, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]; requestHasHeader {
hf.outgoingHeaders[rosettaRequestID].Set(interestingHeader, req.Header.Get(interestingHeader))
outgoingRequestHeaders.Set(interestingHeader, req.Header.Get(interestingHeader))
}
}

hf.outgoingHeaderLock.Lock()
hf.outgoingHeaders[rosettaRequestID] = outgoingRequestHeaders
hf.outgoingHeaderLock.Unlock()
}

// shouldRememberHeaders reports whether response headers should be remembered for a
Expand Down Expand Up @@ -113,7 +122,10 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
hf.incomingHeaderLock.RLock()
headersToRemember, exists := hf.incomingHeaders[rosettaRequestID]
hf.incomingHeaderLock.RUnlock()

if !exists {
headersToRemember = make(http.Header)
}
Expand All @@ -123,7 +135,9 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}

hf.incomingHeaderLock.Lock()
hf.incomingHeaders[rosettaRequestID] = headersToRemember
hf.incomingHeaderLock.Unlock()
}

// shouldRememberMetadata reports whether response metadata should be remembered for a grpc unary
Expand Down Expand Up @@ -152,7 +166,10 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
hf.incomingHeaderLock.RLock()
headersToRemember, exists := hf.incomingHeaders[rosettaID]
hf.incomingHeaderLock.RUnlock()

if !exists {
headersToRemember = make(http.Header)
}
Expand All @@ -163,20 +180,28 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M
}
}

hf.incomingHeaderLock.Lock()
hf.incomingHeaders[rosettaID] = headersToRemember
hf.incomingHeaderLock.Unlock()
}

// GetResponseHeaders returns any headers that should be returned to a rosetta response. These
// consist of native node response headers/metadata that were remembered for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
hf.incomingHeaderLock.RLock()
headers, ok := hf.incomingHeaders[rosettaRequestID]
hf.incomingHeaderLock.RUnlock()

// Delete the headers from the map after they are retrieved
// This is safe to call even if the key doesn't exist
hf.incomingHeaderLock.Lock()
delete(hf.incomingHeaders, rosettaRequestID)
hf.incomingHeaderLock.Unlock()

// Also delete the outgoing headers from the map since we are done with them
hf.outgoingHeaderLock.Lock()
delete(hf.outgoingHeaders, rosettaRequestID)
hf.outgoingHeaderLock.Unlock()

return headers, ok
}
Expand Down Expand Up @@ -209,8 +234,12 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle
// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
hf.outgoingHeaderLock.RLock()
outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromRequest(req)]
hf.outgoingHeaderLock.RUnlock()

// add outgoing headers to the request
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromRequest(req)]; ok {
if hasOutgoingHeaders {
for header, values := range outgoingHeaders {
for _, value := range values {
req.Header.Add(header, value)
Expand All @@ -227,13 +256,18 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error)
return resp, err
}

func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// Capture incoming headers from the grpc call
var header metadata.MD
opts = append(opts, grpc.Header(&header))

// Get outgoing headers from the request ID in context
hf.outgoingHeaderLock.RLock()
outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromContext(ctx)]
hf.outgoingHeaderLock.RUnlock()

// Add outgoing headers to the context
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromContext(ctx)]; ok {
if hasOutgoingHeaders {
for header, values := range outgoingHeaders {
for _, value := range values {
ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(header), value)
Expand Down
Loading