11from flow_prompt .ai_models .ai_model import AI_MODELS_PROVIDER , AIModel
22import logging
33
4+ from flow_prompt .ai_models .constants import C_1M , C_128K
45from flow_prompt .responses import AIResponse
56from decimal import Decimal
67from enum import Enum
910from dataclasses import dataclass
1011
1112from flow_prompt .ai_models .gemini .responses import GeminiAIResponse
12- from flow_prompt . ai_models . gemini . constants import FLASH , PRO
13+
1314from flow_prompt .ai_models .utils import get_common_args
1415from openai .types .chat import ChatCompletionMessage as Message
1516from flow_prompt .responses import Prompt
1617from flow_prompt .exceptions import RetryableCustomError , ConnectionLostError
1718import google .generativeai as genai
1819
20+
1921logger = logging .getLogger (__name__ )
2022
2123
22- C_128K = 127_000
24+ FLASH = "gemini-1.5-flash"
25+ PRO = "gemini-1.5-pro"
26+ PRO_1_0 = "gemini-1.0-pro"
27+
2328
2429
2530class FamilyModel (Enum ):
2631 flash = "Gemini 1.5 Flash"
2732 pro = "Gemini 1.5 Pro"
33+ pro_1_0 = "Gemini 1.0 Pro"
2834
2935
3036DEFAULT_PRICING = {
@@ -35,14 +41,28 @@ class FamilyModel(Enum):
3541GEMINI_AI_PRICING = {
3642 FamilyModel .flash .value : {
3743 C_128K : {
38- "price_per_prompt_1k_tokens" : Decimal (0.00035 ),
39- "price_per_sample_1k_tokens" : Decimal (0.00105 ),
44+ "price_per_prompt_1k_tokens" : Decimal (0.0075 ),
45+ "price_per_sample_1k_tokens" : Decimal (0.030 ),
46+ },
47+ C_1M : {
48+ "price_per_prompt_1k_tokens" : Decimal (0.015 ),
49+ "price_per_sample_1k_tokens" : Decimal (0.060 ),
50+ }
51+ },
52+ FamilyModel .pro_1_0 .value : {
53+ C_1M : {
54+ "price_per_prompt_1k_tokens" : Decimal (0.0005 ),
55+ "price_per_sample_1k_tokens" : Decimal (0.0015 ),
4056 }
4157 },
4258 FamilyModel .pro .value : {
4359 C_128K : {
4460 "price_per_prompt_1k_tokens" : Decimal (0.0035 ),
4561 "price_per_sample_1k_tokens" : Decimal (0.0105 ),
62+ },
63+ C_1M : {
64+ "price_per_prompt_1k_tokens" : Decimal (0.007 ),
65+ "price_per_sample_1k_tokens" : Decimal (0.021 ),
4666 }
4767 },
4868}
@@ -51,13 +71,17 @@ class FamilyModel(Enum):
5171@dataclass (kw_only = True )
5272class GeminiAIModel (AIModel ):
5373 model : str
74+ max_tokens : int = C_1M
5475 gemini_model : genai .GenerativeModel = None
5576 provider : AI_MODELS_PROVIDER = AI_MODELS_PROVIDER .GEMINI
5677 family : str = None
5778
5879 def __post_init__ (self ):
80+ self .model = self .model .lower ()
5981 if FLASH in self .model :
6082 self .family = FamilyModel .flash .value
83+ elif PRO_1_0 in self .model :
84+ self .family = FamilyModel .pro_1_0 .value
6185 elif PRO in self .model :
6286 self .family = FamilyModel .pro .value
6387 else :
@@ -129,18 +153,6 @@ def call(self, messages: t.List[dict], max_tokens: int, client_secrets: dict = {
129153 def name (self ) -> str :
130154 return self .model
131155
132- @property
133- def price_per_prompt_1k_tokens (self ) -> Decimal :
134- return GEMINI_AI_PRICING [self .family ].get (self .max_tokens , DEFAULT_PRICING )[
135- "price_per_prompt_1k_tokens"
136- ]
137-
138- @property
139- def price_per_sample_1k_tokens (self ) -> Decimal :
140- return GEMINI_AI_PRICING [self .family ].get (self .max_tokens , DEFAULT_PRICING )[
141- "price_per_sample_1k_tokens"
142- ]
143-
144156 def get_params (self ) -> t .Dict [str , t .Any ]:
145157 return {
146158 "model" : self .model ,
@@ -152,3 +164,19 @@ def get_metrics_data(self) -> t.Dict[str, t.Any]:
152164 "model" : self .model ,
153165 "max_tokens" : self .max_tokens ,
154166 }
167+
168+
169+ def get_prompt_price (self , count_tokens : int ) -> Decimal :
170+ for key in sorted (GEMINI_AI_PRICING [self .family ].keys ()):
171+ if count_tokens < key :
172+ logger .info (f"Prompt price for { count_tokens } tokens is { GEMINI_AI_PRICING [self .family ][key ]['price_per_prompt_1k_tokens' ] * Decimal (count_tokens ) / 1000 } " )
173+ return self ._decimal (GEMINI_AI_PRICING [self .family ][key ]["price_per_prompt_1k_tokens" ] * Decimal (count_tokens ) / 1000 )
174+
175+ return self ._decimal (self .price_per_prompt_1k_tokens * Decimal (count_tokens ) / 1000 )
176+
177+ def get_sample_price (self , prompt_sample , count_tokens : int ) -> Decimal :
178+ for key in sorted (GEMINI_AI_PRICING [self .family ].keys ()):
179+ if prompt_sample < key :
180+ logger .info (f"Sample price for { count_tokens } tokens is { GEMINI_AI_PRICING [self .family ][key ]['price_per_prompt_1k_tokens' ] * Decimal (count_tokens ) / 1000 } " )
181+ return self ._decimal (GEMINI_AI_PRICING [self .family ][key ]["price_per_sample_1k_tokens" ] * Decimal (count_tokens ) / 1000 )
182+ return self ._decimal (self .price_per_sample_1k_tokens * Decimal (count_tokens ) / 1000 )
0 commit comments