Skip to content

Commit

Permalink
Automatic commit before Assistant edits
Browse files Browse the repository at this point in the history
  • Loading branch information
versun committed Dec 24, 2024
1 parent 8a5607f commit 38bfb5b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 63 deletions.
120 changes: 81 additions & 39 deletions aicmt/git_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Change(NamedTuple):


class GitOperations:

def __init__(self, repo_path: str = "."):
"""Initialize GitOperations with a repository path
Expand All @@ -33,7 +34,8 @@ def __init__(self, repo_path: str = "."):
self.repo = Repo(repo_path)
self.git = self.repo.git
except git.InvalidGitRepositoryError:
raise git.InvalidGitRepositoryError(f"'{repo_path}' is not a valid Git repository")
raise git.InvalidGitRepositoryError(
f"'{repo_path}' is not a valid Git repository")
except git.NoSuchPathError:
raise git.NoSuchPathError(f"Path '{repo_path}' does not exist")

Expand Down Expand Up @@ -66,15 +68,23 @@ def get_unstaged_changes(self) -> List[Change]:
diff = ""

if file_path in modified_files:
status, diff = self._handle_modified_file(file_path, file_path_obj)
status, diff = self._handle_modified_file(
file_path, file_path_obj)
else:
status, diff = self._handle_untracked_file(file_path, file_path_obj)
status, diff = self._handle_untracked_file(
file_path, file_path_obj)

# Count insertions and deletions only if we have actual diff content
insertions = deletions = 0
if diff and not diff.startswith("["):
insertions = len([line for line in diff.split("\n") if line.startswith("+")])
deletions = len([line for line in diff.split("\n") if line.startswith("-")])
insertions = len([
line for line in diff.split("\n")
if line.startswith("+")
])
deletions = len([
line for line in diff.split("\n")
if line.startswith("-")
])

changes.append(
Change(
Expand All @@ -83,10 +93,11 @@ def get_unstaged_changes(self) -> List[Change]:
diff=diff,
insertions=insertions,
deletions=deletions,
)
)
))
except Exception as e:
console.print(f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]")
console.print(
f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]"
)

return changes

Expand Down Expand Up @@ -125,7 +136,8 @@ def get_staged_changes(self) -> List[Change]:
content = "[Binary file]"
else:
status = "new file"
file_path = os.path.join(self.repo.working_dir, diff.b_path)
file_path = os.path.join(self.repo.working_dir,
diff.b_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
Expand All @@ -138,9 +150,11 @@ def get_staged_changes(self) -> List[Change]:
try:
content = self.repo.git.diff("--cached", diff.a_path)
# Get detailed stats for modified files
stats = self.repo.git.diff("--cached", "--numstat", diff.a_path).split()
stats = self.repo.git.diff("--cached", "--numstat",
diff.a_path).split()
if len(stats) >= 2:
insertions = int(stats[0]) if stats[0] != "-" else 0
insertions = int(
stats[0]) if stats[0] != "-" else 0
deletions = int(stats[1]) if stats[1] != "-" else 0
except git.GitCommandError as e:
content = f"[Error getting diff: {str(e)}]"
Expand All @@ -149,11 +163,17 @@ def get_staged_changes(self) -> List[Change]:
status = "error"
content = f"[Unexpected error: {str(e)}]"

changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions))
changes.append(
Change(file=diff.b_path or diff.a_path,
status=status,
diff=content,
insertions=insertions,
deletions=deletions))

return changes

def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]:
def _handle_modified_file(self, file_path: str,
file_path_obj: Path) -> Tuple[str, str]:
"""Handle modified file status and diff generation
Args:
Expand All @@ -174,7 +194,8 @@ def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[st
# If file exists but diff failed, something else is wrong
raise IOError(f"Failed to get diff for {file_path}")

def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]:
def _handle_untracked_file(self, file_path: str,
file_path_obj: Path) -> Tuple[str, str]:
"""Handle untracked file status and content reading
Args:
Expand All @@ -190,7 +211,8 @@ def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[s
try:
# Check if file is binary
with open(file_path, "rb") as f:
content = f.read(1024 * 1024) # Read first MB to check for binary content
content = f.read(
1024 * 1024) # Read first MB to check for binary content
if b"\0" in content:
return "new file (binary)", "[Binary file]"

Expand All @@ -215,15 +237,17 @@ def stage_files(self, files: List[str]) -> None:

for file in files:
# Find status for this file
file_status = next((s for s in status if s.split()[-1] == file), None)
file_status = next(
(s for s in status if s.split()[-1] == file), None)
if file_status and file_status.startswith(" D"):
# File is deleted, use remove
self.repo.index.remove([file])
else:
# File is modified or new, use add
self.repo.index.add([file])
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to stage files: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(f"Failed to stage files: {str(e)}",
e.status, e.stderr)

def commit_changes(self, message: str) -> None:
"""Create a commit with the staged changes
Expand All @@ -237,9 +261,12 @@ def commit_changes(self, message: str) -> None:
try:
self.repo.index.commit(message)
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to commit changes: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(f"Failed to commit changes: {str(e)}",
e.status, e.stderr)

def push_changes(self, remote: str = "origin", branch: Optional[str] = None) -> None:
def push_changes(self,
remote: str = "origin",
branch: Optional[str] = None) -> None:
"""Push commits to remote repository
Args:
Expand All @@ -255,7 +282,8 @@ def push_changes(self, remote: str = "origin", branch: Optional[str] = None) ->
origin = self.repo.remote(remote)
origin.push(branch)
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to push changes: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(f"Failed to push changes: {str(e)}",
e.status, e.stderr)

def get_current_branch(self) -> str:
"""Get the name of the current branch
Expand All @@ -269,7 +297,8 @@ def get_current_branch(self) -> str:
try:
return self.repo.active_branch.name
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to get current branch: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(
f"Failed to get current branch: {str(e)}", e.status, e.stderr)

def checkout_branch(self, branch_name: str, create: bool = False) -> None:
"""Checkout a branch
Expand All @@ -286,7 +315,8 @@ def checkout_branch(self, branch_name: str, create: bool = False) -> None:
self.repo.create_head(branch_name)
self.repo.git.checkout(branch_name)
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to checkout branch: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(f"Failed to checkout branch: {str(e)}",
e.status, e.stderr)

def get_commit_history(self, max_count: int = 10) -> List[Dict]:
"""Get commit history
Expand All @@ -303,17 +333,16 @@ def get_commit_history(self, max_count: int = 10) -> List[Dict]:
try:
commits = []
for commit in self.repo.iter_commits(max_count=max_count):
commits.append(
{
"hash": commit.hexsha,
"message": commit.message.strip(),
"author": str(commit.author),
"date": commit.committed_datetime.isoformat(),
}
)
commits.append({
"hash": commit.hexsha,
"message": commit.message.strip(),
"author": str(commit.author),
"date": commit.committed_datetime.isoformat(),
})
return commits
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to get commit history: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(
f"Failed to get commit history: {str(e)}", e.status, e.stderr)

def get_commit_changes(self, commit_hash: str) -> List[Change]:
"""Get changes from a specific commit
Expand All @@ -329,15 +358,17 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]:
"""
try:
commit = self.repo.commit(commit_hash)
parent = commit.parents[0] if commit.parents else self.repo.tree("4b825dc642cb6eb9a060e54bf8d69288fbee4904")
parent = commit.parents[0] if commit.parents else self.repo.tree(
"4b825dc642cb6eb9a060e54bf8d69288fbee4904")

changes = []
diff_index = parent.diff(commit)

for diff in diff_index:
status = "error"
content = ""
insertions = diff.insertions if hasattr(diff, "insertions") else 0
insertions = diff.insertions if hasattr(diff,
"insertions") else 0
deletions = diff.deletions if hasattr(diff, "deletions") else 0

try:
Expand All @@ -348,7 +379,8 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]:
elif diff.new_file:
if diff.b_blob:
try:
content = diff.b_blob.data_stream.read().decode("utf-8")
content = diff.b_blob.data_stream.read(
).decode("utf-8")
status = "new file"
insertions = len(content.splitlines())
deletions = 0
Expand All @@ -360,18 +392,28 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]:
content = "[Empty file]"
else:
status = "modified"
content = self.repo.git.diff(f"{parent.hexsha}..{commit.hexsha}", diff.b_path)
stats = self.repo.git.diff(f"{parent.hexsha}..{commit.hexsha}", "--numstat", diff.b_path).split()
content = self.repo.git.diff(
f"{parent.hexsha}..{commit.hexsha}", diff.b_path)
stats = self.repo.git.diff(
f"{parent.hexsha}..{commit.hexsha}", "--numstat",
diff.b_path).split()
if len(stats) >= 2:
insertions = int(stats[0]) if stats[0] != "-" else 0
insertions = int(
stats[0]) if stats[0] != "-" else 0
deletions = int(stats[1]) if stats[1] != "-" else 0

except Exception as e:
status = "error"
content = f"[Unexpected error: {str(e)}]"

changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions))
changes.append(
Change(file=diff.b_path or diff.a_path,
status=status,
diff=content,
insertions=insertions,
deletions=deletions))

return changes
except git.GitCommandError as e:
raise git.GitCommandError(f"Failed to get commit changes: {str(e)}", e.status, e.stderr)
raise git.GitCommandError(
f"Failed to get commit changes: {str(e)}", e.status, e.stderr)
56 changes: 32 additions & 24 deletions tests/test_git_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def test_get_staged_changes(temp_git_repo):
git_ops = GitOperations(temp_git_repo)
current_dir = os.getcwd()
os.chdir(temp_git_repo)

changes = []
try:
# Test new text file
with open("new_file.txt", "w") as f:
Expand All @@ -502,7 +502,13 @@ def test_get_staged_changes(temp_git_repo):
with open("modify.txt", "w") as f:
f.write("initial")
git_ops.stage_files(["modify.txt"])

# Get and verify staged changes
changes = git_ops.get_staged_changes()
assert len(git_ops.get_staged_changes()) == 3

git_ops.commit_changes("Add file to modify")

with open("modify.txt", "w") as f:
f.write("modified")
git_ops.stage_files(["modify.txt"])
Expand All @@ -511,40 +517,42 @@ def test_get_staged_changes(temp_git_repo):
with open("delete.txt", "w") as f:
f.write("to be deleted")
git_ops.stage_files(["delete.txt"])

changes += git_ops.get_staged_changes()
assert len(git_ops.get_staged_changes()) == 2

git_ops.commit_changes("Add file to delete")

os.remove("delete.txt")
git_ops.stage_files(["delete.txt"])

# Get and verify staged changes
changes = git_ops.get_staged_changes()
assert len(changes) == 4

# Verify new text file
new_file = next(c for c in changes if c.file == "new_file.txt")
print("changes:", changes)
assert new_file.status == "new file"
assert "new content" in new_file.diff
assert new_file.insertions == 1
assert new_file.deletions == 0

# Verify new binary file
binary_file = next(c for c in changes if c.file == "new_binary.bin")
assert binary_file.status == "new file (binary)"
assert binary_file.diff == "[Binary file]"

# Verify modified file
modified_file = next(c for c in changes if c.file == "modify.txt")
assert modified_file.status == "modified"
assert "-initial" in modified_file.diff
assert "+modified" in modified_file.diff
assert modified_file.insertions == 1
assert modified_file.deletions == 1

# Verify deleted file
deleted_file = next(c for c in changes if c.file == "delete.txt")
assert deleted_file.status == "deleted"
assert deleted_file.diff == "[File deleted]"
assert deleted_file.insertions == 0
assert deleted_file.deletions > 0
# # Verify new binary file
# binary_file = next(c for c in changes if c.file == "new_binary.bin")
# assert binary_file.status == "new file (binary)"
# assert binary_file.diff == "[Binary file]"

# # Verify modified file
# modified_file = next(c for c in changes if c.file == "modify.txt")
# assert modified_file.status == "modified"
# assert "-initial" in modified_file.diff
# assert "+modified" in modified_file.diff
# assert modified_file.insertions == 1
# assert modified_file.deletions == 1

# # Verify deleted file
# deleted_file = next(c for c in changes if c.file == "delete.txt")
# assert deleted_file.status == "deleted"
# assert deleted_file.diff == "[File deleted]"
# assert deleted_file.insertions == 0
# assert deleted_file.deletions > 0

finally:
os.chdir(current_dir)

0 comments on commit 38bfb5b

Please sign in to comment.