Skip to content

Commit

Permalink
Add serial groups to internal jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Jan 10, 2025
1 parent a7ee70c commit 5b0b97f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Add internal job serial group
Revision ID: 3690c677f8d1
Revises: 5f6f804e11d3
Create Date: 2025-01-10 16:08:36.541807
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "3690c677f8d1"
down_revision = "5f6f804e11d3"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("internal_jobs", sa.Column("serial_group", sa.String(), nullable=True))
op.create_index(
"ux_internal_jobs_status_serial_group",
"internal_jobs",
["status", "serial_group"],
unique=True,
postgresql_where=sa.text("status = 'running'"),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
"ux_internal_jobs_status_serial_group",
table_name="internal_jobs",
postgresql_where=sa.text("status = 'running'"),
)
op.drop_column("internal_jobs", "serial_group")
# ### end Alembic commands ###
10 changes: 8 additions & 2 deletions qcfractal/qcfractal/components/dataset_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING

from sqlalchemy import select, delete, func, union, text, and_
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic
from sqlalchemy.orm.attributes import flag_modified

Expand Down Expand Up @@ -1923,9 +1924,14 @@ def add_create_view_attachment_job(
},
user_id=None,
unique_name=True,
serial_group="ds_create_view",
session=session,
)

ds_job_orm = DatasetInternalJobORM(dataset_id=dataset_id, internal_job_id=job_id)
session.add(ds_job_orm)
stmt = (
insert(DatasetInternalJobORM)
.values(dataset_id=dataset_id, internal_job_id=job_id)
.on_conflict_do_nothing()
)
session.execute(stmt)
return job_id
11 changes: 11 additions & 0 deletions qcfractal/qcfractal/components/internal_jobs/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class InternalJobORM(BaseORM):
# it must be unique. null != null always
unique_name = Column(String, nullable=True)

# If this job is part of a serial group (only one may run at a time)
serial_group = Column(String, nullable=True)

__table_args__ = (
Index("ix_internal_jobs_added_date", "added_date", postgresql_using="brin"),
Index("ix_internal_jobs_scheduled_date", "scheduled_date", postgresql_using="brin"),
Expand All @@ -64,6 +67,14 @@ class InternalJobORM(BaseORM):
Index("ix_internal_jobs_name", "name"),
Index("ix_internal_jobs_user_id", "user_id"),
UniqueConstraint("unique_name", name="ux_internal_jobs_unique_name"),
# Enforces only one running per serial group
Index(
"ux_internal_jobs_status_serial_group",
"status",
"serial_group",
unique=True,
postgresql_where=(status == InternalJobStatusEnum.running),
),
)

_qcportal_model_excludes = ["unique_name", "user_id"]
Expand Down
48 changes: 45 additions & 3 deletions qcfractal/qcfractal/components/internal_jobs/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import psycopg2.extensions
from sqlalchemy import select, delete, update, and_, or_
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError

from qcfractal.components.auth.db_models import UserIDMapSubquery
from qcfractal.db_socket.helpers import get_query_proj_options
Expand Down Expand Up @@ -71,6 +72,7 @@ def add(
after_function: Optional[str] = None,
after_function_kwargs: Optional[Dict[str, Any]] = None,
repeat_delay: Optional[int] = None,
serial_group: Optional[str] = None,
*,
session: Optional[Session] = None,
) -> int:
Expand Down Expand Up @@ -99,6 +101,8 @@ def add(
Arguments to use when calling `after_function`
repeat_delay
If set, will submit a new, identical job to be run repeat_delay seconds after this one finishes
serial_group
Only one job within this group may be run at once. If None, there is no limit
session
An existing SQLAlchemy session to use. If None, one will be created. If an existing session
is used, it will be flushed (but not committed) before returning from this function.
Expand All @@ -122,6 +126,7 @@ def add(
after_function=after_function,
after_function_kwargs=after_function_kwargs,
repeat_delay=repeat_delay,
serial_group=serial_group,
user_id=user_id,
)
stmt = stmt.on_conflict_do_update(
Expand Down Expand Up @@ -153,6 +158,7 @@ def add(
after_function=after_function,
after_function_kwargs=after_function_kwargs,
repeat_delay=repeat_delay,
serial_group=serial_group,
user_id=user_id,
)
if unique_name:
Expand Down Expand Up @@ -413,8 +419,16 @@ def _wait_for_job(session: Session, logger, conn, end_event):
Blocks until a job is possibly available to run
"""

serial_cte = select(InternalJobORM.serial_group).distinct()
serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running)
serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None))
serial_cte = serial_cte.cte()

next_job_stmt = select(InternalJobORM.scheduled_date)
next_job_stmt = next_job_stmt.where(InternalJobORM.status == InternalJobStatusEnum.waiting)
next_job_stmt = next_job_stmt.where(
or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte)))
)
next_job_stmt = next_job_stmt.order_by(InternalJobORM.scheduled_date.asc())

# Skip any that are being claimed for running right now
Expand Down Expand Up @@ -516,6 +530,11 @@ def run_loop(self, end_event):
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)

# Prepare a statement for finding jobs. Filters will be added in the loop
serial_cte = select(InternalJobORM.serial_group).distinct()
serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running)
serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None))
serial_cte = serial_cte.cte()

stmt = select(InternalJobORM)
stmt = stmt.order_by(InternalJobORM.scheduled_date.asc()).limit(1)
stmt = stmt.with_for_update(skip_locked=True)
Expand All @@ -528,10 +547,20 @@ def run_loop(self, end_event):
logger.debug("checking for jobs")

# Pick up anything waiting, or anything that hasn't been updated in a while (12 update periods)

now = now_at_utc()
dead = now - timedelta(seconds=(self._update_frequency * 12))
logger.debug(f"checking for jobs before date {now}")
cond1 = and_(InternalJobORM.status == InternalJobStatusEnum.waiting, InternalJobORM.scheduled_date <= now)

# Job is waiting, schedule to be run now or in the past,
# Serial group does not have any running or is not set
cond1 = and_(
InternalJobORM.status == InternalJobStatusEnum.waiting,
InternalJobORM.scheduled_date <= now,
or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte))),
)

# Job is running but runner is determined to be dead. Serial group doesn't matter
cond2 = and_(InternalJobORM.status == InternalJobStatusEnum.running, InternalJobORM.last_updated < dead)

stmt_now = stmt.where(or_(cond1, cond2))
Expand Down Expand Up @@ -560,8 +589,21 @@ def run_loop(self, end_event):
job_orm.runner_uuid = runner_uuid
job_orm.status = InternalJobStatusEnum.running

# Releases the row-level lock (from the with_for_update() in the original query)
session_main.commit()
# For logging below - the object might be in an odd state after the exception, where accessing
# it results in further exceptions
serial_group = job_orm.serial_group

# Violation of the unique constraint may occur - two runners attempting to take
# different jobs of the same serial group at the same time
try:
# Releases the row-level lock (from the with_for_update() in the original query)
session_main.commit()
except IntegrityError:
logger.info(
f"Attempting to run job from serial group '{serial_group}, but seems like another runner got to another job from the same group first"
)
session_main.rollback()
continue

job_progress = JobProgress(job_orm.id, runner_uuid, session_status, self._update_frequency, end_event)
self._run_single(session_main, job_orm, logger, job_progress=job_progress)
Expand Down
44 changes: 44 additions & 0 deletions qcfractal/qcfractal/components/internal_jobs/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,50 @@ def test_internal_jobs_socket_run(storage_socket: SQLAlchemySocket, session: Ses
th.join()


@pytest.mark.parametrize("job_func", ("internal_jobs.dummy_job", "internal_jobs.dummy_job_2"))
def test_internal_jobs_socket_run_serial(storage_socket: SQLAlchemySocket, session: Session, job_func: str):
id_1 = storage_socket.internal_jobs.add(
"dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test"
)
id_2 = storage_socket.internal_jobs.add(
"dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test"
)
id_3 = storage_socket.internal_jobs.add(
"dummy_job",
now_at_utc(),
job_func,
{"iterations": 10},
None,
unique_name=False,
)

# Faster updates for testing
storage_socket.internal_jobs._update_frequency = 1

end_event = threading.Event()
th1 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,))
th2 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,))
th3 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,))
th1.start()
th2.start()
th3.start()
time.sleep(8)

try:
job_1 = session.get(InternalJobORM, id_1)
job_2 = session.get(InternalJobORM, id_2)
job_3 = session.get(InternalJobORM, id_3)
assert job_1.status == InternalJobStatusEnum.running
assert job_2.status == InternalJobStatusEnum.waiting
assert job_3.status == InternalJobStatusEnum.running

finally:
end_event.set()
th1.join()
th2.join()
th3.join()


def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session):
id_1 = storage_socket.internal_jobs.add(
"dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False
Expand Down
1 change: 1 addition & 0 deletions qcportal/qcportal/internal_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Config:
runner_hostname: Optional[str]
runner_uuid: Optional[str]
repeat_delay: Optional[int]
serial_group: Optional[str]

progress: int
progress_description: Optional[str] = None
Expand Down

0 comments on commit 5b0b97f

Please sign in to comment.