Skip to content

Add support for Hugging Face Serverless Inference #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ bigcodebench.evaluate \
--execution [e2b|gradio|local] \
--split [complete|instruct] \
--subset [full|hard] \
--backend [vllm|openai|anthropic|google|mistral|hf]
--backend [vllm|openai|anthropic|google|mistral|hf|hf-inference]
```

- All the resulted files will be stored in a folder named `bcb_results`.
Expand Down Expand Up @@ -177,6 +177,13 @@ Access Gemini APIs from [Google AI Studio](https://aistudio.google.com/)
export GOOGLE_API_KEY=<your_google_api_key>
```

Access the [Hugging Face Serverless Inference API](https://huggingface.co/docs/api-inference/en/index)
```bash
export HF_INFERENCE_API_KEY=<your_hf_api_key>
```

Please make sure your HF access token has the `Make calls to inference providers` permission.

## 💻 LLM-generated Code

We share pre-generated code samples from LLMs we have [evaluated](https://huggingface.co/spaces/bigcode/bigcodebench-leaderboard) on the full set:
Expand Down
34 changes: 34 additions & 0 deletions bigcodebench/gen/util/hf_inference_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import time

from huggingface_hub import InferenceClient
from huggingface_hub.inference._generated.types import TextGenerationOutput


def make_request(
client: InferenceClient,
message: str,
model: str,
temperature: float,
n: int,
max_new_tokens: int = 2048,
) -> TextGenerationOutput:
response = client.text_generation(
model=model,
prompt=message,
do_sample=False,
max_new_tokens=max_new_tokens,
)

return response


def make_auto_request(*args, **kwargs) -> TextGenerationOutput:
ret = None
while ret is None:
try:
ret = make_request(*args, **kwargs)
except Exception as e:
print("Unknown error. Waiting...")
print(e)
time.sleep(1)
return ret
13 changes: 13 additions & 0 deletions bigcodebench/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ def make_model(
tokenizer_name=tokenizer_name,
tokenizer_legacy=tokenizer_legacy,
)
elif backend == "hf-inference":
from bigcodebench.provider.hf_inference import HuggingFaceInferenceDecoder

return HuggingFaceInferenceDecoder(
name=model,
subset=subset,
split=split,
temperature=temperature,
max_new_tokens=max_new_tokens,
direct_completion=direct_completion,
instruction_prefix=instruction_prefix,
response_prefix=response_prefix,
)
elif backend == "openai":
from bigcodebench.provider.openai import OpenAIChatDecoder

Expand Down
54 changes: 54 additions & 0 deletions bigcodebench/provider/hf_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
from typing import List
from tqdm import tqdm

from huggingface_hub import InferenceClient

from bigcodebench.provider.base import DecoderBase
from bigcodebench.gen.util.hf_inference_request import make_auto_request
from bigcodebench.provider.utility import make_raw_chat_prompt


class HuggingFaceInferenceDecoder(DecoderBase):
def __init__(self, name: str, **kwargs):
super().__init__(name, **kwargs)
self.client = InferenceClient(
provider="hf-inference", api_key=os.getenv("HF_INFERENCE_API_KEY")
)

def codegen(
self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
) -> List[str]:
if do_sample:
assert self.temperature > 0, "Temperature must be positive for sampling"

all_outputs = []

for prompt in tqdm(prompts):
outputs = []
message = (
prompt
if self.is_direct_completion()
else make_raw_chat_prompt(
task_prompt=prompt,
subset=self.subset,
split=self.split,
instruction_prefix=self.instruction_prefix,
response_prefix=self.response_prefix,
tokenizer=None,
)
)
ret = make_auto_request(
self.client,
message=message,
model=self.name,
n=num_samples,
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
)
outputs.append(ret)
all_outputs.append(outputs)
return all_outputs

def is_direct_completion(self) -> bool:
return self.direct_completion