Skip to content

Commit

Permalink
Don't end computations until cluster is truly idle (#7790)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored May 23, 2023
1 parent 755f768 commit fcd921c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 37 deletions.
20 changes: 18 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,8 @@ class SchedulerState:
running: set[WorkerState]
#: Workers that are currently in running state and not fully utilized
#: Definition based on occupancy
#: (actually a SortedDict, but the sortedcontainers package isn't annotated)
#: (actually a SortedDict, but the sortedcontainers package isn't annotated).
#: Not to be confused with :meth:`is_idle`.
idle: dict[str, WorkerState]
#: Similar to `idle`
#: Definition based on assigned tasks
Expand Down Expand Up @@ -1775,6 +1776,21 @@ def _clear_task_state(self) -> None:
):
collection.clear() # type: ignore

@property
def is_idle(self) -> bool:
"""Return True iff there are no tasks that haven't finished computing.
Unlike testing `self.total_occupancy`, this property returns False if there are
long-running tasks, no-worker, or queued tasks (due to not having any workers).
Not to be confused with :ivar:`idle`.
"""
return all(
count == 0 or state in {"memory", "error", "released", "forgotten"}
for tg in self.task_groups.values()
for state, count in tg.states.items()
)

@property
def total_occupancy(self) -> float:
return self._calc_occupancy(
Expand Down Expand Up @@ -4325,7 +4341,7 @@ def update_graph(
keys=lost_keys, client=client, stimulus_id=stimulus_id
)

if self.total_occupancy > 1e-9 and self.computations:
if not self.is_idle and self.computations:
# Still working on something. Assign new tasks to same computation
computation = self.computations[-1]
else:
Expand Down
90 changes: 90 additions & 0 deletions distributed/tests/test_computations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Tests for distributed.scheduler.Computation objects"""
from __future__ import annotations

import pytest

from distributed import Event, Worker, secede
from distributed.utils_test import async_poll_for, gen_cluster, inc, wait_for_state


@gen_cluster(client=True)
async def test_computations(c, s, a, b):
da = pytest.importorskip("dask.array")

x = da.ones(100, chunks=(10,))
y = (x + 1).persist()
await y

z = (x - 2).persist()
await z

assert len(s.computations) == 2
assert "add" in str(s.computations[0].groups)
assert "sub" in str(s.computations[1].groups)
assert "sub" not in str(s.computations[0].groups)

assert isinstance(repr(s.computations[1]), str)

assert s.computations[1].stop == max(tg.stop for tg in s.task_groups.values())

assert s.computations[0].states["memory"] == y.npartitions


@gen_cluster(client=True)
async def test_computations_futures(c, s, a, b):
futures = [c.submit(inc, i) for i in range(10)]
total = c.submit(sum, futures)
await total

[computation] = s.computations
assert "sum" in str(computation.groups)
assert "inc" in str(computation.groups)


@gen_cluster(client=True, nthreads=[])
async def test_computations_no_workers(c, s):
"""If a computation is stuck due to lack of workers, don't create a new one"""
x = c.submit(inc, 1, key="x")
await wait_for_state("x", ("queued", "no-worker"), s)
y = c.submit(inc, 2, key="y")
await wait_for_state("y", ("queued", "no-worker"), s)
assert s.total_occupancy == 0
async with Worker(s.address):
assert await x == 2
assert await y == 3
[computation] = s.computations
assert computation.groups == {s.task_groups["x"], s.task_groups["y"]}


@gen_cluster(client=True)
async def test_computations_no_resources(c, s, a, b):
"""If a computation is stuck due to lack of resources, don't create a new one"""
x = c.submit(inc, 1, key="x", resources={"A": 1})
await wait_for_state("x", "no-worker", s)
y = c.submit(inc, 2, key="y")
assert await y == 3
assert s.total_occupancy == 0
async with Worker(s.address, resources={"A": 1}):
assert await x == 2
[computation] = s.computations
assert computation.groups == {s.task_groups["x"], s.task_groups["y"]}


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_computations_long_running(c, s, a):
"""Don't create new computations if there are long-running tasks"""
ev = Event()

def func(ev):
secede()
ev.wait()

x = c.submit(func, ev, key="x")
await wait_for_state("x", "long-running", a)
await async_poll_for(lambda: s.total_occupancy == 0, timeout=5)
y = c.submit(inc, 1, key="y")
assert await y == 2
await ev.set()
await x
[computation] = s.computations
assert computation.groups == {s.task_groups["x"], s.task_groups["y"]}
37 changes: 2 additions & 35 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,8 @@ async def test_default_task_duration_splits(c, s, a, b):
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")

# We don't care about the actual computation here but we'll schedule one anyhow to verify that we're looking for the correct key
# We don't care about the actual computation here but we'll schedule one anyhow to
# verify that we're looking for the correct key
npart = 10
df = dd.from_pandas(pd.DataFrame({"A": range(100), "B": 1}), npartitions=npart)
graph = df.shuffle(
Expand Down Expand Up @@ -3748,40 +3749,6 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first):
assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes


@gen_cluster(client=True)
async def test_computations(c, s, a, b):
da = pytest.importorskip("dask.array")

x = da.ones(100, chunks=(10,))
y = (x + 1).persist()
await y

z = (x - 2).persist()
await z

assert len(s.computations) == 2
assert "add" in str(s.computations[0].groups)
assert "sub" in str(s.computations[1].groups)
assert "sub" not in str(s.computations[0].groups)

assert isinstance(repr(s.computations[1]), str)

assert s.computations[1].stop == max(tg.stop for tg in s.task_groups.values())

assert s.computations[0].states["memory"] == y.npartitions


@gen_cluster(client=True)
async def test_computations_futures(c, s, a, b):
futures = [c.submit(inc, i) for i in range(10)]
total = c.submit(sum, futures)
await total

[computation] = s.computations
assert "sum" in str(computation.groups)
assert "inc" in str(computation.groups)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_transition_counter(c, s, a):
assert s.transition_counter == 0
Expand Down

0 comments on commit fcd921c

Please sign in to comment.