-
Notifications
You must be signed in to change notification settings - Fork 17
/
huggingface-connector.py
88 lines (73 loc) · 3.53 KB
/
huggingface-connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
import aiohttp
from aiohttp import ClientResponse
from moonshot.src.connectors.connector import Connector, perform_retry
from moonshot.src.connectors.connector_response import ConnectorResponse
from moonshot.src.connectors_endpoints.connector_endpoint_arguments import (
ConnectorEndpointArguments,
)
from moonshot.src.utils.log import configure_logger
# Create a logger for this module
logging.basicConfig(level=logging.INFO)
logger = configure_logger(__name__)
class HuggingFaceConnector(Connector):
def __init__(self, ep_arguments: ConnectorEndpointArguments):
# Initialize super class
super().__init__(ep_arguments)
@Connector.rate_limited
@perform_retry
async def get_response(self, prompt: str) -> ConnectorResponse:
"""
Asynchronously sends a prompt to the HuggingFace API and returns the generated response.
This method constructs a request with the given prompt, optionally prepended and appended with
predefined strings, and sends it to the HuggingFace API. The method then awaits the response from
the API, processes it, and returns the resulting message content wrapped in a ConnectorResponse object.
Args:
prompt (str): The input prompt to send to the HuggingFace API.
Returns:
ConnectorResponse: An object containing the text response generated by the HuggingFace model.
"""
connector_prompt = f"{self.pre_prompt}{prompt}{self.post_prompt}"
# Merge self.optional_params with additional parameters
new_params = {**self.optional_params, "inputs": connector_prompt}
async with aiohttp.ClientSession() as session:
async with session.post(
self.endpoint,
headers=self._prepare_headers(),
json=new_params,
) as response:
return ConnectorResponse(
response=await self._process_response(response)
)
def _prepare_headers(self) -> dict:
"""
Prepare HTTP headers for authentication using a bearer token.
This function takes a bearer token as input and prepares a dictionary of HTTP headers
commonly used for authentication in API requests.
Returns:
dict: A dictionary containing HTTP headers with the 'Authorization' header set to
'Bearer <bearer_token>'. This dictionary can be used in API requests for authentication
purposes.
"""
return {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
}
async def _process_response(self, response: ClientResponse) -> str:
"""
Process an HTTP response and extract relevant information as a string.
This function takes an HTTP response object as input and processes it to extract relevant information
as a string. The extracted information may include data from the response body, headers, or other attributes.
Args:
response (ClientResponse): An HTTP response object containing the response data.
Returns:
str: A string representing the relevant information extracted from the response.
"""
try:
json_response = await response.json()
return json_response[0]["generated_text"]
except Exception as exception:
logger.error(
f"[HuggingFaceConnector] An exception has occurred: {str(exception)}, {await response.json()}"
)
raise exception