Skip to content

Commit

Permalink
Add service token authentication mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
abs51295 committed Jul 6, 2018
1 parent b61c652 commit ccb0de6
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 11 deletions.
5 changes: 5 additions & 0 deletions openshift/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ objects:
configMapKeyRef:
name: bayesian-config
key: keycloak-url
- name: BAYESIAN_AUTH_PUBLIC_KEYS_URL
valueFrom:
configMapKeyRef:
name: bayesian-config
key: get-auth-public-key-url
- name: BAYESIAN_JWT_AUDIENCE
value: "fabric8-online-platform,openshiftio-public"
image: "${DOCKER_REGISTRY}/${DOCKER_IMAGE}:${IMAGE_TAG}"
Expand Down
70 changes: 66 additions & 4 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import jwt
from os import getenv


from exceptions import HTTPError
from utils import fetch_public_key
from utils import fetch_public_key, fetch_service_public_keys


def decode_token(token):
def decode_user_token(token):
"""Decode the authorization token read from the request header."""
if token is None:
return {}
Expand Down Expand Up @@ -38,6 +37,39 @@ def decode_token(token):
return decoded_token


def decode_service_token(token): # pragma: no cover
"""Decode OSIO service token."""
# TODO: Merge this function and user token function once audience is removed from user tokens.
if token is None:
return {}

if token.startswith('Bearer '):
_, token = token.split(' ', 1)

pub_keys = fetch_service_public_keys(current_app)
decoded_token = None

# Since we have multiple public keys, we need to verify against every public key.
# Token can be decoded by any one of the available public keys.
for pub_key in pub_keys:
try:
pub_key = '-----BEGIN PUBLIC KEY-----\n{pkey}\n-----END PUBLIC KEY-----'\
.format(pkey=pub_key)
decoded_token = jwt.decode(token, pub_key, algorithms=['RS256'])
except jwt.InvalidTokenError:
current_app.logger.error("Auth token couldn't be decoded for public key: {}"
.format(pub_key))
decoded_token = None

if decoded_token:
break

if not decoded_token:
raise jwt.InvalidTokenError('Auth token cannot be verified.')

return decoded_token


def get_token_from_auth_header():
"""Get the authorization token read from the request header."""
return request.headers.get('Authorization')
Expand All @@ -62,7 +94,37 @@ def wrapper(*args, **kwargs):
lgr = current_app.logger

try:
decoded = decode_token(get_token_from_auth_header())
decoded = decode_user_token(get_token_from_auth_header())
if not decoded:
lgr.exception('Provide an Authorization token with the API request')
raise HTTPError(401, 'Authentication failed - token missing')

lgr.info('Successfuly authenticated user {e} using JWT'.
format(e=decoded.get('email')))
except jwt.ExpiredSignatureError as exc:
lgr.exception('Expired JWT token')
raise HTTPError(401, 'Authentication failed - token has expired') from exc
except Exception as exc:
lgr.exception('Failed decoding JWT token')
raise HTTPError(401, 'Authentication failed - could not decode JWT token') from exc

return view(*args, **kwargs)

return wrapper


def service_token_required(view): # pragma: no cover
"""Check if the request contains a valid service token."""
@wraps(view)
def wrapper(*args, **kwargs):
# Disable authentication for local setup
if getenv('DISABLE_AUTHENTICATION') in ('1', 'True', 'true'):
return view(*args, **kwargs)

lgr = current_app.logger

try:
decoded = decode_user_token(get_token_from_auth_header())
if not decoded:
lgr.exception('Provide an Authorization token with the API request')
raise HTTPError(401, 'Authentication failed - token missing')
Expand Down
9 changes: 8 additions & 1 deletion src/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask_cors import CORS
from utils import DatabaseIngestion, scan_repo, validate_request_data, retrieve_worker_result
from f8a_worker.setup_celery import init_selinon
from auth import login_required
from auth import login_required, service_token_required
from exceptions import HTTPError

app = Flask(__name__)
Expand Down Expand Up @@ -215,5 +215,12 @@ def handle_error(e): # pragma: no cover
}), e.status_code


@app.route('/test-service-token')
@service_token_required
def test_service_token(): # pragma: no cover
"""Test the service token authentication mechanism."""
return flask.jsonify({'token': 'is_valid'}), 200


if __name__ == "__main__": # pragma: no cover
app.run()
24 changes: 24 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,27 @@ def fetch_public_key(app):
app.public_key = None

return app.public_key


def fetch_service_public_keys(app): # pragma: no cover
"""Get public keys for OSIO service account. Currently, there are three public keys."""
if not getattr(app, "service_public_keys", []):
auth_url = os.getenv('BAYESIAN_AUTH_PUBLIC_KEYS_URL', '')
if auth_url:
try:
result = requests.get(auth_url, timeout=0.5)
app.logger.info('Fetching public key from %s, status %d, result: %s',
auth_url, result.status_code, result.text)
except requests.exceptions.Timeout:
app.logger.error('Timeout fetching public key from %s', auth_url)
return ''
if result.status_code != 200:
return ''

keys = result.json().get('keys', [])
app.service_public_keys = keys

else:
app.service_public_keys = None

return app.service_public_keys
12 changes: 6 additions & 6 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,39 @@ def mocked_get_audiences_3():
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_1(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
assert decode_token(None) == {}
assert decode_user_token(None) == {}


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_2(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Foobar") is None
assert decode_user_token("Foobar") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_3(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Bearer ") is None
assert decode_user_token("Bearer ") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
def test_decode_token_invalid_input_4(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Bearer ") is None
assert decode_user_token("Bearer ") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences_2)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
def test_decode_token_invalid_input_5(mocked_fetch_public_key, mocked_get_audiences):
"""Test the handling wrong JWT tokens."""
with pytest.raises(Exception):
assert decode_token("Bearer something") is None
assert decode_user_token("Bearer something") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences_3)
Expand All @@ -112,7 +112,7 @@ def test_decode_token_invalid_input_6(mocked_fetch_public_key, mocked_get_audien
'aud': 'foo:bar'
}
token = jwt.encode(payload, PRIVATE_KEY, algorithm='RS256').decode("utf-8")
assert decode_token(token) is not None
assert decode_user_token(token) is not None


def test_audiences():
Expand Down

0 comments on commit ccb0de6

Please sign in to comment.