Skip to content

Commit

Permalink
add candidate count support
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Sep 4, 2023
1 parent 3926aa8 commit 0e2bc00
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions llmx/generators/text/palm_textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Union
from .base_textgen import TextGenerator
from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message
from ...utils import cache_request, gcp_request, num_tokens_from_messages, get_gcp_credentials
from ...utils import (
cache_request,
gcp_request,
num_tokens_from_messages,
get_gcp_credentials,
)
from ..text.providers import providers


Expand All @@ -29,28 +34,30 @@ def format_messages(self, messages):
if message["role"] == "system":
system_messages += message["content"] + "\n"
else:
if not palm_messages or palm_messages[-1]['author'] != message['role']:
if not palm_messages or palm_messages[-1]["author"] != message["role"]:
palm_message = {
"author": message["role"],
"content": message["content"],
}
palm_messages.append(palm_message)
else:
palm_messages[-1]['content'] += '\n' + message['content']
palm_messages[-1]["content"] += "\n" + message["content"]

if palm_messages and len(palm_messages) % 2 == 0:
merged_content = palm_messages[-2]['content'] + '\n' + palm_messages[-1]['content']
palm_messages[-2]['content'] = merged_content
merged_content = (
palm_messages[-2]["content"] + "\n" + palm_messages[-1]["content"]
)
palm_messages[-2]["content"] = merged_content
palm_messages.pop()

return system_messages, palm_messages

def generate(
self, messages: Union[list[dict],
str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs) -> TextGenerationResponse:

self,
messages: Union[list[dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or "codechat-bison"
system_messages, messages = self.format_messages(messages)
Expand All @@ -62,31 +69,31 @@ def generate(

max_tokens = self.model_list[config.model] if model in self.model_list else 1024
palm_config = {
'temperature': config.temperature,
'maxOutputTokens': config.max_tokens or max_tokens
"temperature": config.temperature,
"maxOutputTokens": config.max_tokens or max_tokens,
"candidateCount": config.n,
}
palm_payload = {
'instances': [
{'messages': messages,
'context': system_messages,
'examples': [],
}
"instances": [
{
"messages": messages,
"context": system_messages,
"examples": [],
}
],
'parameters': palm_config
"parameters": palm_config,
}
# print("*********", use_cache, palm_payload)

cache_key_params = palm_payload
cache_key_params = {**palm_payload, "model": model}
if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
return TextGenerationResponse(**response)

palm_response = gcp_request(
url=api_url,
body=palm_payload,
method="POST",
credentials=self.credentials)
url=api_url, body=palm_payload, method="POST", credentials=self.credentials
)

response_text = [
Message(
Expand All @@ -107,7 +114,9 @@ def generate(
},
)

cache_request(cache=self.cache, params=(cache_key_params), values=asdict(response))
cache_request(
cache=self.cache, params=(cache_key_params), values=asdict(response)
)
return response

def count_tokens(self, text) -> int:
Expand Down

0 comments on commit 0e2bc00

Please sign in to comment.