From b63d482e54141666684f4941df33c4efd234f180 Mon Sep 17 00:00:00 2001 From: Derrick Wippler Date: Thu, 26 May 2022 20:03:22 -0500 Subject: [PATCH 1/2] Added AddOverrideHeader() --- mailgun.go | 21 ++++++++++++++++----- messages.go | 4 ++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mailgun.go b/mailgun.go index b3fd8df2..e35e9b47 100644 --- a/mailgun.go +++ b/mailgun.go @@ -125,6 +125,7 @@ type Mailgun interface { Client() *http.Client SetClient(client *http.Client) SetAPIBase(url string) + AddOverrideHeader(k string, v string) Send(ctx context.Context, m *Message) (string, string, error) ReSend(ctx context.Context, id string, recipients ...string) (string, string, error) @@ -241,11 +242,12 @@ type Mailgun interface { // MailgunImpl bundles data needed by a large number of methods in order to interact with the Mailgun API. // Colloquially, we refer to instances of this structure as "clients." type MailgunImpl struct { - apiBase string - domain string - apiKey string - client *http.Client - baseURL string + apiBase string + domain string + apiKey string + client *http.Client + baseURL string + overrideHeaders map[string]string } // NewMailGun creates a new client instance. @@ -318,6 +320,15 @@ func (mg *MailgunImpl) SetAPIBase(address string) { mg.apiBase = address } +// AddOverrideHeader allows the user to specify additional headers that will be included in the HTTP request +// This is mostly useful for testing the Mailgun API hosted at a different endpoint. +func (mg *MailgunImpl) AddOverrideHeader(k string, v string) { + if mg.overrideHeaders == nil { + mg.overrideHeaders = make(map[string]string) + } + mg.overrideHeaders[k] = v +} + // generateApiUrl renders a URL for an API endpoint using the domain and endpoint name. func generateApiUrl(m Mailgun, endpoint string) string { return fmt.Sprintf("%s/%s/%s", m.APIBase(), m.Domain(), endpoint) diff --git a/messages.go b/messages.go index a21edb87..90cfca76 100644 --- a/messages.go +++ b/messages.go @@ -672,6 +672,10 @@ func (mg *MailgunImpl) Send(ctx context.Context, message *Message) (mes string, r := newHTTPRequest(generateApiUrlWithDomain(mg, message.specific.endpoint(), message.domain)) r.setClient(mg.Client()) r.setBasicAuth(basicAuthUser, mg.APIKey()) + // Override any HTTP headers if provided + for k, v := range mg.overrideHeaders { + r.addHeader(k, v) + } var response sendMessageResponse err = postResponseFromJSON(ctx, r, payload, &response) From d25e820327acfb38c85c8b132de5f89c1c8b68d8 Mon Sep 17 00:00:00 2001 From: Derrick Wippler Date: Thu, 26 May 2022 20:20:06 -0500 Subject: [PATCH 2/2] Added tests for AddOverrideHeader() --- httphelpers.go | 8 +++++++- messages_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/httphelpers.go b/httphelpers.go index 5156b6ed..a0249b7b 100644 --- a/httphelpers.go +++ b/httphelpers.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/pkg/errors" "io" "io/ioutil" "mime/multipart" @@ -15,6 +14,8 @@ import ( "path" "regexp" "strings" + + "github.com/pkg/errors" ) var validURL = regexp.MustCompile(`/v[2-4].*`) @@ -274,6 +275,11 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay } for header, value := range r.Headers { + // Special case, override the Host header + if header == "Host" { + req.Host = value + continue + } req.Header.Add(header, value) } return req, nil diff --git a/messages_test.go b/messages_test.go index e3786be5..9f2ad8b4 100644 --- a/messages_test.go +++ b/messages_test.go @@ -536,6 +536,41 @@ func TestResendStored(t *testing.T) { ensure.DeepEqual(t, id, exampleID) } +func TestAddOverrideHeader(t *testing.T) { + const ( + exampleDomain = "testDomain" + exampleAPIKey = "testAPIKey" + toUser = "test@test.com" + exampleMessage = "Queue. Thank you" + exampleID = "<20111114174239.25659.5817@samples.mailgun.org>" + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ensure.DeepEqual(t, req.Method, http.MethodPost) + ensure.DeepEqual(t, req.URL.Path, fmt.Sprintf("/v3/%s/messages", exampleDomain)) + ensure.DeepEqual(t, req.Header.Get("CustomHeader"), "custom-value") + ensure.DeepEqual(t, req.Host, "example.com") + + rsp := fmt.Sprintf(`{"message":"%s", "id":"%s"}`, exampleMessage, exampleID) + fmt.Fprint(w, rsp) + })) + defer srv.Close() + + mg := NewMailgun(exampleDomain, exampleAPIKey) + mg.SetAPIBase(srv.URL + "/v3") + mg.AddOverrideHeader("Host", "example.com") + mg.AddOverrideHeader("CustomHeader", "custom-value") + ctx := context.Background() + + m := mg.NewMessage(fromUser, exampleSubject, exampleText, toUser) + m.SetRequireTLS(true) + m.SetSkipVerification(true) + + msg, id, err := mg.Send(ctx, m) + ensure.Nil(t, err) + ensure.DeepEqual(t, msg, exampleMessage) + ensure.DeepEqual(t, id, exampleID) +} + func TestSendTLSOptions(t *testing.T) { const ( exampleDomain = "testDomain"