Skip to content

Commit

Permalink
Let variant boards manage their own stack
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasf committed Jul 31, 2024
1 parent dda9e8f commit 5d8e82d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
9 changes: 3 additions & 6 deletions chess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,7 @@ def from_chess960_pos(cls: Type[BaseBoardT], scharnagl: int) -> BaseBoardT:

BoardT = TypeVar("BoardT", bound="Board")

class _BoardState(Generic[BoardT]):
class _BoardState:

def __init__(self, board: BoardT) -> None:
self.pawns = board.pawns
Expand Down Expand Up @@ -1701,7 +1701,7 @@ def __init__(self: BoardT, fen: Optional[str] = STARTING_FEN, *, chess960: bool

self.ep_square = None
self.move_stack = []
self._stack: List[_BoardState[BoardT]] = []
self._stack: List[_BoardState] = []

if fen is None:
self.clear()
Expand Down Expand Up @@ -2304,9 +2304,6 @@ def is_repetition(self, count: int = 3) -> bool:

return False

def _board_state(self: BoardT) -> _BoardState[BoardT]:
return _BoardState(self)

def _push_capture(self, move: Move, capture_square: Square, piece_type: PieceType, was_promoted: bool) -> None:
pass

Expand Down Expand Up @@ -2335,7 +2332,7 @@ def push(self: BoardT, move: Move) -> None:
"""
# Push move and remember board state.
move = self._to_chess960(move)
board_state = self._board_state()
board_state = _BoardState(self)
self.castling_rights = self.clean_castling_rights() # Before pushing stack
self.move_stack.append(self._from_chess960(self.chess960, move.from_square, move.to_square, move.promotion, move.drop))
self._stack.append(board_state)
Expand Down
66 changes: 50 additions & 16 deletions chess/variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,12 @@ def status(self) -> chess.Status:

ThreeCheckBoardT = TypeVar("ThreeCheckBoardT", bound="ThreeCheckBoard")

class _ThreeCheckBoardState(Generic[ThreeCheckBoardT], chess._BoardState[ThreeCheckBoardT]):
def __init__(self, board: ThreeCheckBoardT) -> None:
super().__init__(board)
class _ThreeCheckBoardState:
def __init__(self, board: ThreeCheckBoard) -> None:
self.remaining_checks_w = board.remaining_checks[chess.WHITE]
self.remaining_checks_b = board.remaining_checks[chess.BLACK]

def restore(self, board: ThreeCheckBoardT) -> None:
super().restore(board)
def restore(self, board: ThreeCheckBoard) -> None:
board.remaining_checks[chess.WHITE] = self.remaining_checks_w
board.remaining_checks[chess.BLACK] = self.remaining_checks_b

Expand All @@ -698,8 +696,13 @@ class ThreeCheckBoard(chess.Board):

def __init__(self, fen: Optional[str] = starting_fen, chess960: bool = False) -> None:
self.remaining_checks = [3, 3]
self._three_check_stack: List[_ThreeCheckBoardState] = []
super().__init__(fen, chess960=chess960)

def clear_stack(self) -> None:
super().clear_stack()
self._three_check_stack.clear()

def reset_board(self) -> None:
super().reset_board()
self.remaining_checks[chess.WHITE] = 3
Expand All @@ -710,14 +713,17 @@ def clear_board(self) -> None:
self.remaining_checks[chess.WHITE] = 3
self.remaining_checks[chess.BLACK] = 3

def _board_state(self: ThreeCheckBoardT) -> _ThreeCheckBoardState[ThreeCheckBoardT]:
return _ThreeCheckBoardState(self)

def push(self, move: chess.Move) -> None:
self._three_check_stack.append(_ThreeCheckBoardState(self))
super().push(move)
if self.is_check():
self.remaining_checks[not self.turn] -= 1

def pop(self) -> chess.Move:
move = super().pop()
self._three_check_stack.pop().restore(self)
return move

def has_insufficient_material(self, color: chess.Color) -> bool:
# Any remaining piece can give check.
return not (self.occupied_co[color] & ~self.kings)
Expand Down Expand Up @@ -792,8 +798,19 @@ def _transposition_key(self) -> Hashable:
def copy(self: ThreeCheckBoardT, stack: Union[bool, int] = True) -> ThreeCheckBoardT:
board = super().copy(stack=stack)
board.remaining_checks = self.remaining_checks.copy()
if stack:
stack = len(self.move_stack) if stack is True else stack
board._three_check_stack = self._three_check_stack[-stack:]
return board

def root(self: ThreeCheckBoardT) -> ThreeCheckBoardT:
if self._three_check_stack:
board = super().root()
self._three_check_stack[0].restore(board)
return board
else:
return self.copy(stack=False)

def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT:
board = super().mirror()
board.remaining_checks[chess.WHITE] = self.remaining_checks[chess.BLACK]
Expand All @@ -803,14 +820,12 @@ def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT:

CrazyhouseBoardT = TypeVar("CrazyhouseBoardT", bound="CrazyhouseBoard")

class _CrazyhouseBoardState(Generic[CrazyhouseBoardT], chess._BoardState[CrazyhouseBoardT]):
def __init__(self, board: CrazyhouseBoardT) -> None:
super().__init__(board)
class _CrazyhouseBoardState:
def __init__(self, board: CrazyhouseBoard) -> None:
self.pockets_w = board.pockets[chess.WHITE].copy()
self.pockets_b = board.pockets[chess.BLACK].copy()

def restore(self, board: CrazyhouseBoardT) -> None:
super().restore(board)
def restore(self, board: CrazyhouseBoard) -> None:
board.pockets[chess.WHITE] = self.pockets_w
board.pockets[chess.BLACK] = self.pockets_b

Expand Down Expand Up @@ -870,8 +885,13 @@ class CrazyhouseBoard(chess.Board):

def __init__(self, fen: Optional[str] = starting_fen, chess960: bool = False) -> None:
self.pockets = [CrazyhousePocket(), CrazyhousePocket()]
self._crazyhouse_stack: List[_CrazyhouseBoardState] = []
super().__init__(fen, chess960=chess960)

def clear_stack(self) -> None:
super().clear_stack()
self._crazyhouse_stack.clear()

def reset_board(self) -> None:
super().reset_board()
self.pockets[chess.WHITE].reset()
Expand All @@ -882,10 +902,8 @@ def clear_board(self) -> None:
self.pockets[chess.WHITE].reset()
self.pockets[chess.BLACK].reset()

def _board_state(self: CrazyhouseBoardT) -> _CrazyhouseBoardState[CrazyhouseBoardT]:
return _CrazyhouseBoardState(self)

def push(self, move: chess.Move) -> None:
self._crazyhouse_stack.append(_CrazyhouseBoardState(self))
super().push(move)
if move.drop:
self.pockets[not self.turn].remove(move.drop)
Expand All @@ -896,6 +914,11 @@ def _push_capture(self, move: chess.Move, capture_square: chess.Square, piece_ty
else:
self.pockets[self.turn].add(piece_type)

def pop(self) -> chess.Move:
move = super().pop()
self._crazyhouse_stack.pop().restore(self)
return move

def _is_halfmoves(self, n: int) -> bool:
# No draw by 50-move rule or 75-move rule.
return False
Expand Down Expand Up @@ -1028,8 +1051,19 @@ def copy(self: CrazyhouseBoardT, stack: Union[bool, int] = True) -> CrazyhouseBo
board = super().copy(stack=stack)
board.pockets[chess.WHITE] = self.pockets[chess.WHITE].copy()
board.pockets[chess.BLACK] = self.pockets[chess.BLACK].copy()
if stack:
stack = len(self.move_stack) if stack is True else stack
board._crazyhouse_stack = self._crazyhouse_stack[-stack:]
return board

def root(self: CrazyhouseBoardT) -> CrazyhouseBoardT:
if self._crazyhouse_stack:
board = super().root()
self._crazyhouse_stack[0].restore(board)
return board
else:
return self.copy(stack=False)

def mirror(self: CrazyhouseBoardT) -> CrazyhouseBoardT:
board = super().mirror()
board.pockets[chess.WHITE] = self.pockets[chess.BLACK].copy()
Expand Down

0 comments on commit 5d8e82d

Please sign in to comment.