Skip to content

Commit 4dc4bcc

Browse files
KatsiarynkaKate Yanchenko
andauthored
Fixes for claude/gemini calls (#18)
Co-authored-by: Kate Yanchenko <kate@lamoom.com>
1 parent 61545e0 commit 4dc4bcc

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

flow_prompt/ai_models/claude/claude_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,26 @@ def get_client(self, client_secrets: dict) -> anthropic.Anthropic:
8383
return anthropic.Anthropic(api_key=client_secrets.get('api_key'))
8484

8585

86+
def uny_all_messages_with_same_role(self, messages: t.List[dict]) -> t.List[dict]:
87+
result = []
88+
last_role = None
89+
for message in messages:
90+
if last_role != message.get("role"):
91+
result.append(message)
92+
last_role = message.get("role")
93+
else:
94+
result[-1]["content"] += message.get("content")
95+
return result
96+
97+
8698
def call(self, messages: t.List[dict], max_tokens: int, client_secrets: dict = {}, **kwargs) -> AIResponse:
8799
common_args = get_common_args(max_tokens)
88100
kwargs = {
89-
**{
90-
"messages": messages,
91-
},
92101
**common_args,
93102
**self.get_params(),
94103
**kwargs,
95104
}
105+
messages = self.uny_all_messages_with_same_role(messages)
96106

97107
logger.debug(
98108
f"Calling {messages} with max_tokens {max_tokens} and kwargs {kwargs}"

flow_prompt/ai_models/gemini/gemini_model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,24 @@ def call(self, messages: t.List[dict], max_tokens: int, client_secrets: dict = {
9191
stream_params = kwargs.get("stream_params")
9292

9393
# Parse only prompt content due to gemini call specifics
94-
prompts = [obj["content"] for obj in messages if obj["role"] == "user"]
94+
prompt = '\n\n'.join([obj["content"] for obj in messages])
9595

9696
content = ""
9797

9898
try:
9999
if not kwargs.get('stream'):
100-
for prompt in prompts:
101-
response = self.model.generate_content(prompt, stream=False)
102-
content = response.text
100+
response = self.model.generate_content(prompt, stream=False)
101+
content = response.text
103102
else:
103+
response = self.model.generate_content(prompt, stream=True)
104104
idx = 0
105-
for prompt in prompts:
105+
for chunk in response:
106106
if idx % 5 == 0:
107+
idx = 0
107108
if not check_connection(**stream_params):
108109
raise ConnectionLostError("Connection was lost!")
109-
response = self.model.generate_content(prompt, stream=True)
110-
for chunk in response:
111-
stream_function(chunk.text, **stream_params)
112-
content += chunk.text
110+
stream_function(chunk.text, **stream_params)
111+
content += chunk.text
113112
idx += 1
114113

115114
return GeminiAIResponse(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "flow-prompt"
3-
version = "0.1.14a17"
3+
version = "0.1.20"
44
description = ""
55
authors = ["Flow-prompt Engineering Team <engineering@flow-prompt.com>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)