Skip to content

Commit

Permalink
Restore retire workers API
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Nov 15, 2024
1 parent d7eff77 commit e16728f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
26 changes: 13 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7443,7 +7443,7 @@ async def retire_workers(
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
) -> dict[str, Any]: ...

@overload
async def retire_workers(
Expand All @@ -7453,7 +7453,7 @@ async def retire_workers(
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
) -> dict[str, Any]: ...

@overload
async def retire_workers(
Expand All @@ -7469,7 +7469,7 @@ async def retire_workers(
minimum: int | None = None,
target: int | None = None,
attribute: str = "address",
) -> list[str]: ...
) -> dict[str, Any]: ...

@log_errors
async def retire_workers(
Expand All @@ -7481,7 +7481,7 @@ async def retire_workers(
remove: bool = True,
stimulus_id: str | None = None,
**kwargs: Any,
) -> list[str]:
) -> dict[str, Any]:
"""Gracefully retire workers from cluster. Any key that is in memory exclusively
on the retired workers is replicated somewhere else.
Expand Down Expand Up @@ -7559,7 +7559,7 @@ async def retire_workers(
self.workers[address] for address in self.workers_to_close(**kwargs)
}
if not wss:
return []
return {}

stop_amm = False
amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm")
Expand Down Expand Up @@ -7609,13 +7609,13 @@ async def retire_workers(
# time (depending on interval settings)
amm.run_once()

workers_info_ok = []
workers_info_abort = []
for addr, result in await asyncio.gather(*coros):
workers_info_ok = {}
workers_info_abort = {}
for addr, result, info in await asyncio.gather(*coros):
if result == "OK":
workers_info_ok.append(addr)
workers_info_ok[addr] = info
else:
workers_info_abort.append(addr)
workers_info_abort[addr] = info

finally:
if stop_amm:
Expand Down Expand Up @@ -7649,7 +7649,7 @@ async def _track_retire_worker(
close: bool,
remove: bool,
stimulus_id: str,
) -> tuple[str, Literal["OK", "no-recipients"]]:
) -> tuple[str, Literal["OK", "no-recipients"], dict]:
while not policy.done():
# Sleep 0.01s when there are 4 tasks or less
# Sleep 0.5s when there are 200 or more
Expand All @@ -7671,7 +7671,7 @@ async def _track_retire_worker(
f"Could not retire worker {ws.address!r}: unique data could not be "
f"moved to any other worker ({stimulus_id=!r})"
)
return ws.address, "no-recipients"
return ws.address, "no-recipients", ws.identity()

logger.debug(
f"All unique keys on worker {ws.address!r} have been replicated elsewhere"
Expand All @@ -7685,7 +7685,7 @@ async def _track_retire_worker(
self.close_worker(ws.address)

logger.info(f"Retired worker {ws.address!r} ({stimulus_id=!r})")
return ws.address, "OK"
return ws.address, "OK", ws.identity()

def add_keys(
self, worker: str, keys: Collection[Key] = (), stimulus_id: str | None = None
Expand Down
19 changes: 16 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4385,7 +4385,11 @@ async def test_scatter_type(c, s, a, b):
async def test_retire_workers_2(c, s, a, b):
[x] = await c.scatter([1], workers=a.address)

await s.retire_workers(workers=[a.address])
info = await s.retire_workers(workers=[a.address])
assert info
assert info[a.address]
assert "name" in info[a.address]
assert a.address not in s.workers
assert b.data == {x.key: 1}

assert {ws.address for ws in s.tasks[x.key].who_has} == {b.address}
Expand All @@ -4398,7 +4402,8 @@ async def test_retire_workers_2(c, s, a, b):
async def test_retire_many_workers(c, s, *workers):
futures = await c.scatter(list(range(100)))

await s.retire_workers(workers=[w.address for w in workers[:7]])
info = await s.retire_workers(workers=[w.address for w in workers[:7]])
assert len(info) == 7

results = await c.gather(futures)
assert results == list(range(100))
Expand Down Expand Up @@ -4760,7 +4765,15 @@ def test_recreate_task_sync(c):
@gen_cluster(client=True)
async def test_retire_workers(c, s, a, b):
assert set(s.workers) == {a.address, b.address}
await c.retire_workers(workers=[a.address], close_workers=True)
info = await c.retire_workers(workers=[a.address], close_workers=True)

# Deployment tooling is sometimes relying on this information to be returned
# This represents WorkerState.idenity() right now but may be slimmed down in
# the future
assert info
assert info[a.address]
assert "name" in info[a.address]

assert set(s.workers) == {b.address}

while a.status != Status.closed:
Expand Down

0 comments on commit e16728f

Please sign in to comment.