Skip to content

Commit e83c839

Browse files
authored
Gemini tests (#23)
1 parent fb4bb32 commit e83c839

File tree

9 files changed

+227
-38
lines changed

9 files changed

+227
-38
lines changed

flow_prompt/ai_models/ai_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ def name(self) -> str:
3131
def price_per_prompt_1k_tokens(self) -> Decimal:
3232
return self._price_per_prompt_1k_tokens
3333

34+
def _decimal(self, value) -> Decimal:
35+
return Decimal(value).quantize(Decimal(".00001"))
36+
37+
def get_prompt_price(self, count_tokens: int) -> Decimal:
38+
return self._decimal(self.price_per_prompt_1k_tokens * Decimal(count_tokens) / 1000)
39+
40+
def get_sample_price(self, prompt_sample, count_tokens: int) -> Decimal:
41+
return self._decimal(self.price_per_sample_1k_tokens * Decimal(count_tokens) / 1000)
42+
3443
@property
3544
def price_per_sample_1k_tokens(self) -> Decimal:
3645
return self._price_per_sample_1k_tokens

flow_prompt/ai_models/claude/claude_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from flow_prompt.ai_models.ai_model import AI_MODELS_PROVIDER, AIModel
22
import logging
33

4+
from flow_prompt.ai_models.constants import C_200K
45
from flow_prompt.responses import AIResponse
56
from decimal import Decimal
67
from enum import Enum
@@ -20,8 +21,6 @@
2021
logger = logging.getLogger(__name__)
2122

2223

23-
C_200K = 200000
24-
2524

2625
class FamilyModel(Enum):
2726
haiku = "Claude 3 Haiku"
@@ -30,8 +29,8 @@ class FamilyModel(Enum):
3029

3130

3231
DEFAULT_PRICING = {
33-
"price_per_prompt_1k_tokens": Decimal(0.00025),
34-
"price_per_sample_1k_tokens": Decimal(0.00125),
32+
"price_per_prompt_1k_tokens": Decimal(0.003),
33+
"price_per_sample_1k_tokens": Decimal(0.015),
3534
}
3635

3736
CLAUDE_AI_PRICING = {
@@ -156,13 +155,17 @@ def name(self) -> str:
156155

157156
@property
158157
def price_per_prompt_1k_tokens(self) -> Decimal:
159-
return CLAUDE_AI_PRICING[self.family].get(self.max_tokens, DEFAULT_PRICING)[
158+
keys = list(CLAUDE_AI_PRICING[self.family].keys())
159+
def_pricing = CLAUDE_AI_PRICING[self.family].get(keys[0])
160+
return CLAUDE_AI_PRICING[self.family].get(self.max_tokens, def_pricing)[
160161
"price_per_prompt_1k_tokens"
161162
]
162163

163164
@property
164165
def price_per_sample_1k_tokens(self) -> Decimal:
165-
return CLAUDE_AI_PRICING[self.family].get(self.max_tokens, DEFAULT_PRICING)[
166+
keys = list(CLAUDE_AI_PRICING[self.family].keys())
167+
def_pricing = CLAUDE_AI_PRICING[self.family].get(keys[0])
168+
return CLAUDE_AI_PRICING[self.family].get(self.max_tokens, def_pricing)[
166169
"price_per_sample_1k_tokens"
167170
]
168171

flow_prompt/ai_models/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
C_4K = 4096
3+
C_8K = 8192
4+
C_16K = 16384
5+
C_32K = 32768
6+
7+
C_128K = 128_000
8+
C_200K = 200_000
9+
C_1M = 1_000_000

flow_prompt/ai_models/gemini/constants.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

flow_prompt/ai_models/gemini/gemini_model.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from flow_prompt.ai_models.ai_model import AI_MODELS_PROVIDER, AIModel
22
import logging
33

4+
from flow_prompt.ai_models.constants import C_1M, C_128K
45
from flow_prompt.responses import AIResponse
56
from decimal import Decimal
67
from enum import Enum
@@ -9,22 +10,27 @@
910
from dataclasses import dataclass
1011

1112
from flow_prompt.ai_models.gemini.responses import GeminiAIResponse
12-
from flow_prompt.ai_models.gemini.constants import FLASH, PRO
13+
1314
from flow_prompt.ai_models.utils import get_common_args
1415
from openai.types.chat import ChatCompletionMessage as Message
1516
from flow_prompt.responses import Prompt
1617
from flow_prompt.exceptions import RetryableCustomError, ConnectionLostError
1718
import google.generativeai as genai
1819

20+
1921
logger = logging.getLogger(__name__)
2022

2123

22-
C_128K = 127_000
24+
FLASH = "gemini-1.5-flash"
25+
PRO = "gemini-1.5-pro"
26+
PRO_1_0 = "gemini-1.0-pro"
27+
2328

2429

2530
class FamilyModel(Enum):
2631
flash = "Gemini 1.5 Flash"
2732
pro = "Gemini 1.5 Pro"
33+
pro_1_0 = "Gemini 1.0 Pro"
2834

2935

3036
DEFAULT_PRICING = {
@@ -35,14 +41,28 @@ class FamilyModel(Enum):
3541
GEMINI_AI_PRICING = {
3642
FamilyModel.flash.value: {
3743
C_128K: {
38-
"price_per_prompt_1k_tokens": Decimal(0.00035),
39-
"price_per_sample_1k_tokens": Decimal(0.00105),
44+
"price_per_prompt_1k_tokens": Decimal(0.0075),
45+
"price_per_sample_1k_tokens": Decimal(0.030),
46+
},
47+
C_1M: {
48+
"price_per_prompt_1k_tokens": Decimal(0.015),
49+
"price_per_sample_1k_tokens": Decimal(0.060),
50+
}
51+
},
52+
FamilyModel.pro_1_0.value: {
53+
C_1M: {
54+
"price_per_prompt_1k_tokens": Decimal(0.0005),
55+
"price_per_sample_1k_tokens": Decimal(0.0015),
4056
}
4157
},
4258
FamilyModel.pro.value: {
4359
C_128K: {
4460
"price_per_prompt_1k_tokens": Decimal(0.0035),
4561
"price_per_sample_1k_tokens": Decimal(0.0105),
62+
},
63+
C_1M: {
64+
"price_per_prompt_1k_tokens": Decimal(0.007),
65+
"price_per_sample_1k_tokens": Decimal(0.021),
4666
}
4767
},
4868
}
@@ -51,13 +71,17 @@ class FamilyModel(Enum):
5171
@dataclass(kw_only=True)
5272
class GeminiAIModel(AIModel):
5373
model: str
74+
max_tokens: int = C_1M
5475
gemini_model: genai.GenerativeModel = None
5576
provider: AI_MODELS_PROVIDER = AI_MODELS_PROVIDER.GEMINI
5677
family: str = None
5778

5879
def __post_init__(self):
80+
self.model = self.model.lower()
5981
if FLASH in self.model:
6082
self.family = FamilyModel.flash.value
83+
elif PRO_1_0 in self.model:
84+
self.family = FamilyModel.pro_1_0.value
6185
elif PRO in self.model:
6286
self.family = FamilyModel.pro.value
6387
else:
@@ -129,18 +153,6 @@ def call(self, messages: t.List[dict], max_tokens: int, client_secrets: dict = {
129153
def name(self) -> str:
130154
return self.model
131155

132-
@property
133-
def price_per_prompt_1k_tokens(self) -> Decimal:
134-
return GEMINI_AI_PRICING[self.family].get(self.max_tokens, DEFAULT_PRICING)[
135-
"price_per_prompt_1k_tokens"
136-
]
137-
138-
@property
139-
def price_per_sample_1k_tokens(self) -> Decimal:
140-
return GEMINI_AI_PRICING[self.family].get(self.max_tokens, DEFAULT_PRICING)[
141-
"price_per_sample_1k_tokens"
142-
]
143-
144156
def get_params(self) -> t.Dict[str, t.Any]:
145157
return {
146158
"model": self.model,
@@ -152,3 +164,19 @@ def get_metrics_data(self) -> t.Dict[str, t.Any]:
152164
"model": self.model,
153165
"max_tokens": self.max_tokens,
154166
}
167+
168+
169+
def get_prompt_price(self, count_tokens: int) -> Decimal:
170+
for key in sorted(GEMINI_AI_PRICING[self.family].keys()):
171+
if count_tokens < key:
172+
logger.info(f"Prompt price for {count_tokens} tokens is {GEMINI_AI_PRICING[self.family][key]['price_per_prompt_1k_tokens'] * Decimal(count_tokens) / 1000}")
173+
return self._decimal(GEMINI_AI_PRICING[self.family][key]["price_per_prompt_1k_tokens"] * Decimal(count_tokens) / 1000)
174+
175+
return self._decimal(self.price_per_prompt_1k_tokens * Decimal(count_tokens) / 1000)
176+
177+
def get_sample_price(self, prompt_sample, count_tokens: int) -> Decimal:
178+
for key in sorted(GEMINI_AI_PRICING[self.family].keys()):
179+
if prompt_sample < key:
180+
logger.info(f"Sample price for {count_tokens} tokens is {GEMINI_AI_PRICING[self.family][key]['price_per_prompt_1k_tokens'] * Decimal(count_tokens) / 1000}")
181+
return self._decimal(GEMINI_AI_PRICING[self.family][key]["price_per_sample_1k_tokens"] * Decimal(count_tokens) / 1000)
182+
return self._decimal(self.price_per_sample_1k_tokens * Decimal(count_tokens) / 1000)

flow_prompt/ai_models/openai/openai_models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from openai import OpenAI
88

99
from flow_prompt.ai_models.ai_model import AI_MODELS_PROVIDER, AIModel
10+
from flow_prompt.ai_models.constants import C_128K, C_16K, C_32K, C_4K
1011
from flow_prompt.ai_models.openai.responses import OpenAIResponse
1112
from flow_prompt.ai_models.utils import get_common_args
1213
from flow_prompt.exceptions import ConnectionLostError
@@ -16,11 +17,6 @@
1617

1718
from .utils import raise_openai_exception
1819

19-
C_4K = 4096
20-
C_8K = 8192
21-
C_128K = 127_000
22-
C_16K = 16384
23-
C_32K = 32768
2420
M_DAVINCI = "davinci"
2521

2622
logger = logging.getLogger(__name__)

flow_prompt/prompt/flow_prompt.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,7 @@ def calculate_budget_for_text(self, user_prompt: UserPrompt, text: str) -> int:
200200
return 0
201201
return len(user_prompt.encoding.encode(text))
202202

203-
def _decimal(self, value) -> Decimal:
204-
return Decimal(value).quantize(Decimal(".00001"))
205-
206203
def get_price(
207204
self, attempt: AttemptToCall, sample_budget: int, prompt_budget: int
208205
) -> Decimal:
209-
return self._decimal(
210-
prompt_budget * attempt.ai_model.price_per_prompt_1k_tokens / 1000
211-
) + self._decimal(
212-
sample_budget * attempt.ai_model.price_per_sample_1k_tokens / 1000
213-
)
206+
return attempt.ai_model.get_prompt_price(prompt_budget) + attempt.ai_model.get_sample_price(prompt_budget, sample_budget)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "flow-prompt"
3-
version = "0.1.22a1"
3+
version = "0.1.26"
44
description = ""
55
authors = ["Flow-prompt Engineering Team <engineering@flow-prompt.com>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)