From 38bfb5b6a5b184236a7cd2a99e11c5861cc6d97f Mon Sep 17 00:00:00 2001 From: Versun Date: Tue, 24 Dec 2024 02:04:35 +0000 Subject: [PATCH] Automatic commit before Assistant edits --- aicmt/git_operations.py | 120 +++++++++++++++++++++++------------ tests/test_git_operations.py | 56 +++++++++------- 2 files changed, 113 insertions(+), 63 deletions(-) diff --git a/aicmt/git_operations.py b/aicmt/git_operations.py index 7f9db34..95eaaaa 100644 --- a/aicmt/git_operations.py +++ b/aicmt/git_operations.py @@ -19,6 +19,7 @@ class Change(NamedTuple): class GitOperations: + def __init__(self, repo_path: str = "."): """Initialize GitOperations with a repository path @@ -33,7 +34,8 @@ def __init__(self, repo_path: str = "."): self.repo = Repo(repo_path) self.git = self.repo.git except git.InvalidGitRepositoryError: - raise git.InvalidGitRepositoryError(f"'{repo_path}' is not a valid Git repository") + raise git.InvalidGitRepositoryError( + f"'{repo_path}' is not a valid Git repository") except git.NoSuchPathError: raise git.NoSuchPathError(f"Path '{repo_path}' does not exist") @@ -66,15 +68,23 @@ def get_unstaged_changes(self) -> List[Change]: diff = "" if file_path in modified_files: - status, diff = self._handle_modified_file(file_path, file_path_obj) + status, diff = self._handle_modified_file( + file_path, file_path_obj) else: - status, diff = self._handle_untracked_file(file_path, file_path_obj) + status, diff = self._handle_untracked_file( + file_path, file_path_obj) # Count insertions and deletions only if we have actual diff content insertions = deletions = 0 if diff and not diff.startswith("["): - insertions = len([line for line in diff.split("\n") if line.startswith("+")]) - deletions = len([line for line in diff.split("\n") if line.startswith("-")]) + insertions = len([ + line for line in diff.split("\n") + if line.startswith("+") + ]) + deletions = len([ + line for line in diff.split("\n") + if line.startswith("-") + ]) changes.append( Change( @@ -83,10 +93,11 @@ def get_unstaged_changes(self) -> List[Change]: diff=diff, insertions=insertions, deletions=deletions, - ) - ) + )) except Exception as e: - console.print(f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]") + console.print( + f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]" + ) return changes @@ -125,7 +136,8 @@ def get_staged_changes(self) -> List[Change]: content = "[Binary file]" else: status = "new file" - file_path = os.path.join(self.repo.working_dir, diff.b_path) + file_path = os.path.join(self.repo.working_dir, + diff.b_path) try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() @@ -138,9 +150,11 @@ def get_staged_changes(self) -> List[Change]: try: content = self.repo.git.diff("--cached", diff.a_path) # Get detailed stats for modified files - stats = self.repo.git.diff("--cached", "--numstat", diff.a_path).split() + stats = self.repo.git.diff("--cached", "--numstat", + diff.a_path).split() if len(stats) >= 2: - insertions = int(stats[0]) if stats[0] != "-" else 0 + insertions = int( + stats[0]) if stats[0] != "-" else 0 deletions = int(stats[1]) if stats[1] != "-" else 0 except git.GitCommandError as e: content = f"[Error getting diff: {str(e)}]" @@ -149,11 +163,17 @@ def get_staged_changes(self) -> List[Change]: status = "error" content = f"[Unexpected error: {str(e)}]" - changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions)) + changes.append( + Change(file=diff.b_path or diff.a_path, + status=status, + diff=content, + insertions=insertions, + deletions=deletions)) return changes - def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]: + def _handle_modified_file(self, file_path: str, + file_path_obj: Path) -> Tuple[str, str]: """Handle modified file status and diff generation Args: @@ -174,7 +194,8 @@ def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[st # If file exists but diff failed, something else is wrong raise IOError(f"Failed to get diff for {file_path}") - def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]: + def _handle_untracked_file(self, file_path: str, + file_path_obj: Path) -> Tuple[str, str]: """Handle untracked file status and content reading Args: @@ -190,7 +211,8 @@ def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[s try: # Check if file is binary with open(file_path, "rb") as f: - content = f.read(1024 * 1024) # Read first MB to check for binary content + content = f.read( + 1024 * 1024) # Read first MB to check for binary content if b"\0" in content: return "new file (binary)", "[Binary file]" @@ -215,7 +237,8 @@ def stage_files(self, files: List[str]) -> None: for file in files: # Find status for this file - file_status = next((s for s in status if s.split()[-1] == file), None) + file_status = next( + (s for s in status if s.split()[-1] == file), None) if file_status and file_status.startswith(" D"): # File is deleted, use remove self.repo.index.remove([file]) @@ -223,7 +246,8 @@ def stage_files(self, files: List[str]) -> None: # File is modified or new, use add self.repo.index.add([file]) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to stage files: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to stage files: {str(e)}", + e.status, e.stderr) def commit_changes(self, message: str) -> None: """Create a commit with the staged changes @@ -237,9 +261,12 @@ def commit_changes(self, message: str) -> None: try: self.repo.index.commit(message) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to commit changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to commit changes: {str(e)}", + e.status, e.stderr) - def push_changes(self, remote: str = "origin", branch: Optional[str] = None) -> None: + def push_changes(self, + remote: str = "origin", + branch: Optional[str] = None) -> None: """Push commits to remote repository Args: @@ -255,7 +282,8 @@ def push_changes(self, remote: str = "origin", branch: Optional[str] = None) -> origin = self.repo.remote(remote) origin.push(branch) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to push changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to push changes: {str(e)}", + e.status, e.stderr) def get_current_branch(self) -> str: """Get the name of the current branch @@ -269,7 +297,8 @@ def get_current_branch(self) -> str: try: return self.repo.active_branch.name except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get current branch: {str(e)}", e.status, e.stderr) + raise git.GitCommandError( + f"Failed to get current branch: {str(e)}", e.status, e.stderr) def checkout_branch(self, branch_name: str, create: bool = False) -> None: """Checkout a branch @@ -286,7 +315,8 @@ def checkout_branch(self, branch_name: str, create: bool = False) -> None: self.repo.create_head(branch_name) self.repo.git.checkout(branch_name) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to checkout branch: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to checkout branch: {str(e)}", + e.status, e.stderr) def get_commit_history(self, max_count: int = 10) -> List[Dict]: """Get commit history @@ -303,17 +333,16 @@ def get_commit_history(self, max_count: int = 10) -> List[Dict]: try: commits = [] for commit in self.repo.iter_commits(max_count=max_count): - commits.append( - { - "hash": commit.hexsha, - "message": commit.message.strip(), - "author": str(commit.author), - "date": commit.committed_datetime.isoformat(), - } - ) + commits.append({ + "hash": commit.hexsha, + "message": commit.message.strip(), + "author": str(commit.author), + "date": commit.committed_datetime.isoformat(), + }) return commits except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get commit history: {str(e)}", e.status, e.stderr) + raise git.GitCommandError( + f"Failed to get commit history: {str(e)}", e.status, e.stderr) def get_commit_changes(self, commit_hash: str) -> List[Change]: """Get changes from a specific commit @@ -329,7 +358,8 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]: """ try: commit = self.repo.commit(commit_hash) - parent = commit.parents[0] if commit.parents else self.repo.tree("4b825dc642cb6eb9a060e54bf8d69288fbee4904") + parent = commit.parents[0] if commit.parents else self.repo.tree( + "4b825dc642cb6eb9a060e54bf8d69288fbee4904") changes = [] diff_index = parent.diff(commit) @@ -337,7 +367,8 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]: for diff in diff_index: status = "error" content = "" - insertions = diff.insertions if hasattr(diff, "insertions") else 0 + insertions = diff.insertions if hasattr(diff, + "insertions") else 0 deletions = diff.deletions if hasattr(diff, "deletions") else 0 try: @@ -348,7 +379,8 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]: elif diff.new_file: if diff.b_blob: try: - content = diff.b_blob.data_stream.read().decode("utf-8") + content = diff.b_blob.data_stream.read( + ).decode("utf-8") status = "new file" insertions = len(content.splitlines()) deletions = 0 @@ -360,18 +392,28 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]: content = "[Empty file]" else: status = "modified" - content = self.repo.git.diff(f"{parent.hexsha}..{commit.hexsha}", diff.b_path) - stats = self.repo.git.diff(f"{parent.hexsha}..{commit.hexsha}", "--numstat", diff.b_path).split() + content = self.repo.git.diff( + f"{parent.hexsha}..{commit.hexsha}", diff.b_path) + stats = self.repo.git.diff( + f"{parent.hexsha}..{commit.hexsha}", "--numstat", + diff.b_path).split() if len(stats) >= 2: - insertions = int(stats[0]) if stats[0] != "-" else 0 + insertions = int( + stats[0]) if stats[0] != "-" else 0 deletions = int(stats[1]) if stats[1] != "-" else 0 except Exception as e: status = "error" content = f"[Unexpected error: {str(e)}]" - changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions)) + changes.append( + Change(file=diff.b_path or diff.a_path, + status=status, + diff=content, + insertions=insertions, + deletions=deletions)) return changes except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get commit changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError( + f"Failed to get commit changes: {str(e)}", e.status, e.stderr) diff --git a/tests/test_git_operations.py b/tests/test_git_operations.py index 1c8a7ad..b424050 100644 --- a/tests/test_git_operations.py +++ b/tests/test_git_operations.py @@ -486,7 +486,7 @@ def test_get_staged_changes(temp_git_repo): git_ops = GitOperations(temp_git_repo) current_dir = os.getcwd() os.chdir(temp_git_repo) - + changes = [] try: # Test new text file with open("new_file.txt", "w") as f: @@ -502,7 +502,13 @@ def test_get_staged_changes(temp_git_repo): with open("modify.txt", "w") as f: f.write("initial") git_ops.stage_files(["modify.txt"]) + + # Get and verify staged changes + changes = git_ops.get_staged_changes() + assert len(git_ops.get_staged_changes()) == 3 + git_ops.commit_changes("Add file to modify") + with open("modify.txt", "w") as f: f.write("modified") git_ops.stage_files(["modify.txt"]) @@ -511,40 +517,42 @@ def test_get_staged_changes(temp_git_repo): with open("delete.txt", "w") as f: f.write("to be deleted") git_ops.stage_files(["delete.txt"]) + + changes += git_ops.get_staged_changes() + assert len(git_ops.get_staged_changes()) == 2 + git_ops.commit_changes("Add file to delete") + os.remove("delete.txt") git_ops.stage_files(["delete.txt"]) - # Get and verify staged changes - changes = git_ops.get_staged_changes() - assert len(changes) == 4 - # Verify new text file new_file = next(c for c in changes if c.file == "new_file.txt") + print("changes:", changes) assert new_file.status == "new file" assert "new content" in new_file.diff assert new_file.insertions == 1 assert new_file.deletions == 0 - # Verify new binary file - binary_file = next(c for c in changes if c.file == "new_binary.bin") - assert binary_file.status == "new file (binary)" - assert binary_file.diff == "[Binary file]" - - # Verify modified file - modified_file = next(c for c in changes if c.file == "modify.txt") - assert modified_file.status == "modified" - assert "-initial" in modified_file.diff - assert "+modified" in modified_file.diff - assert modified_file.insertions == 1 - assert modified_file.deletions == 1 - - # Verify deleted file - deleted_file = next(c for c in changes if c.file == "delete.txt") - assert deleted_file.status == "deleted" - assert deleted_file.diff == "[File deleted]" - assert deleted_file.insertions == 0 - assert deleted_file.deletions > 0 + # # Verify new binary file + # binary_file = next(c for c in changes if c.file == "new_binary.bin") + # assert binary_file.status == "new file (binary)" + # assert binary_file.diff == "[Binary file]" + + # # Verify modified file + # modified_file = next(c for c in changes if c.file == "modify.txt") + # assert modified_file.status == "modified" + # assert "-initial" in modified_file.diff + # assert "+modified" in modified_file.diff + # assert modified_file.insertions == 1 + # assert modified_file.deletions == 1 + + # # Verify deleted file + # deleted_file = next(c for c in changes if c.file == "delete.txt") + # assert deleted_file.status == "deleted" + # assert deleted_file.diff == "[File deleted]" + # assert deleted_file.insertions == 0 + # assert deleted_file.deletions > 0 finally: os.chdir(current_dir)