From 9a37506ed9f6924b1d669a0af74ac9a98bdbad75 Mon Sep 17 00:00:00 2001 From: da-z Date: Wed, 21 Feb 2024 08:49:17 +0100 Subject: [PATCH] feat: allow usage of custom schemas --- cors.go | 6 ++++++ cors_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/cors.go b/cors.go index b325222..3d3e362 100644 --- a/cors.go +++ b/cors.go @@ -51,6 +51,9 @@ type Config struct { // Allows usage of popular browser extensions schemas AllowBrowserExtensions bool + // Allows to add custom schema like tauri:// + CustomSchemas []string + // Allows usage of WebSocket protocol AllowWebSockets bool @@ -87,6 +90,9 @@ func (c Config) getAllowedSchemas() []string { if c.AllowFiles { allowedSchemas = append(allowedSchemas, FileSchemas...) } + if c.CustomSchemas != nil { + allowedSchemas = append(allowedSchemas, c.CustomSchemas...) + } return allowedSchemas } diff --git a/cors_test.go b/cors_test.go index c87d60a..50e1033 100644 --- a/cors_test.go +++ b/cors_test.go @@ -271,6 +271,22 @@ func TestValidateOrigin(t *testing.T) { assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id")) } +func TestValidateTauri(t *testing.T) { + c := Config{ + AllowOrigins: []string{"tauri://localhost:1234"}, + AllowBrowserExtensions: true, + } + err := c.Validate() + assert.Equal(t, err.Error(), "bad origin: origins must contain '*' or include http://,https://,chrome-extension://,safari-extension://,moz-extension://,ms-browser-extension://") + + c = Config{ + AllowOrigins: []string{"tauri://localhost:1234"}, + AllowBrowserExtensions: true, + CustomSchemas: []string{"tauri"}, + } + assert.Nil(t, c.Validate()) +} + func TestPassesAllowOrigins(t *testing.T) { router := newTestRouter(Config{ AllowOrigins: []string{"http://google.com"},