Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions patchwork/common/utils/input_parsing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from collections.abc import Iterable, Mapping

from typing_extensions import AnyStr, Union
Expand Down Expand Up @@ -69,3 +70,23 @@ def parse_to_list(
continue
rv.append(stripped_value)
return rv


def parse_to_dict(possible_dict, limit=-1):
if possible_dict is None and limit == 0:
return None

if isinstance(possible_dict, dict):
new_dict = dict()
for k, v in possible_dict.items():
new_dict[k] = parse_to_dict(v, limit - 1)
return new_dict
elif isinstance(possible_dict, str):
try:
new_dict = json.loads(possible_dict, strict=False)
except json.JSONDecodeError:
return possible_dict

return parse_to_dict(new_dict, limit - 1)
else:
return possible_dict
18 changes: 18 additions & 0 deletions patchwork/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import atexit
import dataclasses
import random
import signal
import string
import tempfile
from collections.abc import Mapping
from pathlib import Path

import chevron
import tiktoken
from chardet.universaldetector import UniversalDetector
from git import Head, Repo
Expand All @@ -19,6 +23,20 @@
_NEWLINES = {"\n", "\r\n", "\r"}


def mustache_render(template: str, data: Mapping) -> str:
if len(data.keys()) < 1:
return template

chevron.render.__globals__["_html_escape"] = lambda x: x
return chevron.render(
template=template,
data=data,
partials_path=None,
partials_ext="".join(random.choices(string.ascii_uppercase + string.digits, k=32)),
partials_dict=dict(),
)


def detect_newline(path: str | Path) -> str | None:
with open(path, "r", newline="") as f:
lines = f.read().splitlines(keepends=True)
Expand Down
46 changes: 30 additions & 16 deletions patchwork/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from enum import Enum

from typing_extensions import Any, Dict, List, Optional, Union, is_typeddict
from typing_extensions import (
Any,
Collection,
Dict,
List,
Optional,
Type,
Union,
is_typeddict,
)

from patchwork.logger import logger

Expand Down Expand Up @@ -45,10 +54,9 @@ def __init__(self, inputs: DataPoint):
"""

# check if the inputs have the required keys
if self.__input_class is not None:
missing_keys = self.__input_class.__required_keys__.difference(inputs.keys())
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")
missing_keys = self.find_missing_inputs(inputs)
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")

# store the inputs
self.inputs = inputs
Expand All @@ -64,19 +72,25 @@ def __init__(self, inputs: DataPoint):
self.original_run = self.run
self.run = self.__managed_run

def __init_subclass__(cls, **kwargs):
input_class = kwargs.get("input_class", None) or getattr(cls, "input_class", None)
output_class = kwargs.get("output_class", None) or getattr(cls, "output_class", None)
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
if cls.__name__ == "PreparePR":
print(1)
input_class = input_class or getattr(cls, "input_class", None)
if input_class is not None and not is_typeddict(input_class):
input_class = None

if input_class is not None and is_typeddict(input_class):
cls.__input_class = input_class
else:
cls.__input_class = None
output_class = output_class or getattr(cls, "output_class", None)
if output_class is not None and not is_typeddict(output_class):
output_class = None

if output_class is not None and is_typeddict(output_class):
cls.__output_class = output_class
else:
cls.__output_class = None
cls._input_class = input_class
cls._output_class = output_class

@classmethod
def find_missing_inputs(cls, inputs: DataPoint) -> Collection:
if getattr(cls, "_input_class", None) is None:
return []
return cls._input_class.__required_keys__.difference(inputs.keys())

def __managed_run(self, *args, **kwargs) -> Any:
self.debug(self.inputs)
Expand Down
57 changes: 57 additions & 0 deletions patchwork/steps/CallSQL/CallSQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from sqlalchemy import URL, create_engine, exc, text

from patchwork.common.utils.input_parsing import parse_to_dict
from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs


class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
query_template_data = inputs.get("db_query_template_values", {})
self.query = mustache_render(inputs["db_query"], query_template_data)
self.__build_engine(inputs)

def __build_engine(self, inputs: dict):
dialect = inputs["db_dialect"]
driver = inputs.get("db_driver")
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
kwargs = dict(
username=inputs.get("db_username"),
host=inputs.get("db_host", "localhost"),
port=inputs.get("db_port", 5432),
password=inputs.get("db_password"),
database=inputs.get("db_database"),
query=parse_to_dict(inputs.get("db_params")),
)
connection_url = URL.create(
dialect_plus_driver,
**{k: v for k, v in kwargs.items() if v is not None},
)

connect_args = None
if inputs.get("db_driver_args") is not None:
connect_args = parse_to_dict(inputs.get("db_driver_args"))

self.engine = create_engine(connection_url, connect_args=connect_args)
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
return self.engine

def run(self) -> dict:
try:
rv = []
with self.engine.begin() as conn:
cursor = conn.execute(text(self.query))
for row in cursor:
result = row._asdict()
rv.append(result)
logger.info(f"Retrieved {len(rv)} rows!")
return dict(results=rv)
except exc.InvalidRequestError as e:
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}")
return dict(results=[])
Empty file.
24 changes: 24 additions & 0 deletions patchwork/steps/CallSQL/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing_extensions import Any, TypedDict


class __RequiredCallSQLInputs(TypedDict):
db_dialect: str
db_query: str


class CallSQLInputs(__RequiredCallSQLInputs, total=False):
db_driver: str
db_username: str
db_password: str
db_host: str
db_port: int
db_name: str
db_params: dict[str, Any]
db_driver_args: dict[str, Any]
db_query_template_values: dict[str, Any]


class CallSQLOutputs(TypedDict):
results: list[dict[str, Any]]
58 changes: 58 additions & 0 deletions patchwork/steps/CallShell/CallShell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import shlex
import subprocess
from pathlib import Path

from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallShell.typed import CallShellInputs, CallShellOutputs


class CallShell(Step, input_class=CallShellInputs, output_class=CallShellOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
script_template_values = inputs.get("script_template_values", {})
self.script = mustache_render(inputs["script"], script_template_values)
self.working_dir = inputs.get("working_dir", Path.cwd())
self.env = self.__parse_env_text(inputs.get("env", ""))

@staticmethod
def __parse_env_text(env_text: str) -> dict[str, str]:
env_spliter = shlex.shlex(env_text, posix=True)
env_spliter.whitespace_split = True
env_spliter.whitespace += ";"

env: dict[str, str] = dict()
for env_assign in env_spliter:
env_assign_spliter = shlex.shlex(env_assign, posix=True)
env_assign_spliter.whitespace_split = True
env_assign_spliter.whitespace += "="
env_parts = list(env_assign_spliter)
if len(env_parts) < 1:
continue

env_assign_target = env_parts[0]
if len(env_parts) < 2:
logger.error(f"{env_assign_target} is not assigned anything, skipping...")
continue
if len(env_parts) > 2:
logger.error(f"{env_assign_target} has more than 1 assignment, skipping...")
continue
env[env_assign_target] = env_parts[1]

return env

def run(self) -> dict:
p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
try:
p.check_returncode()
except subprocess.CalledProcessError as e:
self.set_status(
StepStatus.FAILED,
f"Script failed.",
)
logger.info(f"stdout: \n{p.stdout}")
logger.info(f"stderr:\n{p.stderr}")
return dict(stdout_output=p.stdout, stderr_output=p.stderr)
Empty file.
19 changes: 19 additions & 0 deletions patchwork/steps/CallShell/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing_extensions import Annotated, Any, TypedDict

from patchwork.common.utils.step_typing import StepTypeConfig


class __RequiredCallShellInputs(TypedDict):
script: str


class CallShellInputs(__RequiredCallShellInputs, total=False):
working_dir: Annotated[str, StepTypeConfig(is_path=True)]
env: str
script_template_values: dict[str, Any]


class CallShellOutputs(TypedDict):
stdout_output: str
16 changes: 7 additions & 9 deletions patchwork/steps/FixIssue/FixIssue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import difflib
import re
from pathlib import Path
from typing import Any, Optional

from git import Repo, InvalidGitRepositoryError
from patchwork.logger import logger
from git import InvalidGitRepositoryError, Repo
from openai.types.chat import ChatCompletionMessageParam

from patchwork.common.client.llm.aio import AioLlmClient
Expand All @@ -15,6 +13,7 @@
AnalyzeImplementStrategy,
)
from patchwork.common.tools import CodeEditTool, Tool
from patchwork.logger import logger
from patchwork.step import Step
from patchwork.steps.FixIssue.typed import FixIssueInputs, FixIssueOutputs

Expand Down Expand Up @@ -100,7 +99,7 @@ def is_stop(self, messages: list[ChatCompletionMessageParam]) -> bool:
class FixIssue(Step, input_class=FixIssueInputs, output_class=FixIssueOutputs):
def __init__(self, inputs):
"""Initialize the FixIssue step.

Args:
inputs: Dictionary containing input parameters including:
- base_path: Optional path to the repository root
Expand Down Expand Up @@ -145,12 +144,12 @@ def __init__(self, inputs):

def run(self):
"""Execute the FixIssue step.

This method:
1. Executes the multi-turn LLM conversation to analyze and fix the issue
2. Tracks file modifications made by the CodeEditTool
3. Generates in-memory diffs for all modified files

Returns:
dict: Dictionary containing list of modified files with their diffs
"""
Expand All @@ -162,8 +161,7 @@ def run(self):
if not isinstance(tool, CodeEditTool):
continue
tool_modified_files = [
dict(path=str(file_path.relative_to(cwd)), diff="")
for file_path in tool.tool_records["modified_files"]
dict(path=str(file_path.relative_to(cwd)), diff="") for file_path in tool.tool_records["modified_files"]
]
modified_files.extend(tool_modified_files)

Expand All @@ -174,7 +172,7 @@ def run(self):
file = modified_file["path"]
try:
# Try to get the diff using git
diff = self.repo.git.diff('HEAD', file)
diff = self.repo.git.diff("HEAD", file)
modified_file["diff"] = diff or ""
except Exception as e:
# Git-specific errors (untracked files, etc) - keep empty diff
Expand Down
8 changes: 5 additions & 3 deletions patchwork/steps/FixIssue/typed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import Annotated, Dict, List, TypedDict
from typing_extensions import Annotated, List, TypedDict

from patchwork.common.constants import TOKEN_URL
from patchwork.common.utils.step_typing import StepTypeConfig
Expand Down Expand Up @@ -37,19 +37,21 @@ class FixIssueInputs(__FixIssueRequiredInputs, total=False):

class ModifiedFile(TypedDict):
"""Represents a file that has been modified by the FixIssue step.

Attributes:
path: The relative path to the modified file from the repository root
diff: A unified diff string showing the changes made to the file.
Generated using Python's difflib to compare the original and
modified file contents in memory.

Note:
The diff is generated by comparing file contents before and after
modifications, without relying on version control systems.
"""

path: str
diff: str


class FixIssueOutputs(TypedDict):
modified_files: List[ModifiedFile]
Loading
Loading