diff --git a/src/cnaas_nms/db/base.py b/src/cnaas_nms/db/base.py index 860e5425..59be7030 100644 --- a/src/cnaas_nms/db/base.py +++ b/src/cnaas_nms/db/base.py @@ -1,3 +1,3 @@ -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base Base = declarative_base() diff --git a/src/cnaas_nms/db/job.py b/src/cnaas_nms/db/job.py index 23a9cae0..92392d01 100644 --- a/src/cnaas_nms/db/job.py +++ b/src/cnaas_nms/db/job.py @@ -1,7 +1,8 @@ import datetime import enum import json -from typing import Dict, Optional +import time +from typing import Dict, List, Optional from nornir.core.task import AggregatedResult from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, SmallInteger, Unicode @@ -267,3 +268,24 @@ def check_job_abort_status(cls, session, job_id) -> bool: if job.status != JobStatus.RUNNING: return True return False + + @classmethod + def wait_for_job_completion( + cls, + session, + job_id: int, + timeout: int = 300, + exit_status: Optional[List[JobStatus]] = None, + ) -> None: + """Wait for job to complete""" + start_time = time.time() + if not exit_status: + exit_status = [JobStatus.FINISHED, JobStatus.EXCEPTION, JobStatus.ABORTED] + while True: + job: Job = session.query(Job).filter(Job.id == job_id).one() + if job.status in exit_status: + return + if time.time() - start_time > timeout: + raise TimeoutError(f"Job {job_id} did not finish within {timeout} seconds") + time.sleep(1) + session.refresh(job) diff --git a/src/cnaas_nms/db/mgmtdomain.py b/src/cnaas_nms/db/mgmtdomain.py index 01342049..a00c0401 100644 --- a/src/cnaas_nms/db/mgmtdomain.py +++ b/src/cnaas_nms/db/mgmtdomain.py @@ -123,10 +123,10 @@ def is_taken(addr): def _get_taken_ips(session) -> Set[IPAddress]: """Returns the full set of taken (used + reserved) IP addresses""" device_query = ( - session.query(Device).filter(Device.management_ip is not None).options(load_only("management_ip")) + session.query(Device).filter(Device.management_ip is not None).options(load_only(Device.management_ip)) ) used_ips = set(device.management_ip for device in device_query) - reserved_ip_query = session.query(ReservedIP).options(load_only("ip")) + reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip)) reserved_ips = set(reserved_ip.ip for reserved_ip in reserved_ip_query) return used_ips | reserved_ips diff --git a/src/cnaas_nms/db/tests/test_device.py b/src/cnaas_nms/db/tests/test_device.py index 66ef49df..c37ea672 100644 --- a/src/cnaas_nms/db/tests/test_device.py +++ b/src/cnaas_nms/db/tests/test_device.py @@ -55,6 +55,7 @@ def test_get_linknets(self): session.add(device1) session.add(device2) test_linknet = Linknet(device_a=device1, device_b=device2) + session.add(test_linknet) device1 = session.query(Device).filter(Device.hostname == "test-device1").one() device2 = session.query(Device).filter(Device.hostname == "test-device2").one() self.assertEqual([test_linknet], device1.get_linknets(session)) @@ -67,6 +68,7 @@ def test_get_links_to(self): session.add(device1) session.add(device2) test_linknet = Linknet(device_a=device1, device_b=device2) + session.add(test_linknet) device1 = session.query(Device).filter(Device.hostname == "test-device1").one() device2 = session.query(Device).filter(Device.hostname == "test-device2").one() self.assertEqual([test_linknet], device1.get_links_to(session, device2)) @@ -78,7 +80,8 @@ def test_get_neighbors(self): with sqla_session() as session: session.add(device1) session.add(device2) - Linknet(device_a=device1, device_b=device2) + new_linknet = Linknet(device_a=device1, device_b=device2) + session.add(new_linknet) device1 = session.query(Device).filter(Device.hostname == "test-device1").one() device2 = session.query(Device).filter(Device.hostname == "test-device2").one() self.assertEqual(set([device2]), device1.get_neighbors(session)) diff --git a/src/cnaas_nms/devicehandler/sync_devices.py b/src/cnaas_nms/devicehandler/sync_devices.py index 68626bf7..2ab6ccab 100644 --- a/src/cnaas_nms/devicehandler/sync_devices.py +++ b/src/cnaas_nms/devicehandler/sync_devices.py @@ -20,6 +20,7 @@ from cnaas_nms.db.device_vars import expand_interface_settings from cnaas_nms.db.git import RepoStructureException from cnaas_nms.db.interface import Interface +from cnaas_nms.db.job import Job from cnaas_nms.db.joblock import Joblock, JoblockError from cnaas_nms.db.session import redis_session, sqla_session from cnaas_nms.db.settings import get_settings @@ -1009,6 +1010,10 @@ def sync_devices( kwargs={"prev_job_id": job_id, "hostnames": changed_hosts}, ) logger.info(f"Commit-confirm for job id {job_id} scheduled as job id {next_job_id}") + # keep this thread running until next_job has finished so the device session is not closed, + # causing cancellation of pending commits + with sqla_session() as session: + Job.wait_for_job_completion(session, next_job_id) return NornirJobResult(nrresult=nrresult, next_job_id=next_job_id, change_score=total_change_score) diff --git a/src/cnaas_nms/devicehandler/underlay.py b/src/cnaas_nms/devicehandler/underlay.py index a5fb9451..f3a116f5 100644 --- a/src/cnaas_nms/devicehandler/underlay.py +++ b/src/cnaas_nms/devicehandler/underlay.py @@ -12,7 +12,7 @@ def find_free_infra_ip(session) -> Optional[IPv4Address]: """Returns first free IPv4 infra IP.""" used_ips = [] - device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only("infra_ip")) + device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only(Device.infra_ip)) for device in device_query: used_ips.append(device.infra_ip) settings, settings_origin = get_settings(device_type=DeviceType.CORE) @@ -30,10 +30,12 @@ def find_free_mgmt_lo_ip(session) -> Optional[IPv4Address]: """Returns first free IPv4 infra IP.""" used_ips = [] reserved_ips = [] - device_query = session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only("management_ip")) + device_query = ( + session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only(Device.management_ip)) + ) for device in device_query: used_ips.append(device.management_ip) - reserved_ip_query = session.query(ReservedIP).options(load_only("ip")) + reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip)) for reserved_ip in reserved_ip_query: reserved_ips.append(reserved_ip.ip) diff --git a/src/cnaas_nms/devicehandler/update.py b/src/cnaas_nms/devicehandler/update.py index a5586ba8..109e4b03 100644 --- a/src/cnaas_nms/devicehandler/update.py +++ b/src/cnaas_nms/devicehandler/update.py @@ -354,9 +354,6 @@ def update_linknets( if not dry_run: session.add(new_link) session.commit() - else: - # Make sure linknet object is not added to session because of foreign key load - session.expunge(new_link) # Make return data pretty ret_dict = { **new_link.as_dict(),