Skip to content

Commit b698b58

Browse files
committed
feat: add anthropic support and make vllm optional
1 parent e49ab16 commit b698b58

File tree

7 files changed

+159
-36
lines changed

7 files changed

+159
-36
lines changed

README.md

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
## 🚀 Installation
1010

1111
```bash
12-
pip install repoqa
12+
# without vLLM (can run openai, anthropic, and huggingface backends)
13+
pip install --upgrade repoqa
14+
# with vLLM
15+
pip install --upgrade "repoqa[vllm]"
1316
```
1417

1518
<details><summary>⏬ Install nightly version <i>:: click to expand ::</i></summary>
1619
<div>
1720

1821
```bash
19-
pip install "git+https://github.com/evalplus/repoqa.git" --upgrade
22+
pip install --upgrade "git+https://github.com/evalplus/repoqa.git" # without vLLM
23+
pip install --upgrade "repoqa[vllm] @ git+https://github.com/evalplus/repoqa@main" # with vLLM
2024
```
2125

2226
</div>
@@ -35,46 +39,55 @@ pip install -r requirements.txt
3539
</div>
3640
</details>
3741

38-
3942
## 🏁 Search Needle Function
4043

41-
### Inference with vLLM
44+
### Inference with OpenAI Compatible Servers
4245

4346
```bash
44-
repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" --caching --backend vllm
47+
repoqa.search_needle_function --model "gpt4-turbo" --caching --backend openai
48+
# 💡 If you use customized server such vLLM:
49+
# repoqa.search_needle_function --base-url "http://url.to.vllm.server/v1" \
50+
# --model "gpt4-turbo" --caching --backend openai
4551
```
4652

47-
### Inference with OpenAI Compatible Servers
53+
### Inference with Anthropic Compatible Servers
54+
55+
```bash
56+
repoqa.search_needle_function --model "claude-3-haiku-20240307" --caching --backend anthropic
57+
```
58+
59+
### Inference with vLLM
4860

4961
```bash
50-
repoqa.search_needle_function --base-url "http://api.openai.com/v1" \
51-
--model "gpt4-turbo" --caching --backend openai
62+
repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" \
63+
--caching --backend vllm
5264
```
5365

5466
### Inference with HuggingFace transformers
5567

5668
```bash
57-
repoqa.search_needle_function --model "gpt2" "Qwen/CodeQwen1.5-7B-Chat" --caching --backend hf
69+
repoqa.search_needle_function --model "gpt2" "Qwen/CodeQwen1.5-7B-Chat" \
70+
--caching --backend hf --trust-remote-code
5871
```
5972

6073
### Usage
6174

6275
> [!Tip]
6376
>
64-
> * **Input**:
65-
> * `--model`: Hugging-Face model ID, such as `ise-uiuc/Magicoder-S-DS-6.7B`
66-
> * `--backend`: `vllm` (default) or `openai`
67-
> * `--base-url`: OpenAI API base URL
68-
> * `--code-context-size` (default: 16384): Number of tokens (using DeepSeekCoder tokenizer) of code in the long context
69-
> * `--caching` (default: False): if enabled, the tokenization and chuncking results will be cached to accelerate subsequent runs
70-
> * `--max-new-tokens` (default: 1024): Maximum number of new tokens to generate
71-
> * `--system-message` (default: None): if given, the model use a system message (but note some models don't support system message)
72-
> * `--tensor-parallel-size`: Number of tensor parallelism (only for vLLM)
73-
> * `--languages` (default: None): List of languages to evaluate (None means all)
74-
> * `--result-dir` (default: "results"): Directory to save the model outputs and evaluation results
75-
> * **Output**:
76-
> * `results/ntoken_{code-context-size}/{model}.jsonl`: Model generated outputs
77-
> * `results/ntoken_{code-context-size}/{model}-SCORE.json`: Evaluation scores (also see [Compute Scores](#compute-scores))
77+
> - **Input**:
78+
> - `--model`: Hugging-Face model ID, such as `ise-uiuc/Magicoder-S-DS-6.7B`
79+
> - `--backend`: `vllm` (default) or `openai`
80+
> - `--base-url`: OpenAI API base URL
81+
> - `--code-context-size` (default: 16384): Number of tokens (using DeepSeekCoder tokenizer) of code in the long context
82+
> - `--caching` (default: False): if enabled, the tokenization and chuncking results will be cached to accelerate subsequent runs
83+
> - `--max-new-tokens` (default: 1024): Maximum number of new tokens to generate
84+
> - `--system-message` (default: None): if given, the model use a system message (but note some models don't support system message)
85+
> - `--tensor-parallel-size`: Number of tensor parallelism (only for vLLM)
86+
> - `--languages` (default: None): List of languages to evaluate (None means all)
87+
> - `--result-dir` (default: "results"): Directory to save the model outputs and evaluation results
88+
> - **Output**:
89+
> - `results/ntoken_{code-context-size}/{model}.jsonl`: Model generated outputs
90+
> - `results/ntoken_{code-context-size}/{model}-SCORE.json`: Evaluation scores (also see [Compute Scores](#compute-scores))
7891
7992
### Compute Scores
8093

@@ -87,12 +100,11 @@ repoqa.compute_score --model-output-path={model-output}.jsonl
87100

88101
> [!Tip]
89102
>
90-
> * **Input**: Path to the model generated outputs.
91-
> * **Output**: The evaluation scores would be stored in `{model-output}-SCORES.json`
92-
103+
> - **Input**: Path to the model generated outputs.
104+
> - **Output**: The evaluation scores would be stored in `{model-output}-SCORES.json`
93105
94106
## 📚 Read More
95107

96-
* [RepoQA Homepage](https://evalplus.github.io/repoqa.html)
97-
* [RepoQA Dataset Curation](docs/curate_dataset.md)
98-
* [RepoQA Development Notes](docs/dev_note.md)
108+
- [RepoQA Homepage](https://evalplus.github.io/repoqa.html)
109+
- [RepoQA Dataset Curation](docs/curate_dataset.md)
110+
- [RepoQA Development Notes](docs/dev_note.md)

repoqa/provider/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,3 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from repoqa.provider.base import BaseProvider
6-
from repoqa.provider.hf import HfProvider
7-
from repoqa.provider.openai import OpenAIProvider
8-
from repoqa.provider.vllm import VllmProvider

repoqa/provider/anthropic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import os
6+
from typing import List
7+
8+
from anthropic import Client
9+
10+
from repoqa.provider.base import BaseProvider
11+
from repoqa.provider.request.anthropic import make_auto_request
12+
13+
14+
class AnthropicProvider(BaseProvider):
15+
def __init__(self, model):
16+
self.model = model
17+
self.client = Client(api_key=os.getenv("ANTHROPIC_KEY"))
18+
19+
def generate_reply(
20+
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
21+
) -> List[str]:
22+
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
23+
replies = []
24+
for _ in range(n):
25+
reply = make_auto_request(
26+
self.client,
27+
message=question,
28+
model=self.model,
29+
temperature=temperature,
30+
max_tokens=max_tokens,
31+
system_msg=system_msg,
32+
)
33+
replies.append(reply.content[0].text)
34+
35+
return replies

repoqa/provider/request/anthropic.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import signal
6+
import time
7+
8+
import anthropic
9+
from anthropic.types import Message
10+
11+
from repoqa.provider.request import construct_message_list
12+
13+
14+
def make_request(
15+
client: anthropic.Client,
16+
message: str,
17+
model: str,
18+
max_tokens: int = 512,
19+
temperature: float = 1,
20+
system_msg="You are a helpful assistant good at coding.",
21+
**kwargs,
22+
) -> Message:
23+
return client.messages.create(
24+
model=model,
25+
messages=construct_message_list(message, system_message=system_msg),
26+
max_tokens=max_tokens,
27+
temperature=temperature,
28+
**kwargs,
29+
)
30+
31+
32+
def handler(signum, frame):
33+
# swallow signum and frame
34+
raise Exception("end of time")
35+
36+
37+
def make_auto_request(client: anthropic.Client, *args, **kwargs) -> Message:
38+
ret = None
39+
while ret is None:
40+
try:
41+
signal.signal(signal.SIGALRM, handler)
42+
signal.alarm(100)
43+
ret = make_request(client, *args, **kwargs)
44+
signal.alarm(0)
45+
except anthropic.RateLimitError:
46+
print("Rate limit exceeded. Waiting...")
47+
signal.alarm(0)
48+
time.sleep(5)
49+
except anthropic.APIConnectionError:
50+
print("API connection error. Waiting...")
51+
signal.alarm(0)
52+
time.sleep(5)
53+
except anthropic.InternalServerError:
54+
print("Internal server error. Waiting...")
55+
signal.alarm(0)
56+
time.sleep(5)
57+
except anthropic.APIError as e:
58+
print("Unknown API error")
59+
print(e)
60+
if (
61+
e.body["error"]["message"]
62+
== "Output blocked by content filtering policy"
63+
):
64+
raise Exception("Content filtering policy blocked output")
65+
signal.alarm(0)
66+
except Exception as e:
67+
print("Unknown error. Waiting...")
68+
print(e)
69+
signal.alarm(0)
70+
time.sleep(1)
71+
return ret

repoqa/search_needle_function.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,20 +302,24 @@ def evaluate_model(
302302
return
303303

304304
if backend == "openai":
305-
from repoqa.provider import OpenAIProvider
305+
from repoqa.provider.openai import OpenAIProvider
306306

307307
engine = OpenAIProvider(model, base_url=base_url)
308308
elif backend == "vllm":
309-
from repoqa.provider import VllmProvider
309+
from repoqa.provider.vllm import VllmProvider
310310

311311
engine = VllmProvider(
312312
model,
313313
tensor_parallel_size=tensor_parallel_size,
314314
max_model_len=int(code_context_size * 1.25), # Magic number
315315
trust_remote_code=trust_remote_code,
316316
)
317+
elif backend == "anthropic":
318+
from repoqa.provider.anthropic import AnthropicProvider
319+
320+
engine = AnthropicProvider(model)
317321
elif backend == "hf":
318-
from repoqa.provider import HfProvider
322+
from repoqa.provider.hf import HfProvider
319323

320324
engine = HfProvider(model, trust_remote_code=trust_remote_code)
321325

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ rich
66
vllm
77
numpy
88
tree_sitter_languages
9+
anthropic

setup.cfg

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@ install_requires =
2222
openai>=1.23.2
2323
nltk>=3.8.1
2424
rich>=13.5.2
25-
vllm>=0.3.3
2625
tree_sitter_languages>=1.10.2
2726
numpy>=1.25.2
27+
anthropic>=0.25.6
2828

2929
[options.entry_points]
3030
console_scripts =
3131
repoqa.search_needle_function = repoqa.search_needle_function:main
3232
repoqa.compute_score = repoqa.compute_score:main
33+
34+
[options.extras_require]
35+
vllm = vllm>=0.3.3

0 commit comments

Comments
 (0)