Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ jobs:
poetry run patchwork AutoFix --log debug \
--patched_api_key=${{ secrets.PATCHED_API_KEY }} \
--github_api_key=${{ secrets.SCM_GITHUB_KEY }} \
--issue_url=https://github.com/patched-codes/patchwork/issues/1039 \
--force_pr_creation \
--disable_telemetry

Expand Down
99 changes: 50 additions & 49 deletions patchwork/common/client/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __get_model_limit(self, model: str) -> int:
return 200_000 - safety_margin

def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) -> list[MessageParam]:
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
new_messages = []
for message in messages:
if message.get("role") == "system":
Expand Down Expand Up @@ -128,22 +129,22 @@ def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam])
return new_messages

def __adapt_chat_completion_request(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
):
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
adapted_messages = self.__adapt_input_messages(messages)
Expand Down Expand Up @@ -207,22 +208,22 @@ def is_model_supported(self, model: str) -> bool:
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)

def is_prompt_supported(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
) -> int:
model_limit = self.__get_model_limit(model)
input_kwargs = self.__adapt_chat_completion_request(
Expand Down Expand Up @@ -251,27 +252,27 @@ def is_prompt_supported(
return model_limit - message_token_count.input_tokens

def truncate_messages(
self, messages: Iterable[ChatCompletionMessageParam], model: str
self, messages: Iterable[ChatCompletionMessageParam], model: str
) -> Iterable[ChatCompletionMessageParam]:
return self._truncate_messages(self, messages, model)

def chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
input_kwargs = self.__adapt_chat_completion_request(
messages=messages,
Expand Down
42 changes: 32 additions & 10 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from azure.devops.released.core.core_client import CoreClient
from azure.devops.released.git.git_client import GitClient
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository
from github import Auth, Consts, Github, GithubException, PullRequest
from github import Auth, Consts, Github, GithubException, PullRequest, Issue
from github.GithubObject import NotSet
from github.GithubException import UnknownObjectException
from gitlab import Gitlab, GitlabAuthenticationError, GitlabError
from gitlab.v4.objects import ProjectMergeRequest
Expand Down Expand Up @@ -197,6 +198,7 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
issue_url: str | None = None,
) -> PullRequestProtocol:
...

Expand Down Expand Up @@ -434,18 +436,26 @@ def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None:
return slug, resource_id

def find_issue_by_url(self, url: str) -> IssueText | None:
slug, issue_id = self.get_slug_and_id_from_url(url)
resource_slug_and_id = self.get_slug_and_id_from_url(url)
if resource_slug_and_id is None:
return None
slug, issue_id = resource_slug_and_id
return self.find_issue_by_id(slug, issue_id)

def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
repo = self.github.get_repo(slug)
issue = self.__find_issue_by_id(slug, issue_id)
if issue is None:
return None
return dict(
title=issue.title,
body=issue.body,
comments=[issue_comment.body for issue_comment in issue.get_comments()],
)

def __find_issue_by_id(self, slug: str, issue_id: int) -> Issue | None:
try:
issue = repo.get_issue(issue_id)
return dict(
title=issue.title,
body=issue.body,
comments=[issue_comment.body for issue_comment in issue.get_comments()],
)
repo = self.github.get_repo(slug)
return repo.get_issue(issue_id)
except GithubException as e:
logger.warn(f"Failed to get issue: {e}")
return None
Expand Down Expand Up @@ -508,10 +518,19 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
issue_url: str | None = None,
) -> PullRequestProtocol:
# before creating a PR, check if one already exists
repo = self.github.get_repo(slug)
gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch)

issue_obj = NotSet
if issue_url is not None:
resource_slug_and_id = self.get_slug_and_id_from_url(issue_url)
if resource_slug_and_id is not None:
slug, issue_id = resource_slug_and_id
issue_obj = self.__find_issue_by_id(slug, issue_id)

gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch, issue=issue_obj)
pr = GithubPullRequest(gh_pr)
return pr

Expand Down Expand Up @@ -630,7 +649,9 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
issue_url: str | None = None,
) -> PullRequestProtocol:
# issue_url is unused here because we usually set it in the MR body instead for gitlab.
# before creating a PR, check if one already exists
project = self.gitlab.projects.get(slug)
gl_mr = project.mergerequests.create(
Expand Down Expand Up @@ -777,6 +798,7 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
issue_url: str | None = None,
) -> PullRequestProtocol:
# before creating a PR, check if one already exists
pr_body = GitPullRequest(
Expand Down
15 changes: 10 additions & 5 deletions patchwork/steps/CreatePR/CreatePR.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing_extensions import Optional

import git
from git.exc import GitCommandError
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, inputs: dict):
)
self.enabled = False

self.issue_url = inputs.get("issue_url")
self.pr_body = inputs.get("pr_body", "")
self.title = inputs.get("pr_title", "Patchwork PR")
self.force = bool(inputs.get("force_pr_creation", False))
Expand Down Expand Up @@ -107,6 +109,7 @@ def run(self) -> dict:
base_branch_name=self.base_branch,
target_branch_name=self.target_branch,
scm_client=self.scm_client,
issue_url=self.issue_url,
force=self.force,
)

Expand Down Expand Up @@ -147,17 +150,19 @@ def create_pr(
base_branch_name: str,
target_branch_name: str,
scm_client: ScmPlatformClientProtocol,
issue_url: Optional[str] = None,
force: bool = False,
):
prs = scm_client.find_prs(repo_slug, original_branch=base_branch_name, feature_branch=target_branch_name)
pr = next(iter(prs), None)
if pr is None:
pr = scm_client.create_pr(
repo_slug,
title,
body,
base_branch_name,
target_branch_name,
slug=repo_slug,
title=title,
body=body,
original_branch=base_branch_name,
feature_branch=target_branch_name,
issue_url=issue_url
)

pr.set_pr_description(body)
Expand Down
1 change: 1 addition & 0 deletions patchwork/steps/CreatePR/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CreatePRInputs(__CreatePRRequiredInputs, total=False):
force_pr_creation: Annotated[bool, StepTypeConfig(is_config=True)]
disable_pr: Annotated[bool, StepTypeConfig(is_config=True)]
scm_url: Annotated[str, StepTypeConfig(is_config=True)]
issue_url: Annotated[str, StepTypeConfig(is_config=True)]
gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True)]
github_api_key: Annotated[str, StepTypeConfig(is_config=True)]
azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True)]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.84"
version = "0.0.85"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down
Loading