From 82827c2bf1d62ccce85980c21253843245f61218 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 16:29:18 -0800 Subject: [PATCH 01/11] Allow a origin validation function with context --- config.go | 42 +++++++++++++++++++++++++----------------- cors.go | 10 ++++++++-- cors_test.go | 2 ++ 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/config.go b/config.go index 427cfc0..ef24187 100644 --- a/config.go +++ b/config.go @@ -8,14 +8,15 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOriginWithContextFunc func(*gin.Context, string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -54,14 +55,15 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowOriginWithContextFunc: config.AllowOriginWithContextFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -79,7 +81,13 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if cors.allowOriginWithContextFunc != nil { + if !cors.allowOriginWithContextFunc(c, origin) { + c.AbortWithStatus(http.StatusForbidden) + return + } + + } else if !cors.validateOrigin(origin) { c.AbortWithStatus(http.StatusForbidden) return } diff --git a/cors.go b/cors.go index b325222..b0f3bec 100644 --- a/cors.go +++ b/cors.go @@ -22,6 +22,9 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // The same as AllowOriginFunc but allows access to the entire request context + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -102,12 +105,15 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || c.AllowOriginWithContextFunc != nil || len(c.AllowOrigins) > 0) { return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } + if c.AllowOriginFunc != nil && c.AllowOriginWithContextFunc != nil { + return errors.New("conflict settings: Both original validation functions are defined") + } for _, origin := range c.AllowOrigins { if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) { return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ",")) diff --git a/cors_test.go b/cors_test.go index c87d60a..2800238 100644 --- a/cors_test.go +++ b/cors_test.go @@ -205,6 +205,8 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { } func TestValidateOrigin(t *testing.T) { + // review the below for adding a testing context + //https://pkg.go.dev/github.com/gin-gonic/gin#CreateTestContextOnly cors := newCors(Config{ AllowAllOrigins: true, }) From dfcb4defca9bed4e2f702881f9c34a7b81d82a33 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 17:18:48 -0800 Subject: [PATCH 02/11] Revert "Allow a origin validation function with context" This reverts commit 82827c2bf1d62ccce85980c21253843245f61218. --- config.go | 42 +++++++++++++++++------------------------- cors.go | 10 ++-------- cors_test.go | 2 -- 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/config.go b/config.go index ef24187..427cfc0 100644 --- a/config.go +++ b/config.go @@ -8,15 +8,14 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOriginWithContextFunc func(*gin.Context, string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -55,15 +54,14 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowOriginWithContextFunc: config.AllowOriginWithContextFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -81,13 +79,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if cors.allowOriginWithContextFunc != nil { - if !cors.allowOriginWithContextFunc(c, origin) { - c.AbortWithStatus(http.StatusForbidden) - return - } - - } else if !cors.validateOrigin(origin) { + if !cors.validateOrigin(origin) { c.AbortWithStatus(http.StatusForbidden) return } diff --git a/cors.go b/cors.go index b0f3bec..b325222 100644 --- a/cors.go +++ b/cors.go @@ -22,9 +22,6 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool - // The same as AllowOriginFunc but allows access to the entire request context - AllowOriginWithContextFunc func(c *gin.Context, origin string) bool - // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -105,15 +102,12 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || c.AllowOriginWithContextFunc != nil || len(c.AllowOrigins) > 0) { + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } - if c.AllowOriginFunc != nil && c.AllowOriginWithContextFunc != nil { - return errors.New("conflict settings: Both original validation functions are defined") - } for _, origin := range c.AllowOrigins { if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) { return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ",")) diff --git a/cors_test.go b/cors_test.go index 2800238..c87d60a 100644 --- a/cors_test.go +++ b/cors_test.go @@ -205,8 +205,6 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { } func TestValidateOrigin(t *testing.T) { - // review the below for adding a testing context - //https://pkg.go.dev/github.com/gin-gonic/gin#CreateTestContextOnly cors := newCors(Config{ AllowAllOrigins: true, }) From 59f0738ac53e5c281e3b2b5ebf94e2119e69952c Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 17:32:47 -0800 Subject: [PATCH 03/11] Allow origin validation function which receives the full request context --- config.go | 36 +++++++++++++++++++----------------- cors.go | 9 ++++++--- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/config.go b/config.go index 427cfc0..0e914ba 100644 --- a/config.go +++ b/config.go @@ -8,14 +8,15 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOriginWithContextFunc func(*gin.Context, string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -54,14 +55,15 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowOriginWithContextFunc: config.AllowOriginWithContextFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if !cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && cors.allowOriginWithContextFunc(c, origin)) { c.AbortWithStatus(http.StatusForbidden) return } diff --git a/cors.go b/cors.go index b325222..e4ba348 100644 --- a/cors.go +++ b/cors.go @@ -22,6 +22,9 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // same AllowOriginFunc except also receives the full request context + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -102,10 +105,10 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { - return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0 || c.AllowOriginWithContextFunc != nil) { + return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOriginFuncWithContext or AllowOrigins is not needed") } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { From 239063e84877cae4eadb86ba1c80062373bb2fc0 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 18:24:50 -0800 Subject: [PATCH 04/11] fix logic in conditional --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 0e914ba..b7ea5ae 100644 --- a/config.go +++ b/config.go @@ -81,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && cors.allowOriginWithContextFunc(c, origin)) { + if !cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && !cors.allowOriginWithContextFunc(c, origin)) { c.AbortWithStatus(http.StatusForbidden) return } From 906e08c73cea7df119e332f225da82f24d3f1bc1 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 22:29:49 -0800 Subject: [PATCH 05/11] add test, fix logic --- config.go | 3 ++- cors_test.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index b7ea5ae..e4429f6 100644 --- a/config.go +++ b/config.go @@ -81,7 +81,8 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && !cors.allowOriginWithContextFunc(c, origin)) { + validOrigin := cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && cors.allowOriginWithContextFunc(c, origin)) + if !validOrigin { c.AbortWithStatus(http.StatusForbidden) return } diff --git a/cors_test.go b/cors_test.go index c87d60a..bab149a 100644 --- a/cors_test.go +++ b/cors_test.go @@ -282,6 +282,9 @@ func TestPassesAllowOrigins(t *testing.T) { AllowOriginFunc: func(origin string) bool { return origin == "http://github.com" }, + AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool { + return origin == "http://sample.com" + }, }) // no CORS request, origin == "" @@ -329,6 +332,15 @@ func TestPassesAllowOrigins(t *testing.T) { assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age")) + // allowed CORS prefligh request: allowed via AllowOriginWithContextFunc + w = performRequest(router, "OPTIONS", "http://sample.com") + assert.Equal(t, http.StatusNoContent, w.Code) + assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age")) + // deny CORS prefligh request w = performRequest(router, "OPTIONS", "http://example.com") assert.Equal(t, http.StatusForbidden, w.Code) From 6d48563d2c88dcfb4ace2418189ea76c1939d43d Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Fri, 1 Mar 2024 09:35:11 -0800 Subject: [PATCH 06/11] slightly re-work to pass linter --- config.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index e4429f6..8a295e3 100644 --- a/config.go +++ b/config.go @@ -81,8 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - validOrigin := cors.validateOrigin(origin) || (cors.allowOriginWithContextFunc != nil && cors.allowOriginWithContextFunc(c, origin)) - if !validOrigin { + if !cors.isOriginValid(c, origin) { c.AbortWithStatus(http.StatusForbidden) return } @@ -115,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool { return false } +func (cors *cors) isOriginValid(c *gin.Context, origin string) bool { + valid := cors.validateOrigin(origin) + if !valid && cors.allowOriginWithContextFunc != nil { + valid = cors.allowOriginWithContextFunc(c, origin) + } + return valid +} + func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true From 6eded887a6d113dc59763c65cfab1275e628893f Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Fri, 1 Mar 2024 09:39:30 -0800 Subject: [PATCH 07/11] update comments --- cors.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cors.go b/cors.go index e4ba348..28ef9ce 100644 --- a/cors.go +++ b/cors.go @@ -22,7 +22,10 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool - // same AllowOriginFunc except also receives the full request context + // Same as AllowOriginFunc except also receives the full request context. + // This function should use the context as a read only source and not + // have any side effects on the request, such as aborting or injecting + // values on the request. AllowOriginWithContextFunc func(c *gin.Context, origin string) bool // AllowMethods is a list of methods the client is allowed to use with From 6b6a3da425b0ad4959c5a01794356f19606ebb1e Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Sat, 9 Mar 2024 19:51:13 -0800 Subject: [PATCH 08/11] restructure to shorten line lengths to pass linter --- cors.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/cors.go b/cors.go index 28ef9ce..7873963 100644 --- a/cors.go +++ b/cors.go @@ -2,6 +2,7 @@ package cors import ( "errors" + "fmt" "strings" "time" @@ -108,10 +109,21 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0 || c.AllowOriginWithContextFunc != nil) { - return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOriginFuncWithContext or AllowOrigins is not needed") + hasOriginFn := c.AllowOriginFunc != nil + hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil + + if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) { + originFields := strings.Join([]string{ + "AllowOriginFunc", + "AllowOriginFuncWithContext", + "AllowOrigins", + }, " or ") + return fmt.Errorf( + "conflict settings: all origins enabled. %s is not needed.", + originFields, + ) } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { From 141074f2244c589a0b9f7c1c0f7665f9b85043ae Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Sat, 9 Mar 2024 20:27:56 -0800 Subject: [PATCH 09/11] remove punctuation at the end of error string --- cors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cors.go b/cors.go index 7873963..9597f9a 100644 --- a/cors.go +++ b/cors.go @@ -119,7 +119,7 @@ func (c Config) Validate() error { "AllowOrigins", }, " or ") return fmt.Errorf( - "conflict settings: all origins enabled. %s is not needed.", + "conflict settings: all origins enabled. %s is not needed", originFields, ) } From ffd15e8850809958852585eb8f88c721c91b803a Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Sat, 9 Mar 2024 22:08:06 -0800 Subject: [PATCH 10/11] Add multi-group preflight test --- cors_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/cors_test.go b/cors_test.go index bab149a..4955d95 100644 --- a/cors_test.go +++ b/cors_test.go @@ -27,12 +27,34 @@ func newTestRouter(config Config) *gin.Engine { return router } +func multiGroupRouter(config Config) *gin.Engine { + router := gin.New() + router.Use(New(config)) + + app1 := router.Group("/app1") + app1.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app1") + }) + + app2 := router.Group("/app2") + app2.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app2") + }) + + app3 := router.Group("/app3") + app3.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app3") + }) + + return router +} + func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder { - return performRequestWithHeaders(r, method, origin, http.Header{}) + return performRequestWithHeaders(r, method, "/", origin, http.Header{}) } -func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder { - req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil) +func performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) *httptest.ResponseRecorder { + req, _ := http.NewRequestWithContext(context.Background(), method, path, nil) // From go/net/http/request.go: // For incoming requests, the Host header is promoted to the // Request.Host field and removed from the Header map. @@ -297,7 +319,7 @@ func TestPassesAllowOrigins(t *testing.T) { // no CORS request, origin == host h := http.Header{} h.Set("Host", "facebook.com") - w = performRequestWithHeaders(router, "GET", "http://facebook.com", h) + w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h) assert.Equal(t, "get", w.Body.String()) assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) @@ -426,3 +448,47 @@ func TestWildcard(t *testing.T) { w = performRequest(router, "GET", "https://github.com") assert.Equal(t, 200, w.Code) } + +func TestMultiGroupRouter(t *testing.T) { + // performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) + router := multiGroupRouter(Config{ + AllowMethods: []string{"GET"}, + AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool { + path := c.Request.URL.Path + if strings.HasPrefix(path, "/app1") { + return "http://app1.example.com" == origin + } + + if strings.HasPrefix(path, "/app2") { + return "http://app2.example.com" == origin + } + + // app 3 allows all origins + return true + }, + }) + + // allowed CORS prefligh request + emptyHeaders := http.Header{} + app1Origin := "http://app1.example.com" + app2Origin := "http://app2.example.com" + randomOrgin := "http://random.com" + + // allowed CORS preflight + w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + // disallowed CORS preflight + w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusForbidden, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusForbidden, w.Code) + +} From 063635d66b62d62e0b1049a9d5be8e06f1167402 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Sat, 9 Mar 2024 22:21:53 -0800 Subject: [PATCH 11/11] remove comment --- cors_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cors_test.go b/cors_test.go index 5ba04a3..dedd3cc 100644 --- a/cors_test.go +++ b/cors_test.go @@ -467,7 +467,6 @@ func TestWildcard(t *testing.T) { } func TestMultiGroupRouter(t *testing.T) { - // performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) router := multiGroupRouter(Config{ AllowMethods: []string{"GET"}, AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {