diff --git a/.env.testing b/.env.testing index 578cfd8..bc16a53 100644 --- a/.env.testing +++ b/.env.testing @@ -1,4 +1,4 @@ -ALLOWED_DOMAINS=.*\.geo\.admin\.ch,.*\.bgdi\.ch,.*\.swisstopo\.cloud +ALLOWED_DOMAINS=.*\.geo\.admin\.ch,.*\.bgdi\.ch,http://localhost AWS_ACCESS_KEY_ID=testing AWS_SECRET_ACCESS_KEY=testing AWS_SECURITY_TOKEN=testing diff --git a/adr/2022_05_17_short_id_algorithm.md b/adr/2022_05_17_short_id_algorithm.md new file mode 100644 index 0000000..08de2e5 --- /dev/null +++ b/adr/2022_05_17_short_id_algorithm.md @@ -0,0 +1,173 @@ +# Continuous integration platform + +> `Status: proposed` +> +> `Date: 2022-05-17` +> +> `Author: Brice Schaffner` + +## Context + +The actual short ID algorithm for `service-shortlink` uses a kind of counter based on the computer time. +Its takes the actual timestamp rounded to the milliseconds (millisecond timestamp from 1970.01.01 00:00:00.000), to +reduce the size of this timestamp, `1'000'000'000'000` is substracted from it, which give us the number +of milliseconds from `2001.09.09 03:46:40.000`. + +### PROS of actual algorithm + +- Very simple +- With a request rate less than `1 rps` (current rate is `~0.3 rps`) collision are quite unlikely and can be easily avoided with very small number of retries. + +### CONS of actual algorithm + +- Size of short ID is dynamic, current is 10 characters but will increase in near future. Around `2037.01.01` we will have 11 characters. + +## Nano ID - random short ID + +We could reduced the size of the ID to 8 characters by using [Nano ID](https://github.com/ai/nanoid). +Here however we have an issue with collision ! Based on [Nano ID collision calculator](https://zelark.github.io/nano-id-cc/) +and our current request rate of ~1050 rph (request per hour), we will have a 1% collision risk in 99 days ! +Now looking closer to the mathematics (note I might be wrong here as I'm not a mathematician) we can +compute the collision probability as follow: + +- d := number of different possible IDs (see [Permutation with Replacement](https://www.calculatorsoup.com/calculators/discretemathematics/permutationsreplacement.php)) +- n := number of IDs +- `1-((d-1)/d**(n*-1)/2)` [Birthday Paradox / Probability of a shared birthday (collision)](https://en.wikipedia.org/wiki/Birthday_problem) + +```python +d = 64**8 +print(f"{d:,}") +281,474,976,710,656 + +# Number of IDs after 100 days +n = 1050 * 24 * 100 + +collision = 1-((d-1)/d)**(n*(n-1)/2) +print(str(int(collision * 100)) + '%') +1% + +# Number of IDs after 1 years +n = 1050 * 24 * 365 * 1 + +collision = 1-((d-1)/d)**(n*(n-1)/2) +print(str(int(collision * 100)) + '%') +13% + +# Number of IDs after 3 years +n = 1050 * 24 * 365 * 3 + +collision = 1-((d-1)/d)**(n*(n-1)/2) +print(str(int(collision * 100)) + '%') +74% + +# Number of IDs after 5 years +n = 1050 * 24 * 365 * 5 + +collision = 1-((d-1)/d)**(n*(n-1)/2) +print(str(int(collision * 100)) + '%') +97% + +# Number of IDs after 10 years +n = 1050 * 24 * 365 * 10 + +collision = 1-((d-1)/d)**(n*(n-1)/2) +print(str(int(collision * 100)) + '%') +99% +``` + +### Nano ID tests + +I tested Nano ID with 1 and 2 characters with the following code + +```python +# app/helpers/utils.py +def generate_short_id(): + return generate(size=8) + +# tests/unit_tests/test_helpers.py +class TestDynamoDb(BaseShortlinkTestCase): + @params(1, 2) + @patch('app.helpers.dynamo_db.generate_short_id') + def test_duplicate_short_id_end_of_ids(self, m, mock_generate_short_id): + regex = re.compile(r'^[0-9a-zA-Z-_]{' + str(m) + '}$') + + def generate_short_id_mocker(): + return generate(size=m) + + mock_generate_short_id.side_effect = generate_short_id_mocker + # with generate(size=1) we have 64 different possible IDs, as we get closer to this number + # the collision will increase. Here we make sure that we can generate at least the half + # of the maximal number of unique ID with less than the max retry. + n = 64 + max_ids = int(factorial(n) / (factorial(m) * factorial(n - m))) + logger.debug('Try to generate %d entries', max_ids) + for i in range(max_ids): + logger.debug('-' * 80) + logger.debug('Add entry %d', i) + if i < max_ids / 2: + + next_entry = add_url_to_table( + f'https://www.example/test-duplicate-id-end-of-ids-{i}-url' + ) + self.assertIsNotNone( + regex.match(next_entry['shortlink_id']), + msg=f"short ID {next_entry['shortlink_id']} don't match regex" + ) + else: + # more thant the half of max ids might fail due to more than COLLISION_MAX_RETRY + # retries, therefore ignore those errors + try: + next_entry = add_url_to_table( + f'https://www.example/test-duplicate-id-end-of-ids-{i}-url' + ) + except db_table.meta.client.exceptions.ConditionalCheckFailedException: + pass + # Make sure that generating a 65 ID fails. + with self.assertRaises(db_table.meta.client.exceptions.ConditionalCheckFailedException): + add_url_to_table('https://www.example/test-duplicate-id-end-of-ids-65-url') + +``` + +The test with 1 character passed but with 2 not ! This means that with 1 character we could generate +up to half of the available IDs without having more than 10 retries. While with 2 character we could not ! +To note also that the formula used here to compute the maximal number of IDs was wrong and generated less +IDs than the correct formula `max_ids = n**m`: + +- `max_ids = n**m` + - `max_ids = 64**1 = 64` + - `max_ids = 64**2 = 4096` +- `int(factorial(n) / (factorial(m) * factorial(n - m)))` + - `n = 64; m = 1; int(factorial(n) / (factorial(m) * factorial(n - m))) = 64` + - `n = 64; m = 2; int(factorial(n) / (factorial(m) * factorial(n - m))) = 2016` + +### Nano ID conclusion + +Based on the formula and computation above we have a high risk to have too many collision already +after 3 years ! So this algorithm cannot be used with 8 characters. + +## Other algorithms + +After some research on shortlink algorithm, I found out that there are two category of algorithms: + +1. Random ID generator (e.g. NanoID) +2. Counter + +While the first is very easy to implement, the size of the ID highly depends on the generation rates and +max life of the IDs. For our use case we have a quite high generation rate and an infinite life of the IDs. +This means that it is not the best algorithm. + +However the second algorithm is more robust for our use case. Starting a counter from 0 we could reduce the ID significantly (less than 6 characters). However it would require to change the backend to have an atomic counter. With our current +current implementation (k8s with DynamoDB) this is not feasible. So we would need to change the DB (maybe PSQL?) +and rewrite the whole python service. + +## Decision (to be accepted by others) + +I think with the current algorithm which we used the past years, we are good up to 2037 where we will have one more character. +This algo is quite robust, not the more effective in terms of ID length but very simple and fast. + +Changing to a Random ID generator includes based on our generation rate and life cycle, is way too risky and brittle. + +Changing to a real counter approach would require a lot of effort, starting from scratch. + +So IMHO sticking to the current algorithm is the best for the moment. In future we can reduce the size of +the shortlink by reducing the size of the host name wich is quite long; e.g. `s.bgdi.ch` instead of `s.geo.admin.ch`. diff --git a/app/__init__.py b/app/__init__.py index 223d3de..c03ed3a 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -25,6 +25,10 @@ app.config.from_mapping({"TRAP_HTTP_EXCEPTIONS": True}) +def is_domain_allowed(domain): + return re.match(ALLOWED_DOMAINS_PATTERN, domain) is not None + + @app.before_request # Add quick log of the routes used to all request. # Important: this should be the first before_request method, to ensure @@ -43,13 +47,36 @@ def validate_origin(): # any origin (anyone) return - if 'Origin' not in request.headers: - logger.error('Origin header is not set') + # The Origin headers is automatically set by the browser and cannot be changed by the javascript + # application. Unfortunately this header is only set if the request comes from another origin. + # Sec-Fetch-Site header is set to `same-origin` by most of the browser except by Safari ! + # The best protection would be to use the Sec-Fetch-Site and Origin header, however this is + # not supported by Safari. Therefore we added a fallback to the Referer header for Safari. + sec_fetch_site = request.headers.get('Sec-Fetch-Site', None) + origin = request.headers.get('Origin', None) + referrer = request.headers.get('Referer', None) + + if origin is not None: + if is_domain_allowed(origin): + return + logger.error('Origin=%s is not allowed', origin) abort(403, 'Permission denied') - if not re.match(ALLOWED_DOMAINS_PATTERN, request.headers['Origin']): - logger.error('Origin %s is not allowed', request.headers['Origin']) + + if sec_fetch_site is not None: + if sec_fetch_site in ['same-origin', 'same-site']: + return + logger.error('Sec-Fetch-Site=%s is not allowed', sec_fetch_site) abort(403, 'Permission denied') + if referrer is not None: + if is_domain_allowed(referrer): + return + logger.error('Referer=%s is not allowed', referrer) + abort(403, 'Permission denied') + + logger.error('Referer and/or Origin and/or Sec-Fetch-Site headers not set') + abort(403, 'Permission denied') + @app.after_request def add_charset(response): @@ -66,13 +93,15 @@ def add_generic_cors_header(response): if request.endpoint == 'checker': return response - if ( - 'Origin' in request.headers and - re.match(ALLOWED_DOMAINS_PATTERN, request.headers['Origin']) - ): - # Don't add the allow origin if the origin is not allowed, otherwise that would give - # a hint to the user on how to missused this service - response.headers.set('Access-Control-Allow-Origin', request.headers['Origin']) + if request.endpoint == 'get_shortlink' and get_redirect_param(ignore_errors=True): + # redirect endpoint are allowed from all origins + response.headers['Access-Control-Allow-Origin'] = "*" + else: + response.headers['Access-Control-Allow-Origin'] = request.host_url + if 'Origin' in request.headers and is_domain_allowed(request.headers['Origin']): + response.headers['Access-Control-Allow-Origin'] = request.headers['Origin'] + response.headers['Vary'] = 'Origin' + # Always add the allowed methods. response.headers.set( 'Access-Control-Allow-Methods', ', '.join(get_registered_method(app, request.url_rule)) diff --git a/app/helpers/utils.py b/app/helpers/utils.py index 5fe39c0..8428ed0 100644 --- a/app/helpers/utils.py +++ b/app/helpers/utils.py @@ -59,11 +59,13 @@ def get_registered_method(app, url_rule): ) -def get_redirect_param(): +def get_redirect_param(ignore_errors=False): try: redirect = strtobool(request.args.get('redirect', 'true')) except ValueError as error: - abort(400, f'Invalid "redirect" arg: {error}') + redirect = False + if not ignore_errors: + abort(400, f'Invalid "redirect" arg: {error}') return redirect diff --git a/tests/unit_tests/base.py b/tests/unit_tests/base.py index d3738f2..a2454e2 100644 --- a/tests/unit_tests/base.py +++ b/tests/unit_tests/base.py @@ -83,12 +83,18 @@ def setUp(self): def tearDown(self): self.table.delete() - def assertCors(self, response, expected_allowed_methods, check_origin=True): # pylint: disable=invalid-name - if check_origin: - self.assertIn('Access-Control-Allow-Origin', response.headers) - self.assertTrue( - re.match(ALLOWED_DOMAINS_PATTERN, response.headers['Access-Control-Allow-Origin']) - ) + def assertCors( + self, + response, + expected_allowed_methods, + origin_pattern=ALLOWED_DOMAINS_PATTERN + ): # pylint: disable=invalid-name + self.assertIn('Access-Control-Allow-Origin', response.headers) + self.assertIsNotNone( + re.match(origin_pattern, response.headers['Access-Control-Allow-Origin']), + msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}" + f" doesn't match {origin_pattern}" + ) self.assertIn('Access-Control-Allow-Methods', response.headers) self.assertListEqual( sorted(expected_allowed_methods), diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 1266169..6deb0ee 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -116,43 +116,3 @@ def test_one_duplicate_short_id(self, mock_generate_short_id): self.assertEqual(entry1['shortlink_id'], '2') entry2 = self.db.add_url_to_table(url2) self.assertEqual(entry2['shortlink_id'], '3') - - # @params(1, 2) - # @patch('app.helpers.dynamo_db.generate_short_id') - # def test_duplicate_short_id_end_of_ids(self, m, mock_generate_short_id): - # regex = re.compile(r'^[0-9a-zA-Z-_]{' + str(m) + '}$') - - # def generate_short_id_mocker(): - # return generate(size=m) - - # mock_generate_short_id.side_effect = generate_short_id_mocker - # # with generate(size=1) we have 64 different possible IDs, as we get closer to this number - # # the collision will increase. Here we make sure that we can generate at least the half - # # of the maximal number of unique ID with less than the max retry. - # n = 64 - # max_ids = int(factorial(n) / (factorial(m) * factorial(n - m))) - # logger.debug('Try to generate %d entries', max_ids) - # for i in range(max_ids): - # logger.debug('-' * 80) - # logger.debug('Add entry %d', i) - # if i < max_ids / 2: - - # next_entry = add_url_to_table( - # f'https://www.example/test-duplicate-id-end-of-ids-{i}-url' - # ) - # self.assertIsNotNone( - # regex.match(next_entry['shortlink_id']), - # msg=f"short ID {next_entry['shortlink_id']} don't match regex" - # ) - # else: - # # more thant the half of max ids might fail due to more than COLLISION_MAX_RETRY - # # retries, therefore ignore those errors - # try: - # next_entry = add_url_to_table( - # f'https://www.example/test-duplicate-id-end-of-ids-{i}-url' - # ) - # except db_table.meta.client.exceptions.ConditionalCheckFailedException: - # pass - # # Make sure that generating a 65 ID fails. - # with self.assertRaises(db_table.meta.client.exceptions.ConditionalCheckFailedException): - # add_url_to_table('https://www.example/test-duplicate-id-end-of-ids-65-url') diff --git a/tests/unit_tests/test_routes.py b/tests/unit_tests/test_routes.py index 4fb7205..a6cfd4a 100644 --- a/tests/unit_tests/test_routes.py +++ b/tests/unit_tests/test_routes.py @@ -2,6 +2,8 @@ import logging.config import re +from nose2.tools import params + from flask import url_for from app.settings import SHORT_ID_SIZE @@ -54,14 +56,16 @@ def test_create_shortlink_no_json(self): self.assertEqual(415, response.status_code) self.assertCors(response, ['POST', 'OPTIONS']) self.assertIn('application/json', response.content_type) - self.assertEqual({ - 'success': False, - 'error': { - 'code': 415, - 'message': 'Input data missing or from wrong type, must be application/json' - } - }, - response.json) + self.assertEqual( + { + 'success': False, + 'error': { + 'code': 415, + 'message': 'Input data missing or from wrong type, must be application/json' + } + }, + response.json, + ) def test_create_shortlink_no_url(self): response = self.app.post( @@ -70,13 +74,15 @@ def test_create_shortlink_no_url(self): self.assertEqual(400, response.status_code) self.assertCors(response, ['POST', 'OPTIONS']) self.assertIn('application/json', response.content_type) - self.assertEqual({ - 'success': False, - 'error': { - 'code': 400, 'message': 'Url parameter missing from request' - } - }, - response.json) + self.assertEqual( + { + 'success': False, + 'error': { + 'code': 400, 'message': 'Url parameter missing from request' + } + }, + response.json, + ) def test_create_shortlink_no_hostname(self): wrong_url = "/test" @@ -145,7 +151,7 @@ def test_redirect_shortlink_ok(self): for short_id, url in self.uuid_to_url_dict.items(): response = self.app.get(url_for('get_shortlink', shortlink_id=short_id)) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], check_origin=False) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=', response.headers['Cache-Control']) self.assertEqual(response.content_type, "text/html; charset=utf-8") @@ -159,7 +165,7 @@ def test_redirect_shortlink_ok_with_query(self): headers={"Origin": "www.example.com"} ) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], check_origin=False) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=', response.headers['Cache-Control']) self.assertEqual(response.content_type, "text/html; charset=utf-8") @@ -181,7 +187,10 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self): } } self.assertEqual(response.status_code, 400) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS']) + self.assertCors( + response, + ['GET', 'HEAD', 'OPTIONS'], + ) self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=3600', response.headers['Cache-Control']) self.assertIn('application/json', response.content_type) @@ -199,7 +208,7 @@ def test_redirect_shortlink_url_not_found(self): } } self.assertEqual(response.status_code, 404) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS']) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=3600', response.headers['Cache-Control']) self.assertIn('application/json', response.content_type) @@ -256,26 +265,128 @@ def test_fetch_full_url_from_shortlink_url_not_found(self): } self.assertEqual(response.json, expected_json) - def test_create_shortlink_no_origin_header(self): - response = self.app.post("/") + @params( + None, + {'Origin': 'www.example'}, + {'Origin': ''}, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'cross-site' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-origin' + }, + {'Referer': 'http://www.example'}, + {'Referer': ''}, + ) + def test_create_shortlink_origin_not_allowed(self, headers): + response = self.app.post("/", headers=headers) self.assertEqual(403, response.status_code) - self.assertCors(response, ['POST', 'OPTIONS'], check_origin=False) + self.assertCors(response, ['POST', 'OPTIONS']) self.assertIn('application/json', response.content_type) - self.assertEqual({ - 'success': False, 'error': { - 'code': 403, 'message': 'Permission denied' - } + self.assertEqual( + { + 'success': False, 'error': { + 'code': 403, 'message': 'Permission denied' + } + }, + response.json, + ) + + @params( + {'Origin': 'map.geo.admin.ch'}, + { + 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' }, - response.json) + { + 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site' + }, + {'Sec-Fetch-Site': 'same-origin'}, + {'Referer': 'https://map.geo.admin.ch'}, + ) + def test_create_shortlink_origin_allowed(self, headers): + url = "https://map.geo.admin.ch/test" + response = self.app.post(url_for('create_shortlink'), json={"url": url}, headers=headers) + self.assertEqual(response.status_code, 201) + self.assertCors(response, ['POST', 'OPTIONS']) + self.assertEqual(response.content_type, "application/json; charset=utf-8") + self.assertEqual(response.json.get('success'), True) - def test_create_shortlink_non_allowed_origin_header(self): - response = self.app.post("/", headers={"Origin": "big-bad-wolf.com"}) - self.assertEqual(403, response.status_code) - self.assertCors(response, ['POST', 'OPTIONS'], check_origin=False) - self.assertIn('application/json', response.content_type) - self.assertEqual({ - 'success': False, 'error': { - 'code': 403, 'message': 'Permission denied' - } + @params( + None, + {}, + {'Origin': "www.example"}, + {'Origin': 'map.geo.admin.ch'}, + {'Origin': ''}, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'cross-site' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-site' }, - response.json) + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-origin' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'none' + }, + { + 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' + }, + { + 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site' + }, + {'Referer': 'http://www.example'}, + {'Referer': 'https://map.geo.admin.ch'}, + {'Referer': ''}, + {'Sec-Fetch-Site': 'same-origin'}, + {'Sec-Fetch-Site': 'same-site'}, + {'Sec-Fetch-Site': 'none'}, + {'Sec-Fetch-Site': 'cross-site'}, + ) + def test_get_shortlink_redirect_origin_allowed(self, headers): + short_id = next(iter(self.uuid_to_url_dict.keys())) + response = self.app.get( + url_for('get_shortlink', shortlink_id=short_id), + query_string={'redirect': 'true'}, + headers=headers + ) + self.assertEqual(response.status_code, 301) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + + response = self.app.get(url_for('get_shortlink', shortlink_id=short_id), headers=headers) + self.assertEqual(response.status_code, 301) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + + @params( + {'Origin': 'map.geo.admin.ch'}, + { + 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' + }, + { + 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site' + }, + {'Sec-Fetch-Site': 'same-origin'}, + {'Sec-Fetch-Site': 'same-site'}, + {'Referer': 'https://map.geo.admin.ch'}, + ) + def test_get_shortlink_origin_allowed(self, headers): + short_id = next(iter(self.uuid_to_url_dict.keys())) + response = self.app.get( + url_for('get_shortlink', shortlink_id=short_id), + query_string={'redirect': 'false'}, + headers=headers + ) + self.assertEqual(response.status_code, 200) + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])