Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for baidu qianfan and Ernie #823

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
172 changes: 172 additions & 0 deletions guidance/models/_qianfan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import copy

import typing

from ._grammarless import Grammarless, GrammarlessEngine


try:
import qianfan # type: ignore

client_class: typing.Optional[typing.Type[qianfan.ChatCompletion]] = qianfan.ChatCompletion
except ImportError:
client_class = None


class ClassUnavailableException(Exception):
pass


class QianfanAI(Grammarless):
def __init__(
self,
model=None,
echo=True,
max_streaming_tokens=None,
timeout=0.5,
compute_log_probs=False,
is_chat_model=True,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that credentials go into the **kwargs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but normally It's passed through environment variable or .env.

):
"""Build a new QianfanAI model object that represents a model in a given state."""

if client_class is None:
raise ClassUnavailableException("Please execute `pip install qianfan` before using QianfanAI component")

super().__init__(
engine=QianfanAIEngine(
model=model,
max_streaming_tokens=max_streaming_tokens,
timeout=timeout,
compute_log_probs=compute_log_probs,
is_chat_model=is_chat_model,
**kwargs,
),
echo=echo,
)


class QianfanAIEngine(GrammarlessEngine):

def __init__(
self,
model,
max_streaming_tokens,
timeout,
compute_log_probs,
is_chat_model=True,
**kwargs,
):
if client_class is None:
raise ClassUnavailableException("Please execute `pip install qianfan` before using QianfanAI component")

assert (
not compute_log_probs
), "We don't support compute_log_probs=True yet for QianfanAIEngine!"

self.model_name = model

self.is_chat_model = is_chat_model
self.model_obj = qianfan.ChatCompletion(model=model, **kwargs) if self.is_chat_model else qianfan.Completion(model=model, **kwargs)

self.extra_arguments = copy.deepcopy(kwargs)
self.extra_arguments.pop("endpoint") if "endpoint" in kwargs else None

super().__init__(None, max_streaming_tokens, timeout, compute_log_probs)

def _generator(self, prompt, temperature):
if self.is_chat_model:
return self._chat_generator(prompt, temperature)

return self._completion_generator(prompt, temperature)

def _chat_generator(self, prompt, temperature):

# find the system text
pos = 0

system_start = b"<|im_start|>system\n"
user_start = b"<|im_start|>user\n"
assistant_start = b"<|im_start|>assistant\n"
role_end = b"<|im_end|>"

# find the system text
system_text = ""
if prompt.startswith(system_start):
pos += len(system_start)
system_end_pos = prompt.find(role_end)
system_text = prompt[pos:system_end_pos].decode("utf8")
pos = system_end_pos + len(role_end)

# find the user/assistant pairs
messages = []
valid_end = False
while True:

# find the user text
if prompt[pos:].startswith(user_start):
pos += len(user_start)
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
break
messages.append(
dict(
role="user",
content=prompt[pos: pos + end_pos].decode("utf8"),
)
)
pos += end_pos + len(role_end)
elif prompt[pos:].startswith(assistant_start):
pos += len(assistant_start)
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
valid_end = True
break
messages.append(
dict(
role="assistant",
content=prompt[pos: pos + end_pos].decode("utf8"),
)
)
pos += end_pos + len(role_end)
else:
raise Exception(
"It looks like your prompt is not a well formed chat prompt! Please enclose all model state appends inside chat role blocks like `user()` or `assistant()`."
)

self._data = prompt[:pos]

assert len(messages) > 0, "Bad chat format! No chat blocks were defined."
assert (
messages[-1]["role"] == "user"
), "Bad chat format! There must be a user() role before the last assistant() role."
assert valid_end, "Bad chat format! You must generate inside assistant() roles."

if temperature == 0.0:
temperature = 0.0001

input_kwargs = {"temperature": temperature}
input_kwargs.update(self.extra_arguments)

if system_text:
input_kwargs["system"] = system_text

input_kwargs["stream"] = True

result_iter = self.model_obj.do(messages, **input_kwargs)
for response in result_iter:
yield response.body["result"].encode("utf8")

def _completion_generator(self, prompt, temperature):
if temperature == 0.0:
temperature = 0.0001

input_kwargs = {"temperature": temperature}
input_kwargs.update(self.extra_arguments)
input_kwargs["stream"] = True

self._data = prompt

result_iter = self.model_obj.do(prompt.decode("utf8"), **input_kwargs)
for response in result_iter:
yield response.body["result"].encode("utf8")