Skip to content

Commit

Permalink
Merge branch 'feature.authorization' of github.com:Josephine-Rutten/c…
Browse files Browse the repository at this point in the history
…naas-nms into feature.authorization
  • Loading branch information
indy-independence committed Feb 5, 2024
2 parents 79a47ec + be3da89 commit f0560d9
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/cnaas_nms/db/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base

Base = declarative_base()
24 changes: 23 additions & 1 deletion src/cnaas_nms/db/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import enum
import json
from typing import Dict, Optional
import time
from typing import Dict, List, Optional

from nornir.core.task import AggregatedResult
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, SmallInteger, Unicode
Expand Down Expand Up @@ -267,3 +268,24 @@ def check_job_abort_status(cls, session, job_id) -> bool:
if job.status != JobStatus.RUNNING:
return True
return False

@classmethod
def wait_for_job_completion(
cls,
session,
job_id: int,
timeout: int = 300,
exit_status: Optional[List[JobStatus]] = None,
) -> None:
"""Wait for job to complete"""
start_time = time.time()
if not exit_status:
exit_status = [JobStatus.FINISHED, JobStatus.EXCEPTION, JobStatus.ABORTED]
while True:
job: Job = session.query(Job).filter(Job.id == job_id).one()
if job.status in exit_status:
return
if time.time() - start_time > timeout:
raise TimeoutError(f"Job {job_id} did not finish within {timeout} seconds")
time.sleep(1)
session.refresh(job)
4 changes: 2 additions & 2 deletions src/cnaas_nms/db/mgmtdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def is_taken(addr):
def _get_taken_ips(session) -> Set[IPAddress]:
"""Returns the full set of taken (used + reserved) IP addresses"""
device_query = (
session.query(Device).filter(Device.management_ip is not None).options(load_only("management_ip"))
session.query(Device).filter(Device.management_ip is not None).options(load_only(Device.management_ip))
)
used_ips = set(device.management_ip for device in device_query)
reserved_ip_query = session.query(ReservedIP).options(load_only("ip"))
reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip))
reserved_ips = set(reserved_ip.ip for reserved_ip in reserved_ip_query)

return used_ips | reserved_ips
5 changes: 4 additions & 1 deletion src/cnaas_nms/db/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_get_linknets(self):
session.add(device1)
session.add(device2)
test_linknet = Linknet(device_a=device1, device_b=device2)
session.add(test_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual([test_linknet], device1.get_linknets(session))
Expand All @@ -67,6 +68,7 @@ def test_get_links_to(self):
session.add(device1)
session.add(device2)
test_linknet = Linknet(device_a=device1, device_b=device2)
session.add(test_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual([test_linknet], device1.get_links_to(session, device2))
Expand All @@ -78,7 +80,8 @@ def test_get_neighbors(self):
with sqla_session() as session:
session.add(device1)
session.add(device2)
Linknet(device_a=device1, device_b=device2)
new_linknet = Linknet(device_a=device1, device_b=device2)
session.add(new_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual(set([device2]), device1.get_neighbors(session))
Expand Down
5 changes: 5 additions & 0 deletions src/cnaas_nms/devicehandler/sync_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cnaas_nms.db.device_vars import expand_interface_settings
from cnaas_nms.db.git import RepoStructureException
from cnaas_nms.db.interface import Interface
from cnaas_nms.db.job import Job
from cnaas_nms.db.joblock import Joblock, JoblockError
from cnaas_nms.db.session import redis_session, sqla_session
from cnaas_nms.db.settings import get_settings
Expand Down Expand Up @@ -1009,6 +1010,10 @@ def sync_devices(
kwargs={"prev_job_id": job_id, "hostnames": changed_hosts},
)
logger.info(f"Commit-confirm for job id {job_id} scheduled as job id {next_job_id}")
# keep this thread running until next_job has finished so the device session is not closed,
# causing cancellation of pending commits
with sqla_session() as session:
Job.wait_for_job_completion(session, next_job_id)

return NornirJobResult(nrresult=nrresult, next_job_id=next_job_id, change_score=total_change_score)

Expand Down
8 changes: 5 additions & 3 deletions src/cnaas_nms/devicehandler/underlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def find_free_infra_ip(session) -> Optional[IPv4Address]:
"""Returns first free IPv4 infra IP."""
used_ips = []
device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only("infra_ip"))
device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only(Device.infra_ip))
for device in device_query:
used_ips.append(device.infra_ip)
settings, settings_origin = get_settings(device_type=DeviceType.CORE)
Expand All @@ -30,10 +30,12 @@ def find_free_mgmt_lo_ip(session) -> Optional[IPv4Address]:
"""Returns first free IPv4 infra IP."""
used_ips = []
reserved_ips = []
device_query = session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only("management_ip"))
device_query = (
session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only(Device.management_ip))
)
for device in device_query:
used_ips.append(device.management_ip)
reserved_ip_query = session.query(ReservedIP).options(load_only("ip"))
reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip))
for reserved_ip in reserved_ip_query:
reserved_ips.append(reserved_ip.ip)

Expand Down
3 changes: 0 additions & 3 deletions src/cnaas_nms/devicehandler/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,6 @@ def update_linknets(
if not dry_run:
session.add(new_link)
session.commit()
else:
# Make sure linknet object is not added to session because of foreign key load
session.expunge(new_link)
# Make return data pretty
ret_dict = {
**new_link.as_dict(),
Expand Down
9 changes: 7 additions & 2 deletions src/cnaas_nms/tools/rbac/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ def get_permissions_user(permissions_rules: PermissionsModel, user_info: dict):
mappings: dict[str, list[str]]
for map_type, mappings in permissions_rules.group_mappings.items():
for value, groups in mappings.items():
if map_type in user_info and value in user_info[map_type]:
user_roles.extend(groups)
if map_type in user_info:
# if the type is a list in userinfo, we check if the value is in the list
# if not a list, we assume it's a string and compare it directly
if (isinstance(user_info[map_type], list) and value in user_info[map_type]) or value == user_info[
map_type
]:
user_roles.extend(groups)

# find the relevant roles and add permissions
relevant_roles = list(set(permissions_rules.roles) & set(user_roles))
Expand Down
26 changes: 6 additions & 20 deletions src/cnaas_nms/tools/rbac/tests/test_security_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_role_permissions_with_default(self):

def test_role_permissions(self):
user_info = {
"edumember_is_member_of": "urn:collab:group:test.surfconext.nl:nl:surfnet:diensten:surfwired-admin"
"edumember_is_member_of": ["urn:collab:group:test.surfconext.nl:nl:surfnet:diensten:surfwired-admin"]
}
permissions_of_user = get_permissions_user(self.permissions_rules_without_default, user_info)
expected_result = [
Expand All @@ -130,30 +130,16 @@ def test_role_permissions(self):
self.assertEqual(permissions_of_user, expected_result)

def test_role_permissions_only_default(self):
user_info = {"edumember_is_member_of": "notarealrole"}
user_info = {
"edumember_is_member_of": "not-a-real-role-urn:collab:group:test.surfconext.nl:nl:surfnet:diensten:surfwired-admin-test"
}
permissions_of_user = get_permissions_user(self.permissions_rules_with_default, user_info)
expected_result = [PermissionModel(methods=["GET"], endpoints=["/devices**"])]
self.assertEqual(permissions_of_user, expected_result)

def test_role_permissions_zero(self):
permissions_rules = PermissionsModel(
group_mappings={
"edumember_is_member_of": {
"urn:collab:group:test.surfconext.nl:nl:surfnet:diensten:surfwired-admin": ["admin"]
}
},
roles={
"admin": RoleModel(
permissions=[
PermissionModel(methods=["GET", "PUT"], endpoints=["/devices/**/interfaces", "/repository"]),
PermissionModel(methods=["POST"], endpoints=["/auth/*", "/devices"]),
]
),
"default": RoleModel(permissions=[PermissionModel(methods=["GET"], endpoints=["/devices**"])]),
},
)
user_info = {"edumember_is_member_of": "notarealrole"}
permissions_of_user = get_permissions_user(permissions_rules, user_info)
user_info = {"edumember_is_member_of": ["admin"]}
permissions_of_user = get_permissions_user(self.permissions_rules_without_default, user_info)
expected_result = []
self.assertEqual(permissions_of_user, expected_result)

Expand Down

0 comments on commit f0560d9

Please sign in to comment.