Skip to content

Commit

Permalink
fix: improve diff handling and file processing
Browse files Browse the repository at this point in the history
This commit refactors the diff handling logic in `GitOperations` to correctly identify and process different types of file changes (new, modified, deleted), including binary files. It also adds more comprehensive tests to verify the changes.
  • Loading branch information
runner committed Dec 24, 2024
1 parent cfd06dc commit 2ec4742
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 108 deletions.
180 changes: 98 additions & 82 deletions aicmt/git_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def get_unstaged_changes(self) -> List[Change]:
changes = []

# Get modified files
modified_files = [item.a_path for item in self.repo.index.diff(None)]
modified_files = {item.a_path for item in self.repo.index.diff(None)}

# Get untracked files
untracked_files = self.repo.untracked_files
untracked_files = set(self.repo.untracked_files)

for file_path in modified_files + untracked_files:
for file_path in modified_files.union(untracked_files):
try:
file_path_obj = Path(file_path)
status = ""
Expand All @@ -74,17 +74,10 @@ def get_unstaged_changes(self) -> List[Change]:
status, diff = self._handle_untracked_file(
file_path, file_path_obj)

# Count insertions and deletions only if we have actual diff content
insertions = deletions = 0
if diff and not diff.startswith("["):
insertions = len([
line for line in diff.split("\n")
if line.startswith("+")
])
deletions = len([
line for line in diff.split("\n")
if line.startswith("-")
])
insertions, deletions = self._calculate_diff_stats(diff)
else:
insertions = deletions = 0

changes.append(
Change(
Expand Down Expand Up @@ -122,43 +115,14 @@ def get_staged_changes(self) -> List[Change]:
for diff in diff_index:
status = "error"
content = ""
insertions = diff.insertions if hasattr(diff, "insertions") else 0
deletions = diff.deletions if hasattr(diff, "deletions") else 0
# insertions = diff.insertions if hasattr(diff, "insertions") else 0
# deletions = diff.deletions if hasattr(diff, "deletions") else 0
insertions = 0
deletions = 0

try:
if diff.deleted_file:
status = "deleted"
content = "[File deleted]"
insertions, deletions = 0, diff.a_blob.size if diff.a_blob else 0
elif diff.new_file:
if diff.b_blob and diff.b_blob.is_binary:
status = "new file (binary)"
content = "[Binary file]"
else:
status = "new file"
file_path = os.path.join(self.repo.working_dir,
diff.b_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
insertions = len(content.splitlines())
deletions = 0
except IOError as e:
content = f"[Error reading file: {str(e)}]"
else:
status = "modified"
try:
content = self.repo.git.diff("--cached", diff.a_path)
# Get detailed stats for modified files
stats = self.repo.git.diff("--cached", "--numstat",
diff.a_path).split()
if len(stats) >= 2:
insertions = int(
stats[0]) if stats[0] != "-" else 0
deletions = int(stats[1]) if stats[1] != "-" else 0
except git.GitCommandError as e:
content = f"[Error getting diff: {str(e)}]"

status, content, insertions, deletions = self._process_file_diff(
diff)
except Exception as e:
status = "error"
content = f"[Unexpected error: {str(e)}]"
Expand All @@ -172,6 +136,73 @@ def get_staged_changes(self) -> List[Change]:

return changes

def _process_file_diff(self, diff) -> Tuple:
"""
Handles file differences in Git repositories.
Args.
diff: Git diff object containing information about file changes.
Returns.
tuple: (status, content, insertions, deletions)
- status: File status (deleted/new file/modified, etc.)
- content: file content or differences
- insertions: number of lines inserted
- deletions: number of lines deleted
"""
status = ""
content = ""
insertions = 0
deletions = 0

if diff.deleted_file:
status = "deleted"
content = "[File deleted]"
insertions, deletions = 0, len(
diff.a_blob.data_stream.read().decode('utf-8').splitlines())
elif diff.new_file:
if diff.b_blob and diff.b_blob.mime_type != "text/plain":
status = "new file (binary)"
content = "[Binary file]"
else:
status = "new file"
file_path = os.path.join(self.repo.working_dir, diff.b_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
insertions = len(content.splitlines())
deletions = 0
except IOError as e:
content = f"[Error reading file: {str(e)}]"
else:
status = "modified"
try:
# Check if the file is modified in the staging area
staged_diff = self.repo.git.diff("--cached", diff.a_path)
if staged_diff:
content = staged_diff
#stats = self.repo.git.diff("--cached", "--numstat",diff.a_path).split()
else:
# If the file is not modified in the staging area, compare with the parent commit
content = self.repo.git.diff("HEAD^", "HEAD", diff.a_path)
#stats = self.repo.git.diff("HEAD^", "HEAD", "--numstat",diff.a_path).split()

insertions, deletions = self._calculate_diff_stats(content)
except git.GitCommandError as e:
content = f"[Error getting diff: {str(e)}]"

return status, content, insertions, deletions

def _calculate_diff_stats(self, diff_content: str) -> Tuple[int, int]:
"""Caculates the number of inserted and deleted lines in a diff content"""
insertions = deletions = 0
for line in diff_content.split('\n'):
if line.startswith('+') and not line.startswith('+++'):
insertions += 1
elif line.startswith('-') and not line.startswith('---'):
deletions += 1
return insertions, deletions

def _handle_modified_file(self, file_path: str,
file_path_obj: Path) -> Tuple[str, str]:
"""Handle modified file status and diff generation
Expand Down Expand Up @@ -222,6 +253,18 @@ def _handle_untracked_file(self, file_path: str,
except UnicodeDecodeError:
return "new file (binary)", "[Binary file]"

def _handle_file_content(self, file_path: Path) -> Tuple[str, str]:
if not file_path.exists():
return "deleted", "[文件已删除]"

try:
content = file_path.read_bytes()
if b"\0" in content[:1024]:
return "new file (binary)", "[二进制文件]"
return "new file", file_path.read_text(encoding="utf-8")
except UnicodeDecodeError:
return "new file (binary)", "[二进制文件]"

def stage_files(self, files: List[str]) -> None:
"""Stage specified files
Expand Down Expand Up @@ -367,41 +410,14 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]:
for diff in diff_index:
status = "error"
content = ""
insertions = diff.insertions if hasattr(diff,
"insertions") else 0
deletions = diff.deletions if hasattr(diff, "deletions") else 0

# insertions = diff.insertions if hasattr(diff,
# "insertions") else 0
# deletions = diff.deletions if hasattr(diff, "deletions") else 0
insertions = 0
deletions = 0
try:
if diff.deleted_file:
status = "deleted"
content = "[File deleted]"
insertions, deletions = 0, diff.a_blob.size if diff.a_blob else 0
elif diff.new_file:
if diff.b_blob:
try:
content = diff.b_blob.data_stream.read(
).decode("utf-8")
status = "new file"
insertions = len(content.splitlines())
deletions = 0
except UnicodeDecodeError:
status = "new file (binary)"
content = "[Binary file]"
else:
status = "new file"
content = "[Empty file]"
else:
status = "modified"
content = self.repo.git.diff(
f"{parent.hexsha}..{commit.hexsha}", diff.b_path)
stats = self.repo.git.diff(
f"{parent.hexsha}..{commit.hexsha}", "--numstat",
diff.b_path).split()
if len(stats) >= 2:
insertions = int(
stats[0]) if stats[0] != "-" else 0
deletions = int(stats[1]) if stats[1] != "-" else 0

status, content, insertions, deletions = self._process_file_diff(
diff)
except Exception as e:
status = "error"
content = f"[Unexpected error: {str(e)}]"
Expand Down
89 changes: 63 additions & 26 deletions tests/test_git_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def test_get_staged_changes(temp_git_repo):
git_ops = GitOperations(temp_git_repo)
current_dir = os.getcwd()
os.chdir(temp_git_repo)
changes = []
try:
# Test new text file
with open("new_file.txt", "w") as f:
Expand All @@ -504,8 +503,8 @@ def test_get_staged_changes(temp_git_repo):
git_ops.stage_files(["modify.txt"])

# Get and verify staged changes
changes = git_ops.get_staged_changes()
assert len(git_ops.get_staged_changes()) == 3
first_changes = git_ops.get_staged_changes()
assert len(first_changes) == 3

git_ops.commit_changes("Add file to modify")

Expand All @@ -518,41 +517,79 @@ def test_get_staged_changes(temp_git_repo):
f.write("to be deleted")
git_ops.stage_files(["delete.txt"])

changes += git_ops.get_staged_changes()
assert len(git_ops.get_staged_changes()) == 2
second_changes = git_ops.get_staged_changes()
assert len(second_changes) == 2

git_ops.commit_changes("Add file to delete")

os.remove("delete.txt")
git_ops.stage_files(["delete.txt"])
third_changes = git_ops.get_staged_changes()

# Verify new text file
new_file = next(c for c in changes if c.file == "new_file.txt")
print("changes:", changes)
new_file = next(c for c in first_changes if c.file == "new_file.txt")
assert new_file.status == "new file"
assert "new content" in new_file.diff
assert new_file.insertions == 1
assert new_file.deletions == 0

# # Verify new binary file
# binary_file = next(c for c in changes if c.file == "new_binary.bin")
# assert binary_file.status == "new file (binary)"
# assert binary_file.diff == "[Binary file]"

# # Verify modified file
# modified_file = next(c for c in changes if c.file == "modify.txt")
# assert modified_file.status == "modified"
# assert "-initial" in modified_file.diff
# assert "+modified" in modified_file.diff
# assert modified_file.insertions == 1
# assert modified_file.deletions == 1

# # Verify deleted file
# deleted_file = next(c for c in changes if c.file == "delete.txt")
# assert deleted_file.status == "deleted"
# assert deleted_file.diff == "[File deleted]"
# assert deleted_file.insertions == 0
# assert deleted_file.deletions > 0
# Verify new binary file
binary_file = next(c for c in first_changes
if c.file == "new_binary.bin")
assert binary_file.status == "new file (binary)"
assert binary_file.diff == "[Binary file]"

# Verify modified file
modified_file = next(c for c in second_changes
if c.file == "modify.txt")
assert modified_file.status == "modified"
assert "-initial" in modified_file.diff
assert "+modified" in modified_file.diff
assert modified_file.insertions == 1
assert modified_file.deletions == 1

# Verify deleted file
deleted_file = next(c for c in third_changes if c.file == "delete.txt")
assert deleted_file.status == "deleted"
assert deleted_file.diff == "[File deleted]"
assert deleted_file.insertions == 0
assert deleted_file.deletions > 0

finally:
os.chdir(current_dir)


def test_file_reading(temp_git_repo):
"""Test file reading"""
git_ops = GitOperations(temp_git_repo)
current_dir = os.getcwd()
os.chdir(temp_git_repo)
try:
# Test new text file
with open("new_file.txt", "w") as f:
f.write("new content")
git_ops.stage_files(["new_file.txt"])

# Test new binary file
with open("new_binary.bin", "wb") as f:
f.write(bytes([0x00, 0x01, 0x02, 0x03]))
git_ops.stage_files(["new_binary.bin"])

# Test modified file
with open("modify.txt", "w") as f:
f.write("initial")
git_ops.stage_files(["modify.txt"])

# Remove the file
os.remove("new_file.txt")

changes = git_ops.get_staged_changes()
# Test that IOError is raised when trying to get diff
new_file = next(c for c in changes if c.file == "new_file.txt")
assert new_file.status == "new file"
assert "Error reading file:" in new_file.diff

git_ops.commit_changes("Add file to modify")

finally:
os.chdir(current_dir)

0 comments on commit 2ec4742

Please sign in to comment.