diff --git a/.github/workflows/run-unit-tests.yml b/.github/workflows/run-unit-tests.yml index 81517ded..2a8d6baa 100644 --- a/.github/workflows/run-unit-tests.yml +++ b/.github/workflows/run-unit-tests.yml @@ -4,6 +4,26 @@ on: workflow_dispatch: pull_request: jobs: + free-disk-space: + runs-on: ubuntu-latest + steps: + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: false + docker-tests: name: "Run unit tests in docker" runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8b732b1c..11c07a28 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,20 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 23.11.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 5.12.0 hooks: - id: isort name: isort (python) - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.1.0 hooks: - id: flake8 diff --git a/README.md b/README.md index 5b317bab..3b646bba 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,10 @@ python3 -m cnaas_nms.api.tests.test_api python3 -m cnaas_nms.confpush.tests.test_get ``` +## Authorization + +Currently we can use two styles for the authorization. We can use the original style or use OIDC style. For OIDC we need to define some env variables or add a auth_config.yaml in the config. The needed variables are: OIDC_CONF_WELL_KNOWN_URL, OIDC_CLIENT_SECRET, OIDC_CLIENT_ID, FRONTEND_CALLBACK_URL and OIDC_ENABLED. To use the OIDC style the last variable needs to be set to true. + ## License Copyright (c) 2019 - 2020, SUNET (BSD 2-clause license) diff --git a/docker/api/Dockerfile b/docker/api/Dockerfile index 9cb74ef1..625fd398 100644 --- a/docker/api/Dockerfile +++ b/docker/api/Dockerfile @@ -63,7 +63,7 @@ RUN mkdir -p /opt/cnaas/templates /opt/cnaas/settings /opt/cnaas/venv \ COPY --chown=root:www-data cnaas-setup.sh createca.sh exec-pre-app.sh pytest.sh coverage.sh /opt/cnaas/ # Copy cnaas configuration files -COPY --chown=www-data:www-data config/api.yml config/db_config.yml config/plugins.yml config/repository.yml /etc/cnaas-nms/ +COPY --chown=www-data:www-data config/api.yml config/auth_config.yml config/db_config.yml config/plugins.yml config/repository.yml /etc/cnaas-nms/ USER www-data diff --git a/docker/api/config/auth_config.yml b/docker/api/config/auth_config.yml new file mode 100644 index 00000000..b5f681d5 --- /dev/null +++ b/docker/api/config/auth_config.yml @@ -0,0 +1,5 @@ +oidc_conf_well_known_url: "well-known-openid-configuration-endpoint" +oidc_client_secret: "xxx" +oidc_client_id: "client-id" +frontend_callback_url: "http://localhost/callback" +oidc_enabled: False diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index 4ca4dcc4..fc0c340e 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -48,6 +48,17 @@ Defines parameters for the API: - commit_confirmed_wait: Time to wait between comitting configuration and checking that the device is still reachable, specified in seconds. Defaults to 1. +/etc/cnaas-nms/auth_config.yml +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Define parameters for the authentication: + +- oidc_conf_well_known_url: set the url for the oidc +- oidc_client_secret: set the secret of the oidc +- oidc_client_id: set the client_id of the oidc +- frontend_callback_url: set the frontend url the oidc client should link to after the login process +- oidc_enabled: set True to enabled the oidc login. Default: False + /etc/cnaas-nms/repository.yml ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/requirements.txt b/requirements.txt index d89c44c1..48a70468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,6 @@ pydantic==2.3.0 Werkzeug==3.0.1 greenlet==3.0.1 pyyaml!=6.0.0,!=5.4.0,!=5.4.1 -pydantic_settings==2.1.0 \ No newline at end of file +pydantic_settings==2.1.0 +Authlib==1.0.1 +python-jose==3.1.0 diff --git a/src/cnaas_nms/api/app.py b/src/cnaas_nms/api/app.py index 799cd61d..9f3d6468 100644 --- a/src/cnaas_nms/api/app.py +++ b/src/cnaas_nms/api/app.py @@ -1,8 +1,8 @@ import os import re import sys -from typing import Optional +from typing import Optional from engineio.payload import Payload from flask import Flask, jsonify, request from flask_cors import CORS @@ -10,7 +10,11 @@ from flask_jwt_extended.exceptions import InvalidHeaderError, NoAuthorizationError from flask_restx import Api from flask_socketio import SocketIO, join_room -from jwt.exceptions import DecodeError, InvalidSignatureError, InvalidTokenError +from jwt import decode +from jwt.exceptions import DecodeError, InvalidSignatureError, InvalidTokenError, ExpiredSignatureError, InvalidKeyError +from authlib.integrations.flask_client import OAuth +from authlib.oauth2.rfc6749 import MissingAuthorizationError + from cnaas_nms.api.device import ( device_api, @@ -24,6 +28,7 @@ device_update_interfaces_api, devices_api, ) +from cnaas_nms.api.auth import api as auth_api from cnaas_nms.api.firmware import api as firmware_api from cnaas_nms.api.groups import api as groups_api from cnaas_nms.api.interface import api as interfaces_api @@ -35,11 +40,15 @@ from cnaas_nms.api.repository import api as repository_api from cnaas_nms.api.settings import api as settings_api from cnaas_nms.api.system import api as system_api + +from cnaas_nms.app_settings import auth_settings from cnaas_nms.app_settings import api_settings + from cnaas_nms.tools.log import get_logger -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_oauth_userinfo from cnaas_nms.version import __api_version__ + logger = get_logger() @@ -52,31 +61,53 @@ } } -jwt_query_r = re.compile(r"jwt=[^ &]+") +jwt_query_r = re.compile(r"code=[^ &]+") class CnaasApi(Api): def handle_error(self, e): if isinstance(e, DecodeError): - data = {"status": "error", "data": "Could not decode JWT token"} + data = {"status": "error", "message": "Could not decode JWT token"} + elif isinstance(e, InvalidKeyError): + data = {"status": "error", "message": "Invalid keys {}".format(e)} elif isinstance(e, InvalidTokenError): - data = {"status": "error", "data": "Invalid authentication header: {}".format(e)} + data = {"status": "error", "message": "Invalid authentication header: {}".format(e)} elif isinstance(e, InvalidSignatureError): - data = {"status": "error", "data": "Invalid token signature"} + data = {"status": "error", "message": "Invalid token signature"} elif isinstance(e, IndexError): - # We might catch IndexErrors which are not cuased by JWT, + # We might catch IndexErrors which are not caused by JWT, # but this is better than nothing. - data = {"status": "error", "data": "JWT token missing?"} + data = {"status": "error", "message": "JWT token missing?"} elif isinstance(e, NoAuthorizationError): - data = {"status": "error", "data": "JWT token missing?"} + data = {"status": "error", "message": "JWT token missing?"} elif isinstance(e, InvalidHeaderError): - data = {"status": "error", "data": "Invalid header, JWT token missing? {}".format(e)} + data = {"status": "error", "message": "Invalid header, JWT token missing? {}".format(e)} + elif isinstance(e, ExpiredSignatureError): + data = {"status": "error", "message": "The JWT token is expired"} + elif isinstance(e, MissingAuthorizationError): + data = {"status": "error", "message": "JWT token missing?"} + elif isinstance(e, ConnectionError): + data = {"status": "error", "message": "ConnectionError: {}".format(e)} + return jsonify(data), 500 else: return super(CnaasApi, self).handle_error(e) return jsonify(data), 401 app = Flask(__name__) + +# To register the OAuth client +oauth = OAuth(app) +client = oauth.register( + "connext", + server_metadata_url=auth_settings.OIDC_CONF_WELL_KNOWN_URL, + client_id=auth_settings.OIDC_CLIENT_ID, + client_secret=auth_settings.OIDC_CLIENT_SECRET, + client_kwargs={"scope": auth_settings.OIDC_CLIENT_SCOPE}, + response_type="code", + response_mode="query", +) + app.config["RESTX_JSON"] = {"cls": CNaaSJSONEncoder} # TODO: make origins configurable @@ -88,6 +119,9 @@ def handle_error(self, e): Payload.max_decode_packets = 500 socketio = SocketIO(app, cors_allowed_origins="*") + +if api_settings.JWT_ENABLED or auth_settings.OIDC_ENABLED: + app.config["SECRET_KEY"] = os.urandom(128) if api_settings.JWT_ENABLED: try: jwt_pubkey = open(api_settings.JWT_CERT).read() @@ -95,7 +129,6 @@ def handle_error(self, e): print("Could not load public JWT cert from api.yml config: {}".format(e)) sys.exit(1) - app.config["SECRET_KEY"] = os.urandom(128) app.config["JWT_PUBLIC_KEY"] = jwt_pubkey app.config["JWT_IDENTITY_CLAIM"] = "sub" app.config["JWT_ALGORITHM"] = "ES256" @@ -108,6 +141,7 @@ def handle_error(self, e): app, prefix="/api/{}".format(__api_version__), authorizations=authorizations, security="apikey", doc="/api/doc/" ) +api.add_namespace(auth_api) api.add_namespace(device_api) api.add_namespace(devices_api) api.add_namespace(device_init_api) @@ -133,12 +167,28 @@ def handle_error(self, e): api.add_namespace(plugins_api) api.add_namespace(system_api) - # SocketIO on connect @socketio.on("connect") -@jwt_required def socketio_on_connect(): - user = get_jwt_identity() + # get te token string + token_string = request.args.get('jwt') + if not token_string: + return False + #if oidc, get userinfo + if auth_settings.OIDC_ENABLED: + try: + user = get_oauth_userinfo(token_string)['email'] + except InvalidTokenError as e: + logger.debug('InvalidTokenError: ' + format(e)) + return False + # else decode the token and get the sub there + else: + try: + user = decode(token_string, app.config["JWT_PUBLIC_KEY"], algorithms=[app.config["JWT_ALGORITHM"]])['sub'] + except DecodeError as e: + logger.debug('DecodeError: ' + format(e)) + return False + if user: logger.info("User: {} connected via socketio".format(user)) return True @@ -165,18 +215,32 @@ def socketio_on_events(data): # Log all requests, include username etc @app.after_request def log_request(response): - try: - token = request.headers.get("Authorization").split(" ")[-1] - user = decode_token(token).get("sub") - except Exception: - user = "unknown" + user = "" + if request.method in ["POST", "PUT", "DELETE", "PATCH"]: + try: + if auth_settings.OIDC_ENABLED: + token_string = request.headers.get("Authorization").split(" ")[-1] + user = "User: {}, ".format(get_oauth_userinfo(token_string)['email']) + else: + token = request.headers.get("Authorization").split(" ")[-1] + user = "User: {}, ".format(decode_token(token).get("sub")) + except Exception: + user = "User: unknown, " + try: url = re.sub(jwt_query_r, "", request.url) - logger.info( - "User: {}, Method: {}, Status: {}, URL: {}, JSON: {}".format( - user, request.method, response.status_code, url, request.json + if request.headers.get('content-type') == 'application/json': + logger.info( + "{}Method: {}, Status: {}, URL: {}, JSON: {}".format( + user, request.method, response.status_code, url, request.json + ) + ) + else: + logger.info( + "{}Method: {}, Status: {}, URL: {}".format( + user, request.method, response.status_code, url + ) ) - ) except Exception: pass return response diff --git a/src/cnaas_nms/api/auth.py b/src/cnaas_nms/api/auth.py new file mode 100644 index 00000000..9d4db682 --- /dev/null +++ b/src/cnaas_nms/api/auth.py @@ -0,0 +1,94 @@ +from authlib.integrations.base_client.errors import MismatchingStateError, OAuthError +from flask import current_app, redirect, url_for +from flask_restx import Namespace, Resource +from requests.models import PreparedRequest + +from cnaas_nms.api.generic import empty_result +from cnaas_nms.app_settings import auth_settings +from cnaas_nms.tools.log import get_logger +from cnaas_nms.tools.security import get_identity, login_required +from cnaas_nms.version import __api_version__ + +logger = get_logger() +api = Namespace("auth", description="API for handling auth", prefix="/api/{}".format(__api_version__)) + + +class LoginApi(Resource): + def get(self): + """Function to initiate a login of the user. + The user will be sent to the page to login. + Our client info will also be checked. + + Note: + We also discussed adding state to this function. + That way you could be sent to the same page once you logged in. + We would put the relevant information in a dictionary, + base64 encode it and sent it around as a parameter. + For now the application is small and it didn't seem needed. + + Returns: + A HTTP redirect response to OIDC_CONF_WELL_KNOWN_URL we have defined. + We give the auth call as a parameter to redirect after login is successfull. + + """ + if not auth_settings.OIDC_ENABLED: + return empty_result(status="error", data="Can't login when OIDC disabled"), 500 + oauth_client = current_app.extensions["authlib.integrations.flask_client"] + redirect_uri = url_for("auth_auth_api", _external=True) + + return oauth_client.connext.authorize_redirect(redirect_uri) + + +class AuthApi(Resource): + def get(self): + """Function to authenticate the user. + This API call is called by the OAUTH login after the user has logged in. + We get the users token and redirect them to right page in the frontend. + + Returns: + A HTTP redirect response to the url in the frontend that handles the repsonse after login. + The access token is a parameter in the url + + """ + + oauth_client = current_app.extensions["authlib.integrations.flask_client"] + + try: + token = oauth_client.connext.authorize_access_token() + except MismatchingStateError as e: + logger.error("Exception during authorization of the access token: {}".format(str(e))) + return ( + empty_result( + status="error", + data="Exception during authorization of the access token. Please try to login again.", + ), + 502, + ) + except OAuthError as e: + logger.error("Missing information needed for authorization: {}".format(str(e))) + return ( + empty_result( + status="error", + data="The server is missing some information that is needed for authorization.", + ), + 500, + ) + + url = auth_settings.FRONTEND_CALLBACK_URL + parameters = {"token": token["access_token"]} + + req = PreparedRequest() + req.prepare_url(url, parameters) + return redirect(req.url, code=302) + + +class IdentityApi(Resource): + @login_required + def get(self): + identity = get_identity() + return identity + + +api.add_resource(LoginApi, "/login") +api.add_resource(AuthApi, "/auth") +api.add_resource(IdentityApi, "/identity") diff --git a/src/cnaas_nms/api/device.py b/src/cnaas_nms/api/device.py index 7de7d721..5e69f9b6 100644 --- a/src/cnaas_nms/api/device.py +++ b/src/cnaas_nms/api/device.py @@ -39,7 +39,7 @@ ) from cnaas_nms.scheduler.scheduler import Scheduler from cnaas_nms.tools.log import get_logger -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ logger = get_logger() @@ -226,7 +226,7 @@ def device_data_postprocess(device_list: List[Device]) -> List[dict]: class DeviceByIdApi(Resource): - @jwt_required + @login_required def get(self, device_id): """Get a device from ID""" result = empty_result() @@ -239,7 +239,7 @@ def get(self, device_id): return empty_result("error", "Device not found"), 404 return result - @jwt_required + @login_required @device_api.expect(delete_device_model) def delete(self, device_id): """Delete device from ID""" @@ -254,7 +254,7 @@ def delete(self, device_id): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.erase:device_erase", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs={"device_id": device_id}, ) res = empty_result(data="Scheduled job {} to factory default device".format(job_id)) @@ -270,7 +270,7 @@ def delete(self, device_id): remove_sync_events(dev.hostname) for nei in dev.get_neighbors(session): nei.synchronized = False - add_sync_event(nei.hostname, "neighbor_deleted", get_jwt_identity()) + add_sync_event(nei.hostname, "neighbor_deleted", get_identity()) except Exception as e: logger.warning("Could not mark neighbor as unsync after deleting {}: {}".format(dev.hostname, e)) try: @@ -289,7 +289,7 @@ def delete(self, device_id): return empty_result(status="error", data="Could not remove device: {}".format(e)), 500 return empty_result(status="success", data={"deleted_device": dev.as_dict()}), 200 - @jwt_required + @login_required @device_api.expect(device_model) def put(self, device_id): """Modify device from ID""" @@ -326,7 +326,7 @@ def put(self, device_id): and json_data["state"].upper() == "UNMANAGED" and dev_prev_state == DeviceState.MANAGED ): - add_sync_event(dev.hostname, "was_unmanaged", by=get_jwt_identity()) + add_sync_event(dev.hostname, "was_unmanaged", by=get_identity()) session.commit() update_device_primary_groups() dev_dict = device_data_postprocess([dev])[0] @@ -334,7 +334,7 @@ def put(self, device_id): class DeviceByHostnameApi(Resource): - @jwt_required + @login_required def get(self, hostname): """Get a device from hostname""" result = empty_result() @@ -349,7 +349,7 @@ def get(self, hostname): class DeviceApi(Resource): - @jwt_required + @login_required @device_api.expect(device_model) def post(self): """Add a device""" @@ -386,9 +386,10 @@ def post(self): class DevicesApi(Resource): - @jwt_required + @login_required def get(self): """Get all devices""" + logger.info("started get devices") device_list: List[Device] = [] total_count = 0 with sqla_session() as session: @@ -409,7 +410,7 @@ def get(self): class DeviceInitApi(Resource): - @jwt_required + @login_required @device_init_api.expect(device_init_model) def post(self, device_id: int): """Init a device""" @@ -433,7 +434,7 @@ def post(self, device_id: int): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.init_device:init_device_step2", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs={"device_id": device_id, "iteration": 1}, ) @@ -449,7 +450,7 @@ def post(self, device_id: int): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.init_device:init_access_device_step1", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs=job_kwargs, ) elif job_kwargs["device_type"] in [DeviceType.CORE.name, DeviceType.DIST.name]: @@ -457,7 +458,7 @@ def post(self, device_id: int): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.init_device:init_fabric_device_step1", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs=job_kwargs, ) else: @@ -523,7 +524,7 @@ def arg_check(cls, device_id: int, json_data: dict) -> dict: class DeviceInitCheckApi(Resource): - @jwt_required + @login_required @device_init_api.expect(device_initcheck_model) def post(self, device_id: int): """Perform init check on a device""" @@ -623,7 +624,7 @@ def post(self, device_id: int): class DeviceDiscoverApi(Resource): - @jwt_required + @login_required @device_discover_api.expect(device_discover_model) def post(self): """Discover device""" @@ -637,7 +638,7 @@ def post(self): dhcp_ip = json_data["dhcp_ip"] job_id = cnaas_nms.devicehandler.init_device.schedule_discover_device( - ztp_mac=ztp_mac, dhcp_ip=dhcp_ip, iteration=1, scheduled_by=get_jwt_identity() + ztp_mac=ztp_mac, dhcp_ip=dhcp_ip, iteration=1, scheduled_by=get_identity() ) logger.debug(f"Discover device for ztp_mac {ztp_mac} scheduled as ID {job_id}") @@ -649,7 +650,7 @@ def post(self): class DeviceSyncApi(Resource): - @jwt_required + @login_required @device_syncto_api.expect(device_syncto_model) def post(self): """Start sync of device(s)""" @@ -720,7 +721,7 @@ def post(self): return empty_result(status="error", data="No devices to synchronize were specified"), 400 scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.devicehandler.sync_devices:sync_devices", when=1, scheduled_by=get_jwt_identity(), kwargs=kwargs + "cnaas_nms.devicehandler.sync_devices:sync_devices", when=1, scheduled_by=get_identity(), kwargs=kwargs ) res = empty_result(data=f"Scheduled job to synchronize {what}") @@ -734,7 +735,7 @@ def post(self): class DeviceUpdateFactsApi(Resource): - @jwt_required + @login_required @device_update_facts_api.expect(device_update_facts_model) def post(self): """Start update facts of device(s)""" @@ -761,7 +762,7 @@ def post(self): scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.devicehandler.update:update_facts", when=1, scheduled_by=get_jwt_identity(), kwargs=kwargs + "cnaas_nms.devicehandler.update:update_facts", when=1, scheduled_by=get_identity(), kwargs=kwargs ) res = empty_result(data=f"Scheduled job to update facts for {hostname}") @@ -775,7 +776,7 @@ def post(self): class DeviceUpdateInterfacesApi(Resource): - @jwt_required + @login_required @device_update_interfaces_api.expect(device_update_interfaces_model) def post(self): """Update/scan interfaces of device""" @@ -840,7 +841,7 @@ def post(self): scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.devicehandler.update:update_interfacedb", when=1, scheduled_by=get_jwt_identity(), kwargs=kwargs + "cnaas_nms.devicehandler.update:update_interfacedb", when=1, scheduled_by=get_identity(), kwargs=kwargs ) res = empty_result(data=f"Scheduled job to update interfaces for {hostname}") @@ -854,7 +855,7 @@ def post(self): class DeviceGenerateConfigApi(Resource): - @jwt_required + @login_required @device_api.doc(model=device_generate_config_model) def get(self, hostname: str): """Get device configuration""" @@ -888,7 +889,7 @@ def get(self, hostname: str): class DeviceRunningConfigApi(Resource): - @jwt_required + @login_required @device_api.param("interface") def get(self, hostname: str): args = request.args @@ -917,7 +918,7 @@ def get(self, hostname: str): class DevicePreviousConfigApi(Resource): - @jwt_required + @login_required @device_api.param("job_id") @device_api.param("previous") @device_api.param("before") @@ -957,7 +958,7 @@ def get(self, hostname: str): return result - @jwt_required + @login_required @device_api.expect(device_restore_model) def post(self, hostname: str): """Restore configuration to previous version""" @@ -1005,7 +1006,7 @@ def post(self, hostname: str): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.sync_devices:apply_config", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs=apply_kwargs, ) @@ -1016,7 +1017,7 @@ def post(self, hostname: str): class DeviceApplyConfigApi(Resource): - @jwt_required + @login_required @device_api.expect(device_apply_config_model) def post(self, hostname: str): """Apply exact specified configuration to device without using templates""" @@ -1043,7 +1044,7 @@ def post(self, hostname: str): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.sync_devices:apply_config", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs=apply_kwargs, ) @@ -1054,7 +1055,7 @@ def post(self, hostname: str): class DeviceCertApi(Resource): - @jwt_required + @login_required @device_api.expect(device_cert_model) def post(self): """Execute certificate related actions on device""" @@ -1098,7 +1099,7 @@ def post(self): if action == "RENEW": scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.devicehandler.cert:renew_cert", when=1, scheduled_by=get_jwt_identity(), kwargs=kwargs + "cnaas_nms.devicehandler.cert:renew_cert", when=1, scheduled_by=get_identity(), kwargs=kwargs ) res = empty_result(data="Scheduled job to renew certificates") @@ -1114,7 +1115,7 @@ def post(self): class DeviceStackmembersApi(Resource): - @jwt_required + @login_required def get(self, hostname): """Get stackmembers for device""" result = empty_result(data={"stackmembers": []}) @@ -1127,7 +1128,7 @@ def get(self, hostname): result["data"]["stackmembers"].append(stackmember.as_dict()) return result - @jwt_required + @login_required @device_api.expect(stackmembers_model) def put(self, hostname): try: @@ -1167,7 +1168,7 @@ def format_errors(cls, errors): class DeviceSyncHistoryApi(Resource): - @jwt_required + @login_required @device_synchistory_api.param("hostname") def get(self): args = request.args @@ -1184,7 +1185,7 @@ def get(self): result["data"]["hostnames"] = sync_history.asdict() return result - @jwt_required + @login_required @device_synchistory_api.expect(device_synchistory_api) def post(self): try: diff --git a/src/cnaas_nms/api/firmware.py b/src/cnaas_nms/api/firmware.py index c009dbef..78ca8e72 100644 --- a/src/cnaas_nms/api/firmware.py +++ b/src/cnaas_nms/api/firmware.py @@ -4,7 +4,6 @@ import requests from flask import make_response, request -from flask_jwt_extended import get_jwt_identity from flask_restx import Namespace, Resource, fields from cnaas_nms.api.generic import empty_result @@ -15,7 +14,7 @@ from cnaas_nms.scheduler.scheduler import Scheduler from cnaas_nms.scheduler.wrapper import job_wrapper from cnaas_nms.tools.log import get_logger -from cnaas_nms.tools.security import jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ logger = get_logger() @@ -93,7 +92,7 @@ def remove_file(**kwargs: dict) -> str: class FirmwareApi(Resource): - @jwt_required + @login_required @api.expect(firmware_model) def post(self) -> tuple: """Download new firmware""" @@ -116,14 +115,14 @@ def post(self) -> tuple: scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.api.firmware:get_firmware", when=1, scheduled_by=get_jwt_identity(), kwargs=kwargs + "cnaas_nms.api.firmware:get_firmware", when=1, scheduled_by=get_identity(), kwargs=kwargs ) res = empty_result(data="Scheduled job to download firmware") res["job_id"] = job_id return res - @jwt_required + @login_required def get(self) -> tuple: """Get firmwares""" try: @@ -136,14 +135,14 @@ def get(self) -> tuple: class FirmwareImageApi(Resource): - @jwt_required + @login_required def get(self, filename: str) -> dict: """Get information about a single firmware""" scheduler = Scheduler() job_id = scheduler.add_onetime_job( "cnaas_nms.api.firmware:get_firmware_chksum", when=1, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs={"filename": filename}, ) res = empty_result(data="Scheduled job get firmware information") @@ -151,12 +150,12 @@ def get(self, filename: str) -> dict: return res - @jwt_required + @login_required def delete(self, filename: str) -> dict: """Remove firmware""" scheduler = Scheduler() job_id = scheduler.add_onetime_job( - "cnaas_nms.api.firmware:remove_file", when=1, scheduled_by=get_jwt_identity(), kwargs={"filename": filename} + "cnaas_nms.api.firmware:remove_file", when=1, scheduled_by=get_identity(), kwargs={"filename": filename} ) res = empty_result(data="Scheduled job to remove firmware") res["job_id"] = job_id @@ -165,7 +164,7 @@ def delete(self, filename: str) -> dict: class FirmwareUpgradeApi(Resource): - @jwt_required + @login_required @api.expect(firmware_upgrade_model) def post(self): """Upgrade firmware on device""" @@ -277,7 +276,7 @@ def post(self): job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.firmware:device_upgrade", when=seconds, - scheduled_by=get_jwt_identity(), + scheduled_by=get_identity(), kwargs=kwargs, ) res = empty_result(data="Scheduled job to upgrade devices") diff --git a/src/cnaas_nms/api/groups.py b/src/cnaas_nms/api/groups.py index e6083f71..cd7a84ce 100644 --- a/src/cnaas_nms/api/groups.py +++ b/src/cnaas_nms/api/groups.py @@ -7,7 +7,7 @@ from cnaas_nms.db.device import Device, DeviceState from cnaas_nms.db.session import sqla_session from cnaas_nms.db.settings import get_group_regex, get_group_settings, get_groups -from cnaas_nms.tools.security import jwt_required +from cnaas_nms.tools.security import login_required from cnaas_nms.version import __api_version__ api = Namespace("groups", description="API for handling groups", prefix="/api/{}".format(__api_version__)) @@ -68,7 +68,7 @@ def groups_osversion_populate(group_name: str): class GroupsApi(Resource): - @jwt_required + @login_required def get(self): """Get all groups""" result = {"groups": groups_populate(), "group_settings": groups_settings_populate()} @@ -76,7 +76,7 @@ def get(self): class GroupsApiByName(Resource): - @jwt_required + @login_required def get(self, group_name): """Get a single group by name""" if group_name not in get_groups(): @@ -89,7 +89,7 @@ def get(self, group_name): class GroupsApiByNameOsversion(Resource): - @jwt_required + @login_required def get(self, group_name): """Get os version of all devices in a group""" try: diff --git a/src/cnaas_nms/api/interface.py b/src/cnaas_nms/api/interface.py index 9ec424f7..76f20993 100644 --- a/src/cnaas_nms/api/interface.py +++ b/src/cnaas_nms/api/interface.py @@ -11,14 +11,14 @@ from cnaas_nms.devicehandler.interface_state import bounce_interfaces, get_interface_states from cnaas_nms.devicehandler.sync_devices import resolve_vlanid, resolve_vlanid_list from cnaas_nms.devicehandler.sync_history import add_sync_event -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ api = Namespace("device", description="API for handling interfaces", prefix="/api/{}".format(__api_version__)) class InterfaceApi(Resource): - @jwt_required + @login_required def get(self, hostname): """List all interfaces""" result = empty_result() @@ -38,7 +38,7 @@ def get(self, hostname): result["data"]["interfaces"] = sorted(interfaces, key=lambda i: i["indexnum"]) return result - @jwt_required + @login_required def put(self, hostname): """Take a map of interfaces and associated values to update. Example: @@ -237,7 +237,7 @@ def put(self, hostname): if updated: dev.synchronized = False - add_sync_event(hostname, "interface_updated", get_jwt_identity()) + add_sync_event(hostname, "interface_updated", get_identity()) if errors: if data: @@ -250,7 +250,7 @@ def put(self, hostname): class InterfaceStatusApi(Resource): - @jwt_required + @login_required def get(self, hostname): """List all interfaces status""" result = empty_result() @@ -262,7 +262,7 @@ def get(self, hostname): return empty_result("error", "Could not get interface states, unknon exception: {}".format(e)), 400 return result - @jwt_required + @login_required def put(self, hostname): """Bounce selected interfaces by appling bounce-down/bounce-up template""" json_data = request.get_json() diff --git a/src/cnaas_nms/api/jobs.py b/src/cnaas_nms/api/jobs.py index 8c5cd831..6f0e854b 100644 --- a/src/cnaas_nms/api/jobs.py +++ b/src/cnaas_nms/api/jobs.py @@ -11,7 +11,7 @@ from cnaas_nms.db.session import sqla_session from cnaas_nms.scheduler.scheduler import Scheduler from cnaas_nms.tools.log import get_logger -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ job_api = Namespace("job", description="API for handling jobs", prefix="/api/{}".format(__api_version__)) @@ -58,7 +58,7 @@ def filter_job_dict(job_dict: dict, args: dict) -> dict: class JobsApi(Resource): - @jwt_required + @login_required def get(self): """Get one or more jobs""" data = {"jobs": []} @@ -83,7 +83,7 @@ def get(self): class JobByIdApi(Resource): - @jwt_required + @login_required def get(self, job_id): """Get job information by ID""" args = request.args @@ -96,7 +96,7 @@ def get(self, job_id): else: return empty_result(status="error", data="No job with id {} found".format(job_id)), 400 - @jwt_required + @login_required def put(self, job_id): json_data = request.get_json() if "action" not in json_data: @@ -125,7 +125,7 @@ def put(self, job_id): if "abort_reason" in json_data and isinstance(json_data["abort_reason"], str): abort_reason = json_data["abort_reason"][:255] - abort_reason += " (aborted by {})".format(get_jwt_identity()) + abort_reason += " (aborted by {})".format(get_identity()) if job_status == JobStatus.SCHEDULED: scheduler = Scheduler() @@ -144,7 +144,7 @@ def put(self, job_id): class JobLockApi(Resource): - @jwt_required + @login_required def get(self): """Get job locks""" locks = [] @@ -153,7 +153,7 @@ def get(self): locks.append(lock.as_dict()) return empty_result("success", data={"locks": locks}) - @jwt_required + @login_required @job_api.expect(job_model) def delete(self): """Remove job locks""" diff --git a/src/cnaas_nms/api/linknet.py b/src/cnaas_nms/api/linknet.py index 008ab73f..fbb20a30 100644 --- a/src/cnaas_nms/api/linknet.py +++ b/src/cnaas_nms/api/linknet.py @@ -12,7 +12,7 @@ from cnaas_nms.db.session import sqla_session from cnaas_nms.devicehandler.sync_history import add_sync_event from cnaas_nms.devicehandler.underlay import find_free_infra_linknet -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ linknets_api = Namespace("linknets", description="API for handling linknets", prefix="/api/{}".format(__api_version__)) @@ -100,7 +100,7 @@ def validate_hostname(hostname): # Allow pre-provisioning of linknet to device that is not yet # managed or not assigned device_type, so no further checks here - @jwt_required + @login_required def get(self): """Get all linksnets""" result = {"linknets": []} @@ -110,7 +110,7 @@ def get(self): result["linknets"].append(instance.as_dict()) return empty_result(status="success", data=result) - @jwt_required + @login_required @linknets_api.expect(linknets_model) def post(self): """Add a new linknet""" @@ -182,7 +182,7 @@ def post(self): return empty_result(status="success", data=data), 201 - @jwt_required + @login_required def delete(self): """Remove linknet""" json_data = request.get_json() @@ -199,16 +199,16 @@ def delete(self): if not cur_linknet: return empty_result(status="error", data="No such linknet found in database"), 404 cur_linknet.device_a.synchronized = False - add_sync_event(cur_linknet.device_a.hostname, "linknet_deleted", get_jwt_identity()) + add_sync_event(cur_linknet.device_a.hostname, "linknet_deleted", get_identity()) cur_linknet.device_b.synchronized = False - add_sync_event(cur_linknet.device_b.hostname, "linknet_deleted", get_jwt_identity()) + add_sync_event(cur_linknet.device_b.hostname, "linknet_deleted", get_identity()) session.delete(cur_linknet) session.commit() return empty_result(status="success", data={"deleted_linknet": cur_linknet.as_dict()}), 200 class LinknetByIdApi(Resource): - @jwt_required + @login_required def get(self, linknet_id): """Get a single specified linknet""" result = empty_result() @@ -221,23 +221,23 @@ def get(self, linknet_id): return empty_result("error", "Linknet not found"), 404 return result - @jwt_required + @login_required def delete(self, linknet_id): """Remove a linknet""" with sqla_session() as session: instance: Linknet = session.query(Linknet).filter(Linknet.id == linknet_id).one_or_none() if instance: instance.device_a.synchronized = False - add_sync_event(instance.device_a.hostname, "linknet_deleted", get_jwt_identity()) + add_sync_event(instance.device_a.hostname, "linknet_deleted", get_identity()) instance.device_b.synchronized = False - add_sync_event(instance.device_b.hostname, "linknet_deleted", get_jwt_identity()) + add_sync_event(instance.device_b.hostname, "linknet_deleted", get_identity()) session.delete(instance) session.commit() return empty_result(status="success", data={"deleted_linknet": instance.as_dict()}), 200 else: return empty_result("error", "No such linknet found in database"), 404 - @jwt_required + @login_required @linknets_api.expect(linknet_model) def put(self, linknet_id): """Update data on existing linknet""" @@ -267,9 +267,9 @@ def put(self, linknet_id): changed: bool = update_sqla_object(instance, json_data) if changed: instance.device_a.synchronized = False - add_sync_event(instance.device_a.hostname, "linknet_updated", get_jwt_identity()) + add_sync_event(instance.device_a.hostname, "linknet_updated", get_identity()) instance.device_b.synchronized = False - add_sync_event(instance.device_b.hostname, "linknet_updated", get_jwt_identity()) + add_sync_event(instance.device_b.hostname, "linknet_updated", get_identity()) return empty_result(status="success", data={"updated_linknet": instance.as_dict()}), 200 else: return empty_result(status="success", data={"unchanged_linknet": instance.as_dict()}), 200 diff --git a/src/cnaas_nms/api/mgmtdomain.py b/src/cnaas_nms/api/mgmtdomain.py index f702122e..7fca5fed 100644 --- a/src/cnaas_nms/api/mgmtdomain.py +++ b/src/cnaas_nms/api/mgmtdomain.py @@ -13,7 +13,7 @@ from cnaas_nms.db.session import sqla_session from cnaas_nms.db.settings_fields import vlan_id_schema_optional from cnaas_nms.devicehandler.sync_history import add_sync_event -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ mgmtdomains_api = Namespace( @@ -79,7 +79,7 @@ def ipv6_gw_valid_address(cls, v, values, **kwargs): class MgmtdomainByIdApi(Resource): - @jwt_required + @login_required def get(self, mgmtdomain_id): """Get management domain by ID""" result = empty_result() @@ -92,23 +92,23 @@ def get(self, mgmtdomain_id): return empty_result("error", "Management domain not found"), 404 return result - @jwt_required + @login_required def delete(self, mgmtdomain_id): """Remove management domain""" with sqla_session() as session: instance: Mgmtdomain = session.query(Mgmtdomain).filter(Mgmtdomain.id == mgmtdomain_id).one_or_none() if instance: instance.device_a.synchronized = False - add_sync_event(instance.device_a.hostname, "mgmtdomain_deleted", get_jwt_identity()) + add_sync_event(instance.device_a.hostname, "mgmtdomain_deleted", get_identity()) instance.device_b.synchronized = False - add_sync_event(instance.device_b.hostname, "mgmtdomain_deleted", get_jwt_identity()) + add_sync_event(instance.device_b.hostname, "mgmtdomain_deleted", get_identity()) session.delete(instance) session.commit() return empty_result(status="success", data={"deleted_mgmtdomain": instance.as_dict()}), 200 else: return empty_result("error", "Management domain not found"), 404 - @jwt_required + @login_required @mgmtdomain_api.expect(mgmtdomain_model) def put(self, mgmtdomain_id): """Modify management domain""" @@ -128,9 +128,9 @@ def put(self, mgmtdomain_id): changed: bool = update_sqla_object(instance, json_data) if changed: instance.device_a.synchronized = False - add_sync_event(instance.device_a.hostname, "mgmtdomain_updated", get_jwt_identity()) + add_sync_event(instance.device_a.hostname, "mgmtdomain_updated", get_identity()) instance.device_b.synchronized = False - add_sync_event(instance.device_b.hostname, "mgmtdomain_updated", get_jwt_identity()) + add_sync_event(instance.device_b.hostname, "mgmtdomain_updated", get_identity()) return empty_result(status="success", data={"updated_mgmtdomain": instance.as_dict()}), 200 else: return empty_result(status="success", data={"unchanged_mgmtdomain": instance.as_dict()}), 200 @@ -139,7 +139,7 @@ def put(self, mgmtdomain_id): class MgmtdomainsApi(Resource): - @jwt_required + @login_required def get(self): """Get all management domains""" result = empty_result() @@ -154,7 +154,7 @@ def get(self): result["data"]["mgmtdomains"].append(instance.as_dict()) return result - @jwt_required + @login_required @mgmtdomain_api.expect(mgmtdomain_model) def post(self): """Add management domain""" @@ -214,9 +214,9 @@ def post(self): return empty_result("error", "Integrity error: {}".format(e)), 400 device_a.synchronized = False - add_sync_event(device_a.hostname, "mgmtdomain_created", get_jwt_identity()) + add_sync_event(device_a.hostname, "mgmtdomain_created", get_identity()) device_b.synchronized = False - add_sync_event(device_b.hostname, "mgmtdomain_created", get_jwt_identity()) + add_sync_event(device_b.hostname, "mgmtdomain_created", get_identity()) return empty_result(status="success", data={"added_mgmtdomain": new_mgmtd.as_dict()}), 200 else: errors.append( diff --git a/src/cnaas_nms/api/plugins.py b/src/cnaas_nms/api/plugins.py index a04b73c9..01c36f4a 100644 --- a/src/cnaas_nms/api/plugins.py +++ b/src/cnaas_nms/api/plugins.py @@ -3,7 +3,7 @@ from cnaas_nms.api.generic import empty_result from cnaas_nms.plugins.pluginmanager import PluginManagerHandler -from cnaas_nms.tools.security import jwt_required +from cnaas_nms.tools.security import login_required from cnaas_nms.version import __api_version__ api = Namespace("plugins", description="API for handling plugins", prefix="/api/{}".format(__api_version__)) @@ -17,7 +17,7 @@ class PluginsApi(Resource): - @jwt_required + @login_required def get(self): """List all plugins""" try: @@ -29,7 +29,7 @@ def get(self): else: return empty_result("success", {"loaded_plugins": plugin_module_names, "plugindata": plugindata}) - @jwt_required + @login_required @api.expect(plugin_model) def put(self): """Modify plugins""" diff --git a/src/cnaas_nms/api/repository.py b/src/cnaas_nms/api/repository.py index afdffb72..dd8e61c8 100644 --- a/src/cnaas_nms/api/repository.py +++ b/src/cnaas_nms/api/repository.py @@ -5,7 +5,7 @@ from cnaas_nms.db.git import RepoType, get_repo_status, refresh_repo from cnaas_nms.db.joblock import JoblockError from cnaas_nms.db.settings import SettingsSyntaxError, VerifyPathException -from cnaas_nms.tools.security import get_jwt_identity, jwt_required +from cnaas_nms.tools.security import get_identity, login_required from cnaas_nms.version import __api_version__ api = Namespace("repository", description="API for handling repositories", prefix="/api/{}".format(__api_version__)) @@ -19,7 +19,7 @@ class RepositoryApi(Resource): - @jwt_required + @login_required def get(self, repo): """Get repository information""" try: @@ -28,7 +28,7 @@ def get(self, repo): return empty_result("error", "Invalid repository type"), 400 return empty_result("success", get_repo_status(repo_type)) - @jwt_required + @login_required @api.expect(repository_model) def put(self, repo): """Modify repository""" @@ -42,7 +42,7 @@ def put(self, repo): if str(json_data["action"]).upper() == "REFRESH": # TODO: consider doing as scheduled job? try: - res = refresh_repo(repo_type, get_jwt_identity()) + res = refresh_repo(repo_type, get_identity()) return empty_result("success", res) except VerifyPathException as e: return ( diff --git a/src/cnaas_nms/api/settings.py b/src/cnaas_nms/api/settings.py index f0d2ad11..65c97d22 100644 --- a/src/cnaas_nms/api/settings.py +++ b/src/cnaas_nms/api/settings.py @@ -10,7 +10,7 @@ from cnaas_nms.db.session import sqla_session from cnaas_nms.db.settings import SettingsSyntaxError, check_settings_syntax, get_settings, get_settings_root from cnaas_nms.tools.mergedict import merge_dict_origin -from cnaas_nms.tools.security import jwt_required +from cnaas_nms.tools.security import login_required from cnaas_nms.version import __api_version__ settings_root_model = get_settings_root() @@ -20,7 +20,7 @@ class SettingsApi(Resource): - @jwt_required + @login_required @api.param("hostname") @api.param("device_type") def get(self): @@ -73,7 +73,7 @@ def post(self): class SettingsServerApI(Resource): - @jwt_required + @login_required def get(self): ret_dict = {"api": api_settings.dict()} response = make_response(json.dumps(ret_dict, default=json_dumper)) diff --git a/src/cnaas_nms/api/system.py b/src/cnaas_nms/api/system.py index 62e7c6e2..5e878db8 100644 --- a/src/cnaas_nms/api/system.py +++ b/src/cnaas_nms/api/system.py @@ -6,7 +6,7 @@ from cnaas_nms.api import app from cnaas_nms.api.generic import empty_result from cnaas_nms.scheduler.scheduler import Scheduler -from cnaas_nms.tools.security import jwt_required +from cnaas_nms.tools.security import login_required from cnaas_nms.version import __api_version__ from git import InvalidGitRepositoryError, NoSuchPathError, Repo @@ -16,7 +16,7 @@ class ShutdownApi(Resource): - @jwt_required + @login_required def post(self): print("System shutdown API called, exiting...") scheduler = Scheduler() diff --git a/src/cnaas_nms/app_settings.py b/src/cnaas_nms/app_settings.py index e0916053..3323b935 100644 --- a/src/cnaas_nms/app_settings.py +++ b/src/cnaas_nms/app_settings.py @@ -8,7 +8,6 @@ class AppSettings(BaseSettings): # Database settings - CNAAS_DB_HOSTNAME: str = "127.0.0.1" CNAAS_DB_USERNAME: str = "cnaas" CNAAS_DB_PASSWORD: str = "cnaas" @@ -39,14 +38,14 @@ class ApiSettings(BaseSettings): HTTPD_URL: str = "https://cnaas_httpd:1443/api/v1.0/firmware" VERIFY_TLS: bool = False VERIFY_TLS_DEVICE: bool = False - JWT_CERT: Path = "/opt/cnaas/jwtcert/public.pem" - CAFILE: Optional[Path] = "/opt/cnaas/cacert/rootCA.crt" - CAKEYFILE: Path = "/opt/cnaas/cacert/rootCA.key" - CERTPATH: Path = "/tmp/devicecerts/" + JWT_CERT: Path = Path("/opt/cnaas/jwtcert/public.pem") + CAFILE: Optional[Path] = Path("/opt/cnaas/cacert/rootCA.crt") + CAKEYFILE: Path = Path("/opt/cnaas/cacert/rootCA.key") + CERTPATH: Path = Path("/tmp/devicecerts/") ALLOW_APPLY_CONFIG_LIVERUN: bool = False FIRMWARE_URL: str = HTTPD_URL JWT_ENABLED: bool = True - PLUGIN_FILE: Path = "/etc/cnaas-nms/plugins.yml" + PLUGIN_FILE: Path = Path("/etc/cnaas-nms/plugins.yml") GLOBAL_UNIQUE_VLANS: bool = True INIT_MGMT_TIMEOUT: int = 30 MGMTDOMAIN_RESERVED_COUNT: int = 5 @@ -64,6 +63,17 @@ def primary_ip_version_is_valid(cls, version: int) -> int: return version +class AuthSettings(BaseSettings): + # Authorization settings + FRONTEND_CALLBACK_URL: str = "http://localhost/callback" + OIDC_CONF_WELL_KNOWN_URL: str = "well-known-openid-configuration-endpoint" + OIDC_CLIENT_SECRET: str = "xxx" + OIDC_CLIENT_ID: str = "client-id" + OIDC_ENABLED: bool = False + OIDC_CLIENT_SCOPE: str = "openid" + AUDIENCE: str = OIDC_CLIENT_ID + + def construct_api_settings() -> ApiSettings: api_config = Path("/etc/cnaas-nms/api.yml") @@ -133,5 +143,24 @@ def _create_repo_config(settings: AppSettings, config: dict) -> None: return app_settings +def construct_auth_settings() -> AuthSettings: + auth_config = Path("/etc/cnaas-nms/auth_config.yml") + if auth_config.is_file(): + with open(auth_config, "r") as auth_file: + config = yaml.safe_load(auth_file) + return AuthSettings( + OIDC_ENABLED=config.get("oidc_enabled", AuthSettings().OIDC_ENABLED), + FRONTEND_CALLBACK_URL=config.get("frontend_callback_url", AuthSettings().FRONTEND_CALLBACK_URL), + OIDC_CONF_WELL_KNOWN_URL=config.get("oidc_conf_well_known_url", AuthSettings().OIDC_CONF_WELL_KNOWN_URL), + OIDC_CLIENT_SECRET=config.get("oidc_client_secret", AuthSettings().OIDC_CLIENT_SECRET), + OIDC_CLIENT_ID=config.get("oidc_client_id", AuthSettings().OIDC_CLIENT_ID), + OIDC_CLIENT_SCOPE=config.get("oidc_client_scope", AuthSettings().OIDC_CLIENT_SCOPE), + AUDIENCE=config.get("audience", AuthSettings().AUDIENCE), + ) + else: + return AuthSettings() + + app_settings = construct_app_settings() api_settings = construct_api_settings() +auth_settings = construct_auth_settings() diff --git a/src/cnaas_nms/db/git.py b/src/cnaas_nms/db/git.py index 5e8fdd0e..e2837a2a 100644 --- a/src/cnaas_nms/db/git.py +++ b/src/cnaas_nms/db/git.py @@ -6,6 +6,7 @@ from urllib.parse import urldefrag import yaml +from git.exc import GitCommandError, NoSuchPathError from cnaas_nms.app_settings import app_settings from cnaas_nms.db.device import Device, DeviceType @@ -17,7 +18,6 @@ from cnaas_nms.devicehandler.sync_history import add_sync_event from cnaas_nms.tools.log import get_logger from git import InvalidGitRepositoryError, Repo -from git.exc import GitCommandError, NoSuchPathError logger = get_logger() diff --git a/src/cnaas_nms/devicehandler/cert.py b/src/cnaas_nms/devicehandler/cert.py index 58d1f11f..7a5d457b 100644 --- a/src/cnaas_nms/devicehandler/cert.py +++ b/src/cnaas_nms/devicehandler/cert.py @@ -114,7 +114,6 @@ def renew_cert( job_id: Optional[str] = None, scheduled_by: Optional[str] = None, ) -> NornirJobResult: - logger = get_logger() nr = cnaas_init() if hostname: diff --git a/src/cnaas_nms/devicehandler/firmware.py b/src/cnaas_nms/devicehandler/firmware.py index 9a3c68ae..0bd9b6a1 100644 --- a/src/cnaas_nms/devicehandler/firmware.py +++ b/src/cnaas_nms/devicehandler/firmware.py @@ -258,7 +258,6 @@ def device_upgrade_task( post_waittime: Optional[int] = 0, activate: Optional[bool] = False, ) -> NornirJobResult: - # If pre-flight is selected, execute the pre-flight task which # will verify the amount of disk space and so on. set_thread_data(job_id) @@ -357,7 +356,6 @@ def device_upgrade( reboot: Optional[bool] = False, scheduled_by: Optional[str] = None, ) -> NornirJobResult: - logger = get_logger() nr = cnaas_init() if hostname: diff --git a/src/cnaas_nms/tools/security.py b/src/cnaas_nms/tools/security.py index 99ccff3a..44536ec5 100644 --- a/src/cnaas_nms/tools/security.py +++ b/src/cnaas_nms/tools/security.py @@ -1,13 +1,22 @@ +from typing import Any, Mapping + +import requests +from authlib.integrations.flask_oauth2 import ResourceProtector, current_token +from authlib.oauth2.rfc6750 import BearerTokenValidator from flask_jwt_extended import get_jwt_identity as get_jwt_identity_orig from flask_jwt_extended import jwt_required as jwt_orig +from jose import exceptions, jwt +from jwt.exceptions import ExpiredSignatureError, InvalidKeyError, InvalidTokenError +import json +from cnaas_nms.app_settings import api_settings, auth_settings +from cnaas_nms.tools.log import get_logger -from cnaas_nms.app_settings import api_settings +logger = get_logger() def jwt_required(fn): """ This function enables development without Oauth. - """ if api_settings.JWT_ENABLED: return jwt_orig()(fn) @@ -18,6 +27,182 @@ def jwt_required(fn): def get_jwt_identity(): """ This function overides the identity when needed. - """ return get_jwt_identity_orig() if api_settings.JWT_ENABLED else "admin" + + +def get_oauth_userinfo(token_string): + """Give back the user info of the OAUTH account + + If JWT is disabled, we return "admin". + + We do an api call to request userinfo. This gives back all the userinfo. + We get the right info from there and return this to the user. + + Returns: + resp.json(): Object of the user info + + """ + # For now unnecersary, useful when we only use one log in method + if not auth_settings.OIDC_ENABLED: + return "Admin" + # Request the userinfo + try: + metadata = requests.get(auth_settings.OIDC_CONF_WELL_KNOWN_URL) + metadata.raise_for_status() + except requests.exceptions.HTTPError as e: + raise ConnectionError("Can't reach the OIDC URL") + except requests.exceptions.ConnectionError as e: + raise ConnectionError("OIDC metadata unavailable") + user_info_endpoint = metadata.json()["userinfo_endpoint"] + + data = {"token_type_hint": "access_token"} + headers = {"Authorization": "Bearer " + token_string} + try: + resp = requests.post(user_info_endpoint, data=data, headers=headers) + resp.raise_for_status() + except requests.exceptions.HTTPError as e: + body = json.loads(e.response.content) + logger.debug("Request not successful: " + body['error_description']) + raise InvalidTokenError(body['error_description']) + return resp.json() + +class MyBearerTokenValidator(BearerTokenValidator): + keys: Mapping[str, Any] = {} + + def get_keys(self): + """Get the keys for the OIDC decoding""" + try: + metadata = requests.get(auth_settings.OIDC_CONF_WELL_KNOWN_URL) + keys_endpoint = metadata.json()["jwks_uri"] + response = requests.get(url=keys_endpoint) + self.keys = response.json()["keys"] + except KeyError as e: + raise InvalidKeyError(e) + except requests.exceptions.HTTPError as e: + raise InvalidKeyError(e) + + + def get_key(self, kid): + """Get the key based on the kid""" + key = [k for k in self.keys if k['kid'] == kid] + if len(key) == 0: + logger.debug("Key not found. Get the keys.") + self.get_keys() + if len(self.keys) == 0: + logger.error("Keys not downloaded") + raise InvalidKeyError() + try: + key = [k for k in self.keys if k['kid'] == kid] + except KeyError as e: + logger.error("Keys in different format?") + raise InvalidKeyError(e) + if len(key) == 0: + logger.error("Key not in keys") + raise InvalidKeyError() + return key + + def authenticate_token(self, token_string: str): + """Check if token is active. + + If JWT is disabled, we return because no token is needed. + + We decode the header and check if it's good. + + We decode the token using the keys. + We first check if we can decode it, if not we request the keys. + The decode function also checks if it's not expired. + We get de decoded _token back, but for now we do nothing with this. + + Input + token_string(str): The tokenstring + Returns: + token(dict): Dictionary with access_token, decoded_token, token_type, audience, expires_at + + """ + # If OIDC is disabled, no token is needed (for future use) + if not auth_settings.OIDC_ENABLED: + return "no-token-needed" + + # First decode the header + try: + unverified_header = jwt.get_unverified_header(token_string) + except exceptions.JWSError as e: + raise InvalidTokenError(e) + except exceptions.JWTError as e: + # check if we can still get the user info + get_oauth_userinfo(token_string) + token = { + "access_token": token_string + } + return token + + # get the key + key = self.get_key(unverified_header.get("kid")) + + # decode the token + algorithm = unverified_header.get("alg") + try: + decoded_token = jwt.decode( + token_string, key, algorithms=algorithm, audience=auth_settings.AUDIENCE + ) + except exceptions.ExpiredSignatureError as e: + raise ExpiredSignatureError(e) + except exceptions.JWKError: + logger.error("Invalid Key") + raise InvalidKeyError(e) + except exceptions.JWTError as e: + logger.error("Invalid Token") + raise InvalidTokenError(e) + + # make an token object to make it easier to validate + token = { + "access_token": token_string, + "decoded_token": decoded_token, + "token_type": algorithm, + "audience": auth_settings.AUDIENCE, + "expires_at": decoded_token["exp"], + } + return token + + def validate_token(self, token, scopes, request): + """Check if token matches the requested scopes.""" + # For now we don't have a scope yet + # When needed, look at implementation example here: + # https://github.com/lepture/authlib/blob/master/authlib/oauth2/rfc6750/validator.py + return token + + +def get_oauth_identity(): + """Give back the email address of the OAUTH account + + If JWT is disabled, we return "admin". + + We do an api call to request userinfo. This gives back all the userinfo. + We get the right info from there and return this to the user. + + Returns: + email(str): Email of the logged in user + + """ + # For now unnecersary, useful when we nly use one log in method + if not auth_settings.OIDC_ENABLED: + return "Admin" + # Request the userinfo + userinfo = get_oauth_userinfo(current_token["access_token"]) + if "email" not in userinfo: + logger.error("Email is a required claim for oauth") + raise KeyError("Email is a required claim for oauth") + return userinfo["email"] + + +# check which method we use to log in and load vars needed for that +if auth_settings.OIDC_ENABLED is True: + oauth_required = ResourceProtector() + oauth_required.register_token_validator(MyBearerTokenValidator()) + login_required = oauth_required(optional=not auth_settings.OIDC_ENABLED) + get_identity = get_oauth_identity +else: + login_required = jwt_required + get_identity = get_jwt_identity + \ No newline at end of file