Skip to content

Commit

Permalink
Add functionality to ping AI2 InferD endpoints for tulu 2 (#2832)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Skjonsberg <sams@allenai.org>
  • Loading branch information
natolambert and codeviking authored Dec 24, 2023
1 parent 05cc60c commit 1ffdaee
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,78 @@ def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens):
"error_code": 0,
}
yield data


def ai2_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
from requests import post
from json import loads

# get keys and needed values
ai2_key = api_key or os.environ.get("AI2_API_KEY")
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
model_id = "mod_01hhgcga70c91402r9ssyxekan"

# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
if temperature == 0.0 and top_p < 1.0:
raise ValueError("top_p must be 1 when temperature is 0.0")

res = post(
api_base,
stream=True,
headers={"Authorization": f"Bearer {ai2_key}"},
json={
"model_id": model_id,
# This input format is specific to the Tulu2 model. Other models
# may require different input formats. See the model's schema
# documentation on InferD for more information.
"input": {
"messages": messages,
"opts": {
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"logprobs": 1, # increase for more choices
},
},
},
)

if res.status_code != 200:
logger.error(f"unexpected response ({res.status_code}): {res.text}")
raise ValueError("unexpected response from InferD", res)

text = ""
for line in res.iter_lines():
if line:
part = loads(line)
if "result" in part and "output" in part["result"]:
for t in part["result"]["output"]["text"]:
text += t
else:
logger.error(f"unexpected part: {part}")
raise ValueError("empty result in InferD response")

data = {
"text": text,
"error_code": 0,
}
yield data

0 comments on commit 1ffdaee

Please sign in to comment.