Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Hugging Face chat and magics processing with new APIs, clients #784

Merged
merged 7 commits into from
May 16, 2024
129 changes: 65 additions & 64 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
Bedrock,
Cohere,
GPT4All,
HuggingFaceHub,
HuggingFaceEndpoint,
OpenAI,
SagemakerEndpoint,
Together,
Expand Down Expand Up @@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs):
),
"text": PromptTemplate.from_template("{prompt}"), # No customization
}

super().__init__(*args, **kwargs, **model_kwargs)

async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down Expand Up @@ -582,14 +581,10 @@ def allows_concurrency(self):
return False


HUGGINGFACE_HUB_VALID_TASKS = (
"text2text-generation",
"text-generation",
"text-to-image",
)


class HfHubProvider(BaseProvider, HuggingFaceHub):
# References for using HuggingFaceEndpoint and InferenceClient:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
class HfHubProvider(BaseProvider, HuggingFaceEndpoint):
id = "huggingface_hub"
name = "Hugging Face Hub"
models = ["*"]
Expand All @@ -609,33 +604,35 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
try:
from huggingface_hub import InferenceClient

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
values["client"] = InferenceClient(
model=values["model"],
timeout=values["timeout"],
token=huggingfacehub_api_token,
task=values.get("task"),
**values["server_kwargs"],
)
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
return values

# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
# Handle text and image outputs
def _call(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
"""Call out to Hugging Face Hub's inference endpoint.

Args:
Expand All @@ -650,45 +647,49 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:

response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)

if type(response) is dict and "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")

# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
imageFormat = response.format # Presume it's a PIL ImageFile
mimeType = ""
if imageFormat == "JPEG":
mimeType = "image/jpeg"
elif imageFormat == "PNG":
mimeType = "image/png"
elif imageFormat == "GIF":
mimeType = "image/gif"
invocation_params = self._invocation_params(stop, **kwargs)
invocation_params["stop"] = invocation_params[
"stop_sequences"
] # porting 'stop_sequences' into the 'stop' argument
response = self.client.post(
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
)

try: # check if this is a text-generation task
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
except: # if fails, then try to process as a text-to-image task
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
# Custom code for responding to image generation responses
if type(response) == bytes: # Is this an image
image = self.client.text_to_image(prompt)
imageFormat = image.format # Presume it's a PIL ImageFile
mimeType = ""
if imageFormat == "JPEG":
mimeType = "image/jpeg"
elif imageFormat == "PNG":
mimeType = "image/png"
elif imageFormat == "GIF":
mimeType = "image/gif"
else:
raise ValueError(f"Unrecognized image format {imageFormat}")
buffer = io.BytesIO()
image.save(buffer, format=imageFormat)
# # Encode image data to Base64 bytes, then decode bytes to str
return (
mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
)
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unrecognized image format {imageFormat}")

buffer = io.BytesIO()
response.save(buffer, format=imageFormat)
# Encode image data to Base64 bytes, then decode bytes to str
return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()

if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
raise ValueError(
"Task not supported, only text-generation and text-to-image tasks are valid."
)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import BaseOutputParser
from langchain_core.prompts import PromptTemplate


class OutlineSection(BaseModel):
Expand Down
Loading