-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: Add mcts algo * feat: Add tictactoe ai model * feat: Print pretty board * feat: Add game state
- Loading branch information
1 parent
3849aad
commit 3085b16
Showing
7 changed files
with
889 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package engine | ||
|
||
import "fmt" | ||
|
||
const ( | ||
P1 = 1 | ||
P2 = -1 | ||
EMPTY = 0 | ||
) | ||
|
||
type Player = int | ||
|
||
type Board struct { | ||
Size int | ||
Cells []int | ||
} | ||
|
||
func NewBoard(size int) *Board { | ||
cells := make([]int, size*size) | ||
for i := range cells { | ||
cells[i] = EMPTY | ||
} | ||
|
||
return &Board{ | ||
Size: size, | ||
Cells: cells, | ||
} | ||
} | ||
|
||
func (b *Board) GetCell(index int) (int, error) { | ||
if index < 0 || index >= len(b.Cells) { | ||
return 0, fmt.Errorf("invalid cell index: %d", index) | ||
} | ||
|
||
return b.Cells[index], nil | ||
} | ||
|
||
func (b *Board) SetCell(index int, player int) error { | ||
if index < 0 || index >= len(b.Cells) { | ||
return fmt.Errorf("invalid cell index: %d", index) | ||
} | ||
|
||
b.Cells[index] = player | ||
return nil | ||
} | ||
|
||
func (b *Board) Load(cells []int) error { | ||
if len(cells) != len(b.Cells) { | ||
return fmt.Errorf("invalid cells length: %d", len(cells)) | ||
} | ||
|
||
copy(b.Cells, cells) | ||
return nil | ||
} | ||
|
||
func (b *Board) GetRowCol(index int) (int, int, error) { | ||
if index < 0 || index >= len(b.Cells) { | ||
return 0, 0, fmt.Errorf("invalid cell index: %d", index) | ||
} | ||
|
||
return index / b.Size, index % b.Size, nil | ||
} | ||
|
||
func (b *Board) ChangePerspective() { | ||
for i := range b.Cells { | ||
b.Cells[i] *= -1 | ||
} | ||
} | ||
|
||
func (b *Board) Copy() *Board { | ||
newBoard := NewBoard(b.Size) | ||
copy(newBoard.Cells, b.Cells) | ||
return newBoard | ||
} | ||
|
||
func (b *Board) Print() { | ||
for i := 0; i < b.Size; i++ { | ||
for j := 0; j < b.Size; j++ { | ||
cell, _ := b.GetCell(i*b.Size + j) | ||
if cell == P1 { | ||
fmt.Print("O") | ||
} else if cell == P2 { | ||
fmt.Print("X") | ||
} else { | ||
fmt.Print(".") | ||
} | ||
} | ||
fmt.Println() | ||
} | ||
fmt.Println() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package engine | ||
|
||
type Engine struct { | ||
ai AI | ||
} | ||
|
||
func NewEngine(depth int) *Engine { | ||
engine := &Engine{} | ||
mcts := NewMCTS(engine, depth) | ||
engine.ai = mcts | ||
|
||
return engine | ||
} | ||
|
||
func (e *Engine) GetLegalMoves(board *Board) []int { | ||
var moves []int | ||
for i, cell := range board.Cells { | ||
if cell == EMPTY { | ||
moves = append(moves, i) | ||
} | ||
} | ||
return moves | ||
} | ||
|
||
func (e *Engine) PlayMove(board *Board, player int, move int) error { | ||
return board.SetCell(move, player) | ||
} | ||
|
||
func (e *Engine) GetOpponent(player int) int { | ||
return -player | ||
} | ||
|
||
func (e *Engine) CheckGameOver(board *Board, lastMove int) (bool, int) { | ||
if lastMove == -1 { | ||
return false, 0 | ||
} | ||
|
||
if e.CheckWin(board, lastMove) { | ||
absValue := P1 * P2 * -1 | ||
return true, absValue | ||
} | ||
|
||
if len(e.GetLegalMoves(board)) == 0 { | ||
return true, 0 | ||
} | ||
|
||
return false, 0 | ||
} | ||
|
||
func (e *Engine) CheckWin(board *Board, lastMove int) bool { | ||
player, err := board.GetCell(lastMove) | ||
if err != nil { | ||
panic(err) | ||
} | ||
if player == EMPTY { | ||
return false | ||
} | ||
|
||
row, col, err := board.GetRowCol(lastMove) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if e.checkRow(board, row, player) { | ||
return true | ||
} | ||
|
||
if e.checkCol(board, col, player) { | ||
return true | ||
} | ||
|
||
if e.checkDiagonal(board, player) { | ||
return true | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (e *Engine) checkRow(board *Board, row, player int) bool { | ||
for i := 0; i < board.Size; i++ { | ||
cell, err := board.GetCell(row*board.Size + i) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if cell != player { | ||
return false | ||
} | ||
} | ||
|
||
return true | ||
} | ||
|
||
func (e *Engine) checkCol(board *Board, col, player int) bool { | ||
for i := 0; i < board.Size; i++ { | ||
cell, err := board.GetCell(i*board.Size + col) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if cell != player { | ||
return false | ||
} | ||
} | ||
|
||
return true | ||
} | ||
|
||
func (e *Engine) checkDiagonal(board *Board, player int) bool { | ||
sum := 0 | ||
// Left to right | ||
for i := 0; i < board.Size; i++ { | ||
cell, err := board.GetCell(i*board.Size + i) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if cell == player { | ||
sum += player | ||
} | ||
} | ||
|
||
if sum == board.Size*player { | ||
return true | ||
} | ||
|
||
sum = 0 | ||
// Right to left | ||
for i := 0; i < board.Size; i++ { | ||
cell, err := board.GetCell(i*board.Size + board.Size - i - 1) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if cell == player { | ||
sum += player | ||
} | ||
} | ||
|
||
if sum == board.Size*player { | ||
return true | ||
} | ||
|
||
return false | ||
} |
Oops, something went wrong.