diff --git a/cmd/gg/main.go b/cmd/gg/main.go index 92ff624..5ac09f5 100644 --- a/cmd/gg/main.go +++ b/cmd/gg/main.go @@ -31,6 +31,7 @@ func main() { huh.NewOption("connect 4 (2 player)", "connect4"), huh.NewOption("pong (2 player)", "pong"), huh.NewOption("tictactoe (2 player)", "tictactoe"), + huh.NewOption("tictactoe (vs AI)", "tictactoe-ai"), ). Value(&game). Run() @@ -46,6 +47,8 @@ func main() { pong.Run() case "tictactoe": tictactoe.Run() + case "tictactoe-ai": + tictactoe.RunVsAi() case "dodger": dodger.Run() case "hangman": diff --git a/internal/app/tictactoe/engine/board.go b/internal/app/tictactoe/engine/board.go new file mode 100644 index 0000000..a06020d --- /dev/null +++ b/internal/app/tictactoe/engine/board.go @@ -0,0 +1,91 @@ +package engine + +import "fmt" + +const ( + P1 = 1 + P2 = -1 + EMPTY = 0 +) + +type Player = int + +type Board struct { + Size int + Cells []int +} + +func NewBoard(size int) *Board { + cells := make([]int, size*size) + for i := range cells { + cells[i] = EMPTY + } + + return &Board{ + Size: size, + Cells: cells, + } +} + +func (b *Board) GetCell(index int) (int, error) { + if index < 0 || index >= len(b.Cells) { + return 0, fmt.Errorf("invalid cell index: %d", index) + } + + return b.Cells[index], nil +} + +func (b *Board) SetCell(index int, player int) error { + if index < 0 || index >= len(b.Cells) { + return fmt.Errorf("invalid cell index: %d", index) + } + + b.Cells[index] = player + return nil +} + +func (b *Board) Load(cells []int) error { + if len(cells) != len(b.Cells) { + return fmt.Errorf("invalid cells length: %d", len(cells)) + } + + copy(b.Cells, cells) + return nil +} + +func (b *Board) GetRowCol(index int) (int, int, error) { + if index < 0 || index >= len(b.Cells) { + return 0, 0, fmt.Errorf("invalid cell index: %d", index) + } + + return index / b.Size, index % b.Size, nil +} + +func (b *Board) ChangePerspective() { + for i := range b.Cells { + b.Cells[i] *= -1 + } +} + +func (b *Board) Copy() *Board { + newBoard := NewBoard(b.Size) + copy(newBoard.Cells, b.Cells) + return newBoard +} + +func (b *Board) Print() { + for i := 0; i < b.Size; i++ { + for j := 0; j < b.Size; j++ { + cell, _ := b.GetCell(i*b.Size + j) + if cell == P1 { + fmt.Print("O") + } else if cell == P2 { + fmt.Print("X") + } else { + fmt.Print(".") + } + } + fmt.Println() + } + fmt.Println() +} diff --git a/internal/app/tictactoe/engine/engine.go b/internal/app/tictactoe/engine/engine.go new file mode 100644 index 0000000..6a4920c --- /dev/null +++ b/internal/app/tictactoe/engine/engine.go @@ -0,0 +1,145 @@ +package engine + +type Engine struct { + ai AI +} + +func NewEngine(depth int) *Engine { + engine := &Engine{} + mcts := NewMCTS(engine, depth) + engine.ai = mcts + + return engine +} + +func (e *Engine) GetLegalMoves(board *Board) []int { + var moves []int + for i, cell := range board.Cells { + if cell == EMPTY { + moves = append(moves, i) + } + } + return moves +} + +func (e *Engine) PlayMove(board *Board, player int, move int) error { + return board.SetCell(move, player) +} + +func (e *Engine) GetOpponent(player int) int { + return -player +} + +func (e *Engine) CheckGameOver(board *Board, lastMove int) (bool, int) { + if lastMove == -1 { + return false, 0 + } + + if e.CheckWin(board, lastMove) { + absValue := P1 * P2 * -1 + return true, absValue + } + + if len(e.GetLegalMoves(board)) == 0 { + return true, 0 + } + + return false, 0 +} + +func (e *Engine) CheckWin(board *Board, lastMove int) bool { + player, err := board.GetCell(lastMove) + if err != nil { + panic(err) + } + if player == EMPTY { + return false + } + + row, col, err := board.GetRowCol(lastMove) + if err != nil { + panic(err) + } + + if e.checkRow(board, row, player) { + return true + } + + if e.checkCol(board, col, player) { + return true + } + + if e.checkDiagonal(board, player) { + return true + } + + return false +} + +func (e *Engine) checkRow(board *Board, row, player int) bool { + for i := 0; i < board.Size; i++ { + cell, err := board.GetCell(row*board.Size + i) + if err != nil { + panic(err) + } + + if cell != player { + return false + } + } + + return true +} + +func (e *Engine) checkCol(board *Board, col, player int) bool { + for i := 0; i < board.Size; i++ { + cell, err := board.GetCell(i*board.Size + col) + if err != nil { + panic(err) + } + + if cell != player { + return false + } + } + + return true +} + +func (e *Engine) checkDiagonal(board *Board, player int) bool { + sum := 0 + // Left to right + for i := 0; i < board.Size; i++ { + cell, err := board.GetCell(i*board.Size + i) + if err != nil { + panic(err) + } + + if cell == player { + sum += player + } + } + + if sum == board.Size*player { + return true + } + + sum = 0 + // Right to left + for i := 0; i < board.Size; i++ { + cell, err := board.GetCell(i*board.Size + board.Size - i - 1) + if err != nil { + panic(err) + } + + if cell == player { + sum += player + } + } + + if sum == board.Size*player { + return true + } + + return false +} diff --git a/internal/app/tictactoe/engine/engine_test.go b/internal/app/tictactoe/engine/engine_test.go new file mode 100644 index 0000000..73e882c --- /dev/null +++ b/internal/app/tictactoe/engine/engine_test.go @@ -0,0 +1,157 @@ +package engine + +import ( + "testing" +) + +var testCases = []struct { + input []int + expected int +}{ + // #0: first row + { + input: []int{1, 1, 0, -1, 0, -1, 0, 0, 0}, + expected: 2, + }, + // #1: first col + { + input: []int{1, 0, 0, 1, -1, 0, 0, -1, 0}, + expected: 6, + }, + // #2: second col + { + input: []int{0, 1, 0, 0, 1, -1, 0, 0, -1}, + expected: 7, + }, + // #3: diagonal left (\) + { + input: []int{1, -1, 0, 0, 1, -1, 0, 0, 0}, + expected: 8, + }, + // #4: diagonal right (/) + { + input: []int{0, -1, 1, 0, 1, -1, 0, 0, 0}, + expected: 6, + }, + // #5: middle row + { + input: []int{0, 0, 0, 1, 0, 1, -1, -1, 0}, + expected: 4, + }, + // #6: last row + { + input: []int{0, 0, 0, -1, -1, 0, 1, 1, 0}, + expected: 8, + }, + // #7: last col + { + input: []int{0, 0, 1, -1, 0, 1, 0, 0, 0}, + expected: 8, + }, + // #8: No move + { + input: []int{1, -1, 1, -1, -1, 1, 1, 1, -1}, + expected: -1, // Indicates no move left to win + }, +} + +func TestEngine_Solve(t *testing.T) { + BOARD_SIZE := 3 + engine := NewEngine(DEPTH) + + for _, tc := range testCases { + t.Run("Testing solve", func(t *testing.T) { + board := NewBoard(BOARD_SIZE) + board.Load(tc.input) + + move := engine.ai.Solve(board) + + if move != tc.expected { + t.Errorf("expected move %d, got %d", tc.expected, move) + } + }) + } +} + +func TestEngine_CheckWin(t *testing.T) { + BOARD_SIZE := 3 + board := NewBoard(BOARD_SIZE) + engine := NewEngine(DEPTH) + + t.Run("Empty board", func(t *testing.T) { + if engine.CheckWin(board, 0) { + t.Error("expected no win") + } + }) + + t.Run("Horizontal win", func(t *testing.T) { + board.SetCell(0, P1) + board.SetCell(1, P1) + board.SetCell(2, P1) + if !engine.CheckWin(board, 2) { + t.Error("expected win") + } + }) + + t.Run("Vertical win", func(t *testing.T) { + board = NewBoard(BOARD_SIZE) + board.SetCell(0, P1) + board.SetCell(3, P1) + board.SetCell(6, P1) + if !engine.CheckWin(board, 6) { + t.Error("expected win") + } + }) + + t.Run("Left diagonal win", func(t *testing.T) { + board = NewBoard(BOARD_SIZE) + board.SetCell(0, P1) + board.SetCell(4, P1) + board.SetCell(8, P1) + if !engine.CheckWin(board, 8) { + t.Error("expected win") + } + }) + + t.Run("Right diagonal win", func(t *testing.T) { + board = NewBoard(BOARD_SIZE) + board.SetCell(2, P1) + board.SetCell(4, P1) + board.SetCell(6, P1) + if !engine.CheckWin(board, 6) { + t.Error("expected win") + } + }) +} + +func TestEngine_GetLegalMoves(t *testing.T) { + BOARD_SIZE := 4 + board := NewBoard(BOARD_SIZE) + engine := NewEngine(DEPTH) + moves := []int{} + + t.Run("Empty board", func(t *testing.T) { + moves = engine.GetLegalMoves(board) + if len(moves) != BOARD_SIZE*BOARD_SIZE { + t.Errorf("expected %d moves, got %d", BOARD_SIZE*BOARD_SIZE, len(moves)) + } + }) + + t.Run("Full board", func(t *testing.T) { + for _, move := range moves { + board.SetCell(move, P1) + } + moves := engine.GetLegalMoves(board) + if len(moves) != 0 { + t.Errorf("expected 0 moves, got %d", len(moves)) + } + }) + + t.Run("One empty cell", func(t *testing.T) { + board.SetCell(0, EMPTY) + moves = engine.GetLegalMoves(board) + if len(moves) != 1 { + t.Errorf("expected 1 move, got %d", len(moves)) + } + }) +} diff --git a/internal/app/tictactoe/engine/mcts.go b/internal/app/tictactoe/engine/mcts.go new file mode 100644 index 0000000..c120932 --- /dev/null +++ b/internal/app/tictactoe/engine/mcts.go @@ -0,0 +1,217 @@ +package engine + +import ( + "fmt" + "math" + "math/rand/v2" +) + +const ( + C_VALUE = 1.41 + DEPTH = 100 +) + +type AI interface { + // Returns the best move for the current player + Solve(board *Board) int +} + +type GameEngine interface { + // Returns gameover (bool) & a value if there's a winner + CheckGameOver(board *Board, lastMove int) (bool, int) + // Get all available moves + GetLegalMoves(board *Board) []int + // Get the opponent of a player + GetOpponent(player int) int + // Play a move on the board + PlayMove(board *Board, player int, move int) error +} + +type mcts struct { + engine GameEngine + depth int +} + +func NewMCTS(engine GameEngine, depth int) AI { + return &mcts{engine, depth} +} + +func (m *mcts) Solve(board *Board) int { + root := newNode(m.engine, board, -1, nil) + + for i := 0; i < m.depth; i++ { + node := root + for node.isExpanded() { + child, err := node.selectChild() + if err != nil { + panic(err) + } + node = child + } + + isOver, value := m.engine.CheckGameOver(node.board, node.move) + value = m.engine.GetOpponent(value) + + if !isOver { + child, err := node.expand() + if err != nil { + } else { + value = child.simulate() + node = child + } + } + + node.backpropagate(value) + } + + visits := make([]float64, board.Size*board.Size) + dist := make([]float64, board.Size*board.Size) + sum := 0.0 + + for _, child := range root.children { + visits[child.move] = float64(child.visitCount) + sum += visits[child.move] + } + + for i, visit := range visits { + dist[i] = visit / sum + } + + bestMove := -1 + bestValue := 0.0 + + for i, value := range dist { + if value > bestValue { + bestMove = i + bestValue = value + } + } + + return bestMove +} + +type node struct { + engine GameEngine + board *Board + move int + parent *node + children []*node + legalMoves []int + valueSum int + visitCount int +} + +func newNode(engine GameEngine, board *Board, move int, parent *node) *node { + legalMoves := engine.GetLegalMoves(board) + + return &node{ + engine: engine, + board: board, + move: move, + parent: parent, + children: []*node{}, + legalMoves: legalMoves, + valueSum: 0, + visitCount: 0, + } +} + +// Simulate all moves until game is over; +// Returns winner +func (n *node) simulate() int { + isOver, winner := n.engine.CheckGameOver(n.board, n.move) + if isOver { + return n.engine.GetOpponent(winner) + } + + board := n.board.Copy() + player := P1 + result := 0 + + for { + move, _, err := popRandomMove(n.engine.GetLegalMoves(board)) + if err != nil { + break + } + + n.engine.PlayMove(board, player, move) + isOver, winner = n.engine.CheckGameOver(board, move) + if isOver { + result = winner + break + } + + player = n.engine.GetOpponent(player) + } + + return result +} + +func (n *node) expand() (*node, error) { + move, rest, err := popRandomMove(n.legalMoves) + if err != nil { + return nil, err + } + + n.legalMoves = rest + + board := n.board.Copy() + n.engine.PlayMove(board, P1, move) + + // Every node considers itself as p1 + board.ChangePerspective() + child := newNode(n.engine, board, move, n) + n.children = append(n.children, child) + + return child, nil +} + +func (n *node) backpropagate(value int) { + n.visitCount++ + n.valueSum += value + + if n.parent != nil { + n.parent.backpropagate(n.engine.GetOpponent(value)) + } +} + +// Get next child with highest UCB +func (n *node) selectChild() (*node, error) { + if len(n.children) == 0 { + return nil, fmt.Errorf("No child nodes") + } + + var selected *node + var bestValue float64 = math.Inf(-1) + + for _, child := range n.children { + ucb := n.getUCB(child) + if selected == nil || ucb > bestValue { + selected = child + bestValue = ucb + } + } + + return selected, nil +} + +func popRandomMove(legalMoves []int) (int, []int, error) { + if len(legalMoves) == 0 { + return -1, legalMoves, fmt.Errorf("No legal moves") + } + + index := rand.IntN(len(legalMoves)) + move := legalMoves[index] + legalMoves = append(legalMoves[:index], legalMoves[index+1:]...) + + return move, legalMoves, nil +} + +func (n *node) isExpanded() bool { + return len(n.children) > 0 && len(n.legalMoves) == 0 +} + +func (n *node) getUCB(child *node) float64 { + q := 1 - ((float64(child.valueSum)/float64(child.visitCount))+1)/2 + return q + C_VALUE*math.Sqrt(math.Log(float64(n.visitCount))/float64(child.visitCount)) +} diff --git a/internal/app/tictactoe/engine/model.go b/internal/app/tictactoe/engine/model.go new file mode 100644 index 0000000..75901f1 --- /dev/null +++ b/internal/app/tictactoe/engine/model.go @@ -0,0 +1,267 @@ +package engine + +import ( + "fmt" + "log" + "math/rand/v2" + "strconv" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type Game struct { + board *Board + engine *Engine + turn Player + winner Player + gameover bool + round int + scoreP1 int + scoreP2 int + colors map[string]lipgloss.Style +} + +const ( + size = 3 + yellow = "#FF9E3B" + dark = "#3C3A32" + gray = "#717C7C" + light = "#DCD7BA" + red = "#E63D3D" + green = "#98BB6C" + blue = "#7E9CD8" +) + +func GetModel() tea.Model { + board := NewBoard(size) + engine := NewEngine(100) + + defaultStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#f9f6f2")) + c := func(s string) lipgloss.Color { + return lipgloss.Color(s) + } + + return Game{ + board: board, + engine: engine, + turn: P1, + winner: 0, + round: 1, + scoreP1: 0, + scoreP2: 0, + gameover: false, + colors: map[string]lipgloss.Style{ + "board": defaultStyle.Background(c(dark)), + "text": defaultStyle.Background(c(dark)).Foreground(c(light)), + "line": defaultStyle.Background(c(dark)).Foreground(c(gray)), + "p1": defaultStyle.Background(c(dark)).Foreground(c(yellow)), + "p2": defaultStyle.Background(c(dark)).Foreground(c(red)), + "hi": defaultStyle.Foreground(c(green)), + "status": defaultStyle.Foreground(c(blue)), + }, + } +} + +func (g Game) Init() tea.Cmd { + return nil +} + +type gameOverMsg struct{ winner Player } +type nextTurnMsg struct{} +type aiTurnMsg struct{} + +func (g Game) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case aiTurnMsg: + time.Sleep(time.Millisecond * 200) + return g, aiMoveCmd(&g) + + case nextTurnMsg: + g.turn = g.engine.GetOpponent(g.turn) + if g.turn == P2 { + return g, func() tea.Msg { + return aiTurnMsg{} + } + } + return g, nil + + case gameOverMsg: + g.winner = msg.winner + g.turn = g.engine.GetOpponent(g.turn) + g.gameover = true + if g.winner == P1 { + g.scoreP1 += 1 + } else if g.winner == P2 { + g.scoreP2 += 1 + } + return g, nil + + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c", "q": + return g, tea.Quit + + case "n", "N": + g.nextMatch() + if g.turn == P2 { + return g, aiMoveCmd(&g) + } + return g, nil + + case "1", "2", "3", "4", "5", "6", "7", "8", "9": + // There shouldn't be an error, because this is only called for integers + index, _ := strconv.Atoi(msg.String()) + index -= 1 + cell, err := g.board.GetCell(index) + if err != nil { + log.Fatal(err) + } + + if cell == EMPTY { + g.engine.PlayMove(g.board, P1, index) + + isover, win := g.engine.CheckGameOver(g.board, index) + + if isover { + if win > 0 { + g.winner = g.turn + // Update score + if g.winner == P1 { + g.scoreP1 += 1 + } else if g.winner == P2 { + g.scoreP2 += 1 + } + } else { + g.winner = 0 + } + + g.gameover = true + g.turn = g.engine.GetOpponent(g.turn) + return g, nil + } + + return g, func() tea.Msg { + return nextTurnMsg{} + } + } + } + } + + return g, nil +} + +// Handle AI turn +func aiMoveCmd(g *Game) tea.Cmd { + return func() tea.Msg { + rollout := g.board.Copy() + move := g.engine.ai.Solve(rollout) + + g.engine.PlayMove(g.board, P2, move) + + isover, win := g.engine.CheckGameOver(g.board, move) + if isover { + if win > 0 { + return gameOverMsg{winner: P2} + } + + return gameOverMsg{winner: 0} + } + + return nextTurnMsg{} + } +} + +func (g *Game) nextMatch() { + g.board = NewBoard(size) + g.gameover = false + g.winner = 0 + g.round += 1 + + randLvl := rand.IntN(50) + 50 + g.engine = NewEngine(randLvl) +} + +func printCell(board *Board, index int) string { + cell, err := board.GetCell(index) + if err != nil { + panic(err) + } + + sign := printPlayer(cell) + + if sign == "" { + return fmt.Sprintf("%d", index+1) + } + + return sign +} + +func printPlayer(cell int) string { + if cell == P1 { + return "O" + } else if cell == P2 { + return "X" + } + + return "" +} + +func (g Game) View() string { + renderCell := func(index int) string { + cell, _ := g.board.GetCell(index) + var style lipgloss.Style + content := "" + + switch cell { + case P1: + style = g.colors["p1"] + content = "O" + case P2: + style = g.colors["p2"] + content = "X" + default: // Empty cell, show index + style = g.colors["text"] + content = strconv.Itoa(index + 1) + } + + return style.Render(content) + } + winner := "\n" + if g.gameover { + winner = "" + if g.winner != 0 { + winner += g.colors["hi"].Render(" Winner: ") + winner += g.colors["hi"].Render(printPlayer(g.winner)) + winner += "\n" + } else { + winner += g.colors["hi"].Render(" Draw!") + winner += "\n" + } + } + + board := "" + for i := 0; i < 3; i++ { + board += g.colors["board"].Render(" ") + board += renderCell(i * 3) + board += g.colors["line"].Render(" | ") + board += renderCell(i*3 + 1) + board += g.colors["line"].Render(" | ") + board += renderCell(i*3 + 2) + board += g.colors["board"].Render(" ") + + if i < 2 { + board += "\n" + g.colors["line"].Render("---+---+---") + "\n" + } + } + + status := g.colors["status"].Render(fmt.Sprintf("\n#%d:(W%d-L%d)", g.round, g.scoreP1, g.scoreP2)) + if g.gameover { + status += g.colors["status"].Render("> [Q]uit - [N]ext match") + } else { + status += g.colors["status"].Render(fmt.Sprintf("> %s's turn", printPlayer(g.turn))) + } + + return winner + board + status +} diff --git a/internal/app/tictactoe/tictactoe.go b/internal/app/tictactoe/tictactoe.go index 02c018d..b3a711f 100644 --- a/internal/app/tictactoe/tictactoe.go +++ b/internal/app/tictactoe/tictactoe.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" + "github.com/Kaamkiya/gg/internal/app/tictactoe/engine" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" ) @@ -129,3 +130,11 @@ func Run() { fmt.Printf("%c wins\n", winner) } + +func RunVsAi() { + p := tea.NewProgram(engine.GetModel()) + + if _, err := p.Run(); err != nil { + panic(err) + } +}