Skip to content

Commit

Permalink
chore(cors): Allow a custom validation function which receives the fu…
Browse files Browse the repository at this point in the history
…ll gin context (#140)

* Allow a origin validation function with context

* Revert "Allow a origin validation function with context"

This reverts commit 82827c2.

* Allow origin validation function which receives the full request context

* fix logic in conditional

* add test, fix logic

* slightly re-work to pass linter

* update comments

* restructure to shorten line lengths to pass linter

* remove punctuation at the end of error string

* Add multi-group preflight test

* remove comment
  • Loading branch information
dbhoot authored Mar 10, 2024
1 parent 7f30a1f commit 9d49f16
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 24 deletions.
44 changes: 27 additions & 17 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
)

type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOriginWithContextFunc func(*gin.Context, string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
}

var (
Expand Down Expand Up @@ -54,14 +55,15 @@ func newCors(config Config) *cors {
}

return &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
allowOriginFunc: config.AllowOriginFunc,
allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
}
}

Expand All @@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) {
return
}

if !cors.validateOrigin(origin) {
if !cors.isOriginValid(c, origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}
Expand Down Expand Up @@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool {
return false
}

func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
valid := cors.validateOrigin(origin)
if !valid && cors.allowOriginWithContextFunc != nil {
valid = cors.allowOriginWithContextFunc(c, origin)
}
return valid
}

func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
Expand Down
24 changes: 21 additions & 3 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"errors"
"fmt"
"strings"
"time"

Expand All @@ -22,6 +23,12 @@ type Config struct {
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool

// Same as AllowOriginFunc except also receives the full request context.
// This function should use the context as a read only source and not
// have any side effects on the request, such as aborting or injecting
// values on the request.
AllowOriginWithContextFunc func(c *gin.Context, origin string) bool

// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
AllowMethods []string
Expand Down Expand Up @@ -108,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool {

// Validate is check configuration of user defined.
func (c Config) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
hasOriginFn := c.AllowOriginFunc != nil
hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil

if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
originFields := strings.Join([]string{
"AllowOriginFunc",
"AllowOriginFuncWithContext",
"AllowOrigins",
}, " or ")
return fmt.Errorf(
"conflict settings: all origins enabled. %s is not needed",
originFields,
)
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
Expand Down
84 changes: 80 additions & 4 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,34 @@ func newTestRouter(config Config) *gin.Engine {
return router
}

func multiGroupRouter(config Config) *gin.Engine {
router := gin.New()
router.Use(New(config))

app1 := router.Group("/app1")
app1.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app1")
})

app2 := router.Group("/app2")
app2.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app2")
})

app3 := router.Group("/app3")
app3.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app3")
})

return router
}

func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
return performRequestWithHeaders(r, method, origin, http.Header{})
return performRequestWithHeaders(r, method, "/", origin, http.Header{})
}

func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder {
req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil)
func performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) *httptest.ResponseRecorder {
req, _ := http.NewRequestWithContext(context.Background(), method, path, nil)
// From go/net/http/request.go:
// For incoming requests, the Host header is promoted to the
// Request.Host field and removed from the Header map.
Expand Down Expand Up @@ -299,6 +321,9 @@ func TestPassesAllowOrigins(t *testing.T) {
AllowOriginFunc: func(origin string) bool {
return origin == "http://github.com"
},
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
return origin == "http://sample.com"
},
})

// no CORS request, origin == ""
Expand All @@ -311,7 +336,7 @@ func TestPassesAllowOrigins(t *testing.T) {
// no CORS request, origin == host
h := http.Header{}
h.Set("Host", "facebook.com")
w = performRequestWithHeaders(router, "GET", "http://facebook.com", h)
w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h)
assert.Equal(t, "get", w.Body.String())
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
Expand Down Expand Up @@ -346,6 +371,15 @@ func TestPassesAllowOrigins(t *testing.T) {
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))

// allowed CORS prefligh request: allowed via AllowOriginWithContextFunc
w = performRequest(router, "OPTIONS", "http://sample.com")
assert.Equal(t, http.StatusNoContent, w.Code)
assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))

// deny CORS prefligh request
w = performRequest(router, "OPTIONS", "http://example.com")
assert.Equal(t, http.StatusForbidden, w.Code)
Expand Down Expand Up @@ -432,6 +466,48 @@ func TestWildcard(t *testing.T) {
assert.Equal(t, 200, w.Code)
}

func TestMultiGroupRouter(t *testing.T) {
router := multiGroupRouter(Config{
AllowMethods: []string{"GET"},
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/app1") {
return "http://app1.example.com" == origin
}

if strings.HasPrefix(path, "/app2") {
return "http://app2.example.com" == origin
}

// app 3 allows all origins
return true
},
})

// allowed CORS prefligh request
emptyHeaders := http.Header{}
app1Origin := "http://app1.example.com"
app2Origin := "http://app2.example.com"
randomOrgin := "http://random.com"

// allowed CORS preflight
w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

// disallowed CORS preflight
w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusForbidden, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusForbidden, w.Code)
}

func TestParseWildcardRules_NoWildcard(t *testing.T) {
config := Config{
AllowOrigins: []string{
Expand Down

0 comments on commit 9d49f16

Please sign in to comment.