Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more bug fixes on diff #34

Merged
merged 1 commit into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions openelm/utils/diff_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from enum import Enum
from typing import Optional

line_number_pattern = re.compile(r"(?m)^@@ -(\d*?),(\d*?) \+(\d*?),(\d*?) @@")
line_number_pattern = re.compile(r"(?m)^@@ -(?P<l1>\d*),*?(?P<s1>\d*?) \+(?P<l2>\d*),*?(?P<s2>\d*?) @@")
diff_pattern = re.compile(
r"""<NME> (?P<name>.*?)
<BEF> (?P<file>(.|\n)*?)
<MSG> (?P<message>(.|\n)*?)
<DFF> (?P<diff>(.|\n)*)"""
)
hunk_split_pattern = re.compile(r"(?m)^(@@ .*? @@).*\n")
ignored = re.compile(r"(?m)^\\ No newline at end of file$\n?")


class DiffState(Enum):
Expand Down Expand Up @@ -66,13 +67,14 @@ def parse_line_info(content: str) -> tuple:
match = line_number_pattern.match(content)
if match is None:
return ()
match = match.groups()
if len(match) >= 4:
# shouldn't be more than 4, but in case of anything weird, we take the first 4 matching elements.
return tuple([int(num) for num in match][:4])
else:
# incorrect format => return nothing
match_dict = match.groupdict()
# line numbers are mandatory
if not match_dict['l1'] or not match_dict['l2']:
return ()
for s in ['s1', 's2']:
# line ranges are optional and default to 1
match_dict[s] = match_dict[s] if match_dict[s] else '1'
return int(match_dict['l1']), int(match_dict['s1']), int(match_dict['l2']), int(match_dict['s2'])


def parse_diff_content(
Expand All @@ -90,10 +92,10 @@ def parse_diff_content(
(before_diff, after_diff);
None if reject_invalid==True and the diff hunk contains invalid format.
"""
# Remove trailing \n at the beginning and the end.
hunk = hunk.split("\n")
before_diff, after_diff = [], []
for line in hunk:
# Ignore invalid trailing '\n'. An empty line in the diff hunk should at least be '\n ' with the space.
if not line:
continue
if line[0] == "-" or line[0] == " ":
Expand All @@ -109,9 +111,11 @@ def parse_diff_content(
return "\n".join(before_diff), "\n".join(after_diff)


def replace_text(
text: str, before: str, after: str, start_pointer: int
) -> tuple[str, int]:
def replace_text(text: str,
before: str,
after: str,
start_pointer: int,
reject_incomplete_line: bool = True) -> tuple[str, int]:
"""
Try to match `before` within `text` and replace the content into `after`.
If not found, return the original text.
Expand All @@ -121,16 +125,24 @@ def replace_text(
before: the text to be matched.
after: the text to be replaced into.
start_pointer: the index where we start to match (inclusive).
reject_incomplete_line: (Optional) reject the patch if `before` does not match till the end of a line.
Returns:
(diff_result, new_start_pointer)
the text after the match-and-replace and the new index at the end of the change.
"""
idx = text.find(before)
if idx < start_pointer:
idx = text[start_pointer:].find(before)
start_idx = start_pointer + idx

if reject_incomplete_line:
# If the end of the match is neither EOF nor \n, reject the patch.
if idx >= 0 and start_idx + len(before) < len(text) and text[start_idx + len(before)] != '\n':
return text, start_pointer

if idx < 0:
return text, start_pointer
else:
# Even if idx + len(before) is out-of-bound, the list slicing would return []
return text[:idx] + after + text[idx + len(before) :], idx + len(after)
# Even if start_idx + len(before) is out-of-bound, the list slicing would return ""
return text[:start_idx] + after + text[start_idx + len(before):], start_idx + len(after)


def apply_diff(file: str, diff: str, use_line_number=False, allow_add_file=True) -> str:
Expand All @@ -149,8 +161,7 @@ def apply_diff(file: str, diff: str, use_line_number=False, allow_add_file=True)
Return:
the maximally patched file content.
"""
diff = hunk_split_pattern.split(diff.lstrip().lstrip("\n"))

diff = hunk_split_pattern.split(ignored.sub("", diff))
# If we use the line numbers, we match-and-replace in a line-by-line fashion.
file_by_line = file.split("\n") if use_line_number else None
line_offset = 0 # the offset between pre-/post-patching line numbers
Expand All @@ -170,14 +181,16 @@ def apply_diff(file: str, diff: str, use_line_number=False, allow_add_file=True)
i += 2

# Generate the pre-/post-diff string based on the first character being '+' or '-'
# (Note: parse_diff_content will strip trailing \n at the beginning and the end)
# (Note: parse_diff_content will ignore trailing \n at the beginning and at the end)
parsed_diff = parse_diff_content(diff_content, separate_lines=use_line_number)

# If we allow the recognition of "ADDFILE" and encounter such file, special treatment is needed.
# If we allow the recognition of "ADDFILE", special treatment is needed.
if allow_add_file and file == "ADDFILE":
if use_line_number:
# Immediately apply the first hunk but also check the partial validity of line numbers.
return parsed_diff[1] if line_info == (0, 0) else ""
else:
# Immediately apply the first hunk and ignore the rest.
return parsed_diff[1]

if use_line_number:
Expand All @@ -196,17 +209,23 @@ def apply_diff(file: str, diff: str, use_line_number=False, allow_add_file=True)
# We ignore the second pair "+a, b" just to be lenient.
if valid and len(parsed_diff[0]) == line_info[1]:
# Update the list of lines
file_by_line = (
file_by_line[: start_idx - 1]
+ parsed_diff[1]
+ file_by_line[start_idx - 1 + line_info[1] :]
)
if start_idx == 0: # Add lines to the beginning.
file_by_line = parsed_diff[1] + file_by_line
else:
file_by_line = file_by_line[: start_idx - 1] + parsed_diff[1] + \
file_by_line[start_idx - 1 + line_info[1]:]
line_offset += len(parsed_diff[1]) - line_info[1]
else:
# Directly (and naively) apply patch by match-and-replace.
file, patch_pointer = replace_text(
file, parsed_diff[0], parsed_diff[1], patch_pointer
)
# CAUTION: this way of handling empty context is being very lenient and could lead to
# undesirable behaviors. Only do this when you want to be as tolerant as possible.
if parsed_diff[0] == "":
if patch_pointer != 0: # Lack of matching context can only happen at the beginning of file.
continue
file = parsed_diff[1] + "\n" + file
patch_pointer = len(parsed_diff[0]) + 1
else:
# Directly (and naively) apply patch by match-and-replace.
file, patch_pointer = replace_text(file, parsed_diff[0], parsed_diff[1], patch_pointer)

if use_line_number:
file = "\n".join(file_by_line)
Expand All @@ -224,15 +243,15 @@ def verify_diff(diff_text: str) -> DiffState:
Returns:
A DiffState (see above).
"""
diff_dict = split_diff(diff_text)
diff_dict = split_diff(ignored.sub("", diff_text)) # Ignore the GitHub warning on the end of file
line_offset = 0

keys = ["name", "file", "message", "diff"]
for key in keys:
if key not in diff_dict:
return DiffState(0b100) # Invalid overall format

diff_parts = hunk_split_pattern.split(diff_dict["diff"].lstrip())
diff_parts = hunk_split_pattern.split(diff_dict["diff"])
if not diff_parts:
return DiffState(0b100) # Invalid overall format

Expand All @@ -253,7 +272,7 @@ def verify_diff(diff_text: str) -> DiffState:
len(diff_parts) != i
or not line_info
or line_info[:3] != (0, 0, 1)
or line_info[3] != len(diff_content[1].strip("\n").split("\n"))
or line_info[3] != len(diff_content[1].split("\n"))
or diff_content[0]
):
return DiffState(0b110)
Expand Down Expand Up @@ -283,8 +302,8 @@ def verify_diff(diff_text: str) -> DiffState:
line_number_mismatch = True
else:
# Check the line numbers regardless of whether the context matches.
pre_diff_line_number = len(diff_content[0].strip("\n").split("\n"))
post_diff_line_number = len(diff_content[1].strip("\n").split("\n"))
pre_diff_line_number = len(diff_content[0].split("\n"))
post_diff_line_number = len(diff_content[1].split("\n"))
if (pre_diff_line_number, post_diff_line_number) != (
line_info[1],
line_info[3],
Expand Down
Loading