Skip to content

Commit

Permalink
oai model support #1
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Sep 9, 2023
1 parent fac9b42 commit bb7bdf3
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions llmx/generators/text/openai_textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,22 @@ def __init__(
openai.api_key = api_key
if organization:
openai.organization = organization
if api_type and api_type == "azure":
openai.api_base = api_base
if api_version:
openai.api_version = api_version
self.api_key = api_key
self.api_type = api_type
self.api_base = api_base
self.api_version = api_version
if api_base:
openai.api_base = api_base
if api_type:
openai.api_type = api_type

def generate(
self, messages: Union[List[dict],
str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs) -> TextGenerationResponse:
# print content of class fields
# print(vars(openai))

def generate(
self,
messages: Union[List[dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or "gpt-3.5-turbo-0301"
prompt_tokens = num_tokens_from_messages(messages)
Expand All @@ -70,8 +72,9 @@ def generate(
"messages": messages,
}

if self.api_type and self.api_type == "azure":
oai_config["deployment_id"] = config.model
if openai.api_type and openai.api_type == "azure":
oai_config["engine"] = config.model

self.model_name = model
cache_key_params = (oai_config) | {"messages": messages}
if use_cache:
Expand All @@ -88,7 +91,9 @@ def generate(
usage=dict(oai_response.usage),
)
# if use_cache:
cache_request(cache=self.cache, params=cache_key_params, values=asdict(response))
cache_request(
cache=self.cache, params=cache_key_params, values=asdict(response)
)
return response

def count_tokens(self, text) -> int:
Expand Down

0 comments on commit bb7bdf3

Please sign in to comment.