Skip to content

Commit

Permalink
fix: some minor fixes and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Stockton committed Dec 5, 2024
1 parent 42ed14a commit f51f1f8
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 47 deletions.
29 changes: 26 additions & 3 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,31 @@ type MailerConfiguration struct {
ExternalHosts []string `json:"external_hosts" split_words:"true"`

// EXPERIMENTAL: May be removed in a future release.
EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"`
EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"`
EmailValidationServiceKey string `json:"email_validation_service_key" split_words:"true"`
EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"`
EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"`
EmailValidationServiceHeaders string `json:"email_validation_service_key" split_words:"true"`

serviceHeaders map[string][]string `json:"-"`
}

func (c *MailerConfiguration) Validate() error {
headers := make(map[string][]string)

if c.EmailValidationServiceHeaders != "" {
err := json.Unmarshal([]byte(c.EmailValidationServiceHeaders), &headers)
if err != nil {
return fmt.Errorf("conf: SMTP headers not a map[string][]string format: %w", err)
}
}

if len(headers) > 0 {
c.serviceHeaders = headers
}
return nil
}

func (c *MailerConfiguration) GetEmailValidationServiceHeaders() map[string][]string {
return c.serviceHeaders
}

type PhoneProviderConfiguration struct {
Expand Down Expand Up @@ -1025,6 +1047,7 @@ func (c *GlobalConfiguration) Validate() error {
&c.Tracing,
&c.Metrics,
&c.SMTP,
&c.Mailer,
&c.SAML,
&c.Security,
&c.Sessions,
Expand Down
7 changes: 7 additions & 0 deletions internal/conf/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestGlobal(t *testing.T) {
os.Setenv("GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI", "pg-functions://postgres/auth/count_failed_attempts")
os.Setenv("GOTRUE_HOOK_SEND_SMS_SECRETS", "v1,whsec_aWxpa2VzdXBhYmFzZXZlcnltdWNoYW5kaWhvcGV5b3Vkb3Rvbw==")
os.Setenv("GOTRUE_SMTP_HEADERS", `{"X-PM-Metadata-project-ref":["project_ref"],"X-SES-Message-Tags":["ses:feedback-id-a=project_ref,ses:feedback-id-b=$messageType"]}`)
os.Setenv("GOTRUE_MAILER_EMAIL_VALIDATION_SERVICE_HEADERS", `{"apikey":["test"]}`)
os.Setenv("GOTRUE_SMTP_LOGGING_ENABLED", "true")
gc, err := LoadGlobal("")
require.NoError(t, err)
Expand All @@ -34,6 +35,12 @@ func TestGlobal(t *testing.T) {
assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader)
assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFAVerificationAttempt.URI)

{
hdrs := gc.Mailer.GetEmailValidationServiceHeaders()
assert.Equal(t, 1, len(hdrs["apikey"]))
assert.Equal(t, "test", hdrs["apikey"][0])
}

}

func TestRateLimits(t *testing.T) {
Expand Down
7 changes: 5 additions & 2 deletions internal/mailer/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp
})
}

errors := make(chan error)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()

errors := make(chan error, len(emails))
for _, email := range emails {
path, err := getPath(
m.Config.Mailer.URLPaths.EmailChange,
Expand All @@ -279,7 +282,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp
"RedirectTo": referrerURL,
}
errors <- m.Mailer.Mail(
r.Context(),
ctx,
address,
withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"),
template,
Expand Down
43 changes: 25 additions & 18 deletions internal/mailer/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var invalidHostMap = map[string]bool{
}

const (
validateEmailTimeout = 2 * time.Second
validateEmailTimeout = 3 * time.Second
)

var (
Expand All @@ -77,23 +77,21 @@ var (
)

type EmailValidator struct {
extended bool
serviceURL string
serviceKey string
extended bool
serviceURL string
serviceHeaders map[string][]string
}

func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator {
return &EmailValidator{
extended: mc.EmailValidationExtended,
serviceURL: mc.EmailValidationServiceURL,
serviceKey: mc.EmailValidationServiceKey,
extended: mc.EmailValidationExtended,
serviceURL: mc.EmailValidationServiceURL,
serviceHeaders: mc.GetEmailValidationServiceHeaders(),
}
}

func (ev *EmailValidator) isExtendedDisabled() bool { return !ev.extended }
func (ev *EmailValidator) isServiceDisabled() bool {
return ev.serviceURL == "" || ev.serviceKey == ""
}
func (ev *EmailValidator) isExtendedEnabled() bool { return ev.extended }
func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" }

// Validate performs validation on the given email.
//
Expand All @@ -104,7 +102,7 @@ func (ev *EmailValidator) isServiceDisabled() bool {
// When serviceURL AND serviceKey are non-empty strings it uses the remote
// service to determine if the email is valid.
func (ev *EmailValidator) Validate(ctx context.Context, email string) error {
if ev.isExtendedDisabled() && ev.isServiceDisabled() {
if !ev.isExtendedEnabled() && !ev.isServiceEnabled() {
return nil
}

Expand All @@ -121,7 +119,7 @@ func (ev *EmailValidator) Validate(ctx context.Context, email string) error {

// Validate the static rules first to prevent round trips on bad emails
// and to parse the host ahead of time.
if !ev.isExtendedDisabled() {
if ev.isExtendedEnabled() {

// First validate static checks such as format, known invalid hosts
// and any other network free checks. Running this check before we
Expand All @@ -136,9 +134,9 @@ func (ev *EmailValidator) Validate(ctx context.Context, email string) error {
g.Go(func() error { return ev.validateHost(ctx, host) })
}

// If the service check is not disabled we start a goroutine to run
// If the service check is enabled we start a goroutine to run
// that check as well.
if !ev.isServiceDisabled() {
if ev.isServiceEnabled() {
g.Go(func() error { return ev.validateService(ctx, email) })
}
return g.Wait()
Expand All @@ -147,7 +145,7 @@ func (ev *EmailValidator) Validate(ctx context.Context, email string) error {
// validateStatic will validate the format and do the static checks before
// returning the host portion of the email.
func (ev *EmailValidator) validateStatic(email string) (string, error) {
if !ev.extended {
if !ev.isExtendedEnabled() {
return "", nil
}

Expand Down Expand Up @@ -185,7 +183,7 @@ func (ev *EmailValidator) validateStatic(email string) (string, error) {
}

func (ev *EmailValidator) validateService(ctx context.Context, email string) error {
if ev.isServiceDisabled() {
if !ev.isServiceEnabled() {
return nil
}

Expand All @@ -204,7 +202,11 @@ func (ev *EmailValidator) validateService(ctx context.Context, email string) err
return nil
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("apikey", ev.serviceKey)
for name, vals := range ev.serviceHeaders {
for _, val := range vals {
req.Header.Set(name, val)
}
}

res, err := http.DefaultClient.Do(req)
if err != nil {
Expand All @@ -215,6 +217,11 @@ func (ev *EmailValidator) validateService(ctx context.Context, email string) err
resObject := struct {
Valid *bool `json:"valid"`
}{}

if res.StatusCode != http.StatusOK {
return nil
}

dec := json.NewDecoder(io.LimitReader(res.Body, 1<<5))
if err := dec.Decode(&resObject); err != nil {
return nil
Expand Down
79 changes: 55 additions & 24 deletions internal/mailer/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ func TestEmalValidatorService(t *testing.T) {
testHdrsVal := new(atomic.Value)
testHdrsVal.Store(map[string]string{"apikey": "test"})

// testHeaders := map[string][]string{"apikey": []string{"test"}}
testHeaders := `{"apikey": ["test"]}`

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("apikey")
if key == "" {
Expand All @@ -40,10 +43,14 @@ func TestEmalValidatorService(t *testing.T) {
{
testResVal.Store(`{"valid": true}`)
cfg := conf.MailerConfiguration{
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceKey: "test",
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceHeaders: testHeaders,
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err != nil {
Expand All @@ -58,10 +65,14 @@ func TestEmalValidatorService(t *testing.T) {
testResVal.Store(`{"valid": true}`)

cfg := conf.MailerConfiguration{
EmailValidationExtended: false,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceKey: "test",
EmailValidationExtended: false,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceHeaders: testHeaders,
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err != nil {
Expand All @@ -76,10 +87,14 @@ func TestEmalValidatorService(t *testing.T) {
testResVal.Store(`{"valid": false}`)

cfg := conf.MailerConfiguration{
EmailValidationExtended: false,
EmailValidationServiceURL: "",
EmailValidationServiceKey: "",
EmailValidationExtended: false,
EmailValidationServiceURL: "",
EmailValidationServiceHeaders: "",
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err != nil {
Expand All @@ -93,10 +108,14 @@ func TestEmalValidatorService(t *testing.T) {
{
testResVal.Store(`{"valid": true}`)
cfg := conf.MailerConfiguration{
EmailValidationExtended: true,
EmailValidationServiceURL: "",
EmailValidationServiceKey: "",
EmailValidationExtended: true,
EmailValidationServiceURL: "",
EmailValidationServiceHeaders: "",
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err == nil {
Expand All @@ -110,10 +129,14 @@ func TestEmalValidatorService(t *testing.T) {
{
testResVal.Store(`{"valid": true}`)
cfg := conf.MailerConfiguration{
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceKey: "test",
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceHeaders: testHeaders,
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err == nil {
Expand All @@ -127,10 +150,14 @@ func TestEmalValidatorService(t *testing.T) {
{
testResVal.Store(`{"valid": false}`)
cfg := conf.MailerConfiguration{
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceKey: "test",
EmailValidationExtended: true,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceHeaders: testHeaders,
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err == nil {
Expand All @@ -145,10 +172,14 @@ func TestEmalValidatorService(t *testing.T) {
testResVal.Store(`{"valid": false}`)

cfg := conf.MailerConfiguration{
EmailValidationExtended: false,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceKey: "test",
EmailValidationExtended: false,
EmailValidationServiceURL: ts.URL,
EmailValidationServiceHeaders: testHeaders,
}
if err := cfg.Validate(); err != nil {
t.Fatal(err)
}

ev := newEmailValidator(cfg)
err := ev.Validate(ctx, "[email protected]")
if err == nil {
Expand Down Expand Up @@ -221,9 +252,9 @@ func TestValidateEmailExtended(t *testing.T) {
}

cfg := conf.MailerConfiguration{
EmailValidationExtended: true,
EmailValidationServiceURL: "",
EmailValidationServiceKey: "",
EmailValidationExtended: true,
EmailValidationServiceURL: "",
EmailValidationServiceHeaders: "",
}
ev := newEmailValidator(cfg)

Expand Down

0 comments on commit f51f1f8

Please sign in to comment.