Skip to content

⚡️ Speed up function apply_diff by 4,799% #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 47 additions & 62 deletions openhands/resolver/patching/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path
import subprocess
import tempfile
from tempfile import NamedTemporaryFile

from .exceptions import HunkApplyException, SubprocessException
from .patch import Change, diffobj
Expand All @@ -17,60 +18,57 @@ def _apply_diff_with_subprocess(
if not patchexec:
raise SubprocessException('cannot find patch program', code=-1)

tempdir = tempfile.gettempdir()

filepath = os.path.join(tempdir, 'wtp-' + str(hash(diff.header)))
oldfilepath = filepath + '.old'
newfilepath = filepath + '.new'
rejfilepath = filepath + '.rej'
patchfilepath = filepath + '.patch'
with open(oldfilepath, 'w') as f:
f.write('\n'.join(lines) + '\n')

with open(patchfilepath, 'w') as f:
f.write(diff.text)

args = [
patchexec,
'--reverse' if reverse else '--forward',
'--quiet',
'--no-backup-if-mismatch',
'-o',
newfilepath,
'-i',
patchfilepath,
'-r',
rejfilepath,
oldfilepath,
]
ret = subprocess.call(args)

with open(newfilepath) as f:
lines = f.read().splitlines()

try:
with open(rejfilepath) as f:
rejlines = f.read().splitlines()
except IOError:
rejlines = None

with NamedTemporaryFile('w+', delete=False) as old_file, \
NamedTemporaryFile('w+', delete=False) as patch_file, \
NamedTemporaryFile('r+', delete=False) as new_file, \
NamedTemporaryFile('r+', delete=False) as rej_file:

old_file.write('\n'.join(lines) + '\n')
oldfilepath = old_file.name

patch_file.write(diff.text)
patchfilepath = patch_file.name
args = [
patchexec,
'--reverse' if reverse else '--forward',
'--quiet',
'--no-backup-if-mismatch',
'-o',
new_file.name,
'-i',
patchfilepath,
'-r',
rej_file.name,
oldfilepath,
]
ret = subprocess.call(args)

# Read new_file contents
new_file.seek(0)
lines = new_file.read().splitlines()

try:
rej_file.seek(0)
rejlines = rej_file.read().splitlines()
except IOError:
rejlines = None

# Clean up temporary files
remove(oldfilepath)
remove(newfilepath)
remove(rejfilepath)
remove(patchfilepath)
remove(new_file.name)
remove(rej_file.name)

# do this last to ensure files get cleaned up
# Check and raise exception after file clean-up
if ret != 0:
raise SubprocessException('patch program failed', code=ret)

return lines, rejlines


def _reverse(changes: list[Change]) -> list[Change]:
def _reverse_change(c: Change) -> Change:
return c._replace(old=c.new, new=c.old)

return [_reverse_change(c) for c in changes]
return [c._replace(old=c.new, new=c.old) for c in changes]


def apply_diff(
Expand All @@ -85,44 +83,31 @@ def apply_diff(
n_lines = len(lines)

changes = _reverse(diff.changes) if reverse else diff.changes
# check that the source text matches the context of the diff

# Validate source text with context lines
for old, new, line, hunk in changes:
# might have to check for line is None here for ed scripts
if old is not None and line is not None:
if old > n_lines:
raise HunkApplyException(
'context line {n}, "{line}" does not exist in source'.format(
n=old, line=line
),
hunk=hunk,
f'context line {old}, "{line}" does not exist in source', hunk=hunk
)
if lines[old - 1] != line:
# Try to normalize whitespace by replacing multiple spaces with a single space
# This helps with patches that have different indentation levels
normalized_line = ' '.join(line.split())
normalized_source = ' '.join(lines[old - 1].split())
if normalized_line != normalized_source:
raise HunkApplyException(
'context line {n}, "{line}" does not match "{sl}"'.format(
n=old, line=line, sl=lines[old - 1]
),
hunk=hunk,
f'context line {old}, "{line}" does not match "{lines[old - 1]}"', hunk=hunk
)

# for calculating the old line
# Efficient change application
r = 0
i = 0

for old, new, line, hunk in changes:
if old is not None and new is None:
del lines[old - 1 - r + i]
r += 1
elif old is None and new is not None:
lines.insert(new - 1, line)
i += 1
elif old is not None and new is not None:
# Sometimes, people remove hunks from patches, making these
# numbers completely unreliable. Because they're jerks.
pass

return lines