Skip to content

Commit

Permalink
Merge pull request SUNET#338 from SUNET/bugfix.confirm_commit_2_py311…
Browse files Browse the repository at this point in the history
…_race

Bugfix.confirm commit 2 py311 race
  • Loading branch information
indy-independence authored Jan 30, 2024
2 parents 0d3b8fa + 2af08d8 commit 6151451
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/cnaas_nms/db/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base

Base = declarative_base()
24 changes: 23 additions & 1 deletion src/cnaas_nms/db/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import enum
import json
from typing import Dict, Optional
import time
from typing import Dict, List, Optional

from nornir.core.task import AggregatedResult
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, SmallInteger, Unicode
Expand Down Expand Up @@ -267,3 +268,24 @@ def check_job_abort_status(cls, session, job_id) -> bool:
if job.status != JobStatus.RUNNING:
return True
return False

@classmethod
def wait_for_job_completion(
cls,
session,
job_id: int,
timeout: int = 300,
exit_status: Optional[List[JobStatus]] = None,
) -> None:
"""Wait for job to complete"""
start_time = time.time()
if not exit_status:
exit_status = [JobStatus.FINISHED, JobStatus.EXCEPTION, JobStatus.ABORTED]
while True:
job: Job = session.query(Job).filter(Job.id == job_id).one()
if job.status in exit_status:
return
if time.time() - start_time > timeout:
raise TimeoutError(f"Job {job_id} did not finish within {timeout} seconds")
time.sleep(1)
session.refresh(job)
4 changes: 2 additions & 2 deletions src/cnaas_nms/db/mgmtdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def is_taken(addr):
def _get_taken_ips(session) -> Set[IPAddress]:
"""Returns the full set of taken (used + reserved) IP addresses"""
device_query = (
session.query(Device).filter(Device.management_ip is not None).options(load_only("management_ip"))
session.query(Device).filter(Device.management_ip is not None).options(load_only(Device.management_ip))
)
used_ips = set(device.management_ip for device in device_query)
reserved_ip_query = session.query(ReservedIP).options(load_only("ip"))
reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip))
reserved_ips = set(reserved_ip.ip for reserved_ip in reserved_ip_query)

return used_ips | reserved_ips
5 changes: 4 additions & 1 deletion src/cnaas_nms/db/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_get_linknets(self):
session.add(device1)
session.add(device2)
test_linknet = Linknet(device_a=device1, device_b=device2)
session.add(test_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual([test_linknet], device1.get_linknets(session))
Expand All @@ -67,6 +68,7 @@ def test_get_links_to(self):
session.add(device1)
session.add(device2)
test_linknet = Linknet(device_a=device1, device_b=device2)
session.add(test_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual([test_linknet], device1.get_links_to(session, device2))
Expand All @@ -78,7 +80,8 @@ def test_get_neighbors(self):
with sqla_session() as session:
session.add(device1)
session.add(device2)
Linknet(device_a=device1, device_b=device2)
new_linknet = Linknet(device_a=device1, device_b=device2)
session.add(new_linknet)
device1 = session.query(Device).filter(Device.hostname == "test-device1").one()
device2 = session.query(Device).filter(Device.hostname == "test-device2").one()
self.assertEqual(set([device2]), device1.get_neighbors(session))
Expand Down
5 changes: 5 additions & 0 deletions src/cnaas_nms/devicehandler/sync_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cnaas_nms.db.device_vars import expand_interface_settings
from cnaas_nms.db.git import RepoStructureException
from cnaas_nms.db.interface import Interface
from cnaas_nms.db.job import Job
from cnaas_nms.db.joblock import Joblock, JoblockError
from cnaas_nms.db.session import redis_session, sqla_session
from cnaas_nms.db.settings import get_settings
Expand Down Expand Up @@ -1009,6 +1010,10 @@ def sync_devices(
kwargs={"prev_job_id": job_id, "hostnames": changed_hosts},
)
logger.info(f"Commit-confirm for job id {job_id} scheduled as job id {next_job_id}")
# keep this thread running until next_job has finished so the device session is not closed,
# causing cancellation of pending commits
with sqla_session() as session:
Job.wait_for_job_completion(session, next_job_id)

return NornirJobResult(nrresult=nrresult, next_job_id=next_job_id, change_score=total_change_score)

Expand Down
8 changes: 5 additions & 3 deletions src/cnaas_nms/devicehandler/underlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def find_free_infra_ip(session) -> Optional[IPv4Address]:
"""Returns first free IPv4 infra IP."""
used_ips = []
device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only("infra_ip"))
device_query = session.query(Device).filter(Device.infra_ip.isnot(None)).options(load_only(Device.infra_ip))
for device in device_query:
used_ips.append(device.infra_ip)
settings, settings_origin = get_settings(device_type=DeviceType.CORE)
Expand All @@ -30,10 +30,12 @@ def find_free_mgmt_lo_ip(session) -> Optional[IPv4Address]:
"""Returns first free IPv4 infra IP."""
used_ips = []
reserved_ips = []
device_query = session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only("management_ip"))
device_query = (
session.query(Device).filter(Device.management_ip.isnot(None)).options(load_only(Device.management_ip))
)
for device in device_query:
used_ips.append(device.management_ip)
reserved_ip_query = session.query(ReservedIP).options(load_only("ip"))
reserved_ip_query = session.query(ReservedIP).options(load_only(ReservedIP.ip))
for reserved_ip in reserved_ip_query:
reserved_ips.append(reserved_ip.ip)

Expand Down
3 changes: 0 additions & 3 deletions src/cnaas_nms/devicehandler/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,6 @@ def update_linknets(
if not dry_run:
session.add(new_link)
session.commit()
else:
# Make sure linknet object is not added to session because of foreign key load
session.expunge(new_link)
# Make return data pretty
ret_dict = {
**new_link.as_dict(),
Expand Down

0 comments on commit 6151451

Please sign in to comment.