Skip to content

Commit

Permalink
wip cancel engine commands
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasf committed Dec 20, 2024
1 parent aa98f31 commit bf7d169
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
27 changes: 12 additions & 15 deletions chess/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) -
assert command.state == CommandState.NEW, command.state

if self.next_command is not None:
self.next_command.result.cancel()
self.next_command.finished.cancel()
self.next_command.set_finished()
self.next_command._cancel()

self.next_command = command

Expand All @@ -957,20 +955,14 @@ def previous_command_finished() -> None:
if self.command is not None:
cmd = self.command

def cancel_if_cancelled(result: asyncio.Future[T]) -> None:
if result.cancelled():
cmd._cancel()

cmd.result.add_done_callback(cancel_if_cancelled)
cmd._start()
if cmd.state == CommandState.NEW:
cmd._start()
cmd.add_finished_callback(previous_command_finished)

if self.command is None:
previous_command_finished()
elif not self.command.result.done():
self.command.result.cancel()
elif not self.command.result.cancelled():
if self.command is not None:
self.command._cancel()
else:
previous_command_finished()

return await command.result

Expand Down Expand Up @@ -1233,7 +1225,12 @@ def set_finished(self) -> None:
self._dispatch_finished()

def _cancel(self) -> None:
if self.state != CommandState.CANCELLING and self.state != CommandState.DONE:
if self.state == CommandState.NEW:
self.state = CommandState.DONE
self.result.cancel()
self.finished.cancel()
self._dispatch_finished()
elif self.state != CommandState.CANCELLING and self.state != CommandState.DONE:
assert self.state == CommandState.ACTIVE, self.state
self.state = CommandState.CANCELLING
self.cancel()
Expand Down
21 changes: 21 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,27 @@ def test_sf_quit(self):
with self.assertRaises(chess.engine.EngineTerminatedError), engine:
engine.ping()

@catchAndSkip(FileNotFoundError, "need stockfish")
def test_sf_cancel(self):
class TerminateTaskGroup(Exception):
pass

async def terminate_task_group():
await asyncio.sleep(0.001)
raise TerminateTaskGroup()

async def main():
try:
async with asyncio.TaskGroup() as group:
_, engine = await chess.engine.popen_uci("stockfish")
group.create_task(engine.analyse(chess.Board(), chess.engine.Limit()))
group.create_task(engine.analyse(chess.Board(), chess.engine.Limit()))
group.create_task(terminate_task_group())
except* TerminateTaskGroup:
pass

asyncio.run(main())

@catchAndSkip(FileNotFoundError, "need fairy-stockfish")
def test_fairy_sf_initialize(self):
with chess.engine.SimpleEngine.popen_uci("fairy-stockfish", setpgrp=True, debug=True):
Expand Down

0 comments on commit bf7d169

Please sign in to comment.