Skip to content

Commit 83d7b55

Browse files
authored
Feat/bedrock support for Nova models through the ConverseAPI (#207)
* feat: added support for bedrock nova models * feat: tokens are now read from usage if available to ensure accuracy * chore: removed duplicated integration tests folder in wrong place * feat: refactored bedrock provider into being a single file instead of folder * chore: renamed bedrock to bedrock-converse in examples/core.py * chore: renamed bedrock in config.yaml
1 parent ddc250f commit 83d7b55

File tree

7 files changed

+68
-61
lines changed

7 files changed

+68
-61
lines changed

examples/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dotenv import load_dotenv
99
load_dotenv()
1010

11-
def run_provider(provider, model, api_key, **kwargs):
11+
def run_provider(provider, model, api_key=None, **kwargs):
1212
print(f"\n\n###RUNNING for <{provider}>, <{model}> ###")
1313
llm = LLMCore(provider=provider, api_key=api_key, **kwargs)
1414

@@ -107,6 +107,16 @@ def build_chat_request(model: str, chat_input: str, is_stream: bool, max_tokens:
107107
"max_completion_tokens": max_tokens
108108
}
109109
}
110+
elif 'amazon.nova' in model or 'anthropic.claude' in model:
111+
chat_request = {
112+
"chat_input": chat_input,
113+
"model": model,
114+
"is_stream": is_stream,
115+
"retries": 0,
116+
"parameters": {
117+
"maxTokens": max_tokens
118+
}
119+
}
110120
else:
111121
chat_request = {
112122
"chat_input": chat_input,
@@ -150,10 +160,7 @@ def multiple_provider_runs(provider:str, model:str, num_runs:int, api_key:str, *
150160

151161

152162
multiple_provider_runs(provider="vertexai", model="gemini-1.5-flash", num_runs=1, api_key=os.environ["GOOGLE_API_KEY"])
153-
# provider = "vertexai"
154-
# model = "gemini-1.5-pro-latest"
155-
# for _ in range(1):
156-
# latencies = run_provider(provider=provider, model=model,
157-
# api_key=os.environ["GOOGLE_API_KEY"],
158-
# )
159-
# pprint(latencies)
163+
164+
# Bedrock
165+
multiple_provider_runs(provider="bedrock", model="us.amazon.nova-lite-v1:0", num_runs=1, api_key=None, region=os.environ["BEDROCK_REGION"], secret_key=os.environ["BEDROCK_SECRET_KEY"], access_key=os.environ["BEDROCK_ACCESS_KEY"])
166+
#multiple_provider_runs(provider="bedrock", model="anthropic.claude-3-5-sonnet-20241022-v2:0", num_runs=1, api_key=None, region=os.environ["BEDROCK_REGION"], secret_key=os.environ["BEDROCK_SECRET_KEY"], access_key=os.environ["BEDROCK_ACCESS_KEY"])

libs/core/llmstudio_core/config.yaml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ providers:
7373
step: 1
7474
bedrock:
7575
id: bedrock
76-
name: Bedrock
76+
name: Bedrock ConverseAPI
7777
chat: true
7878
embed: true
7979
keys:
@@ -126,6 +126,22 @@ providers:
126126
max_tokens: 100000
127127
input_token_cost: 0.000008
128128
output_token_cost: 0.000024
129+
us.amazon.nova-pro-v1:0:
130+
mode: chat
131+
max_tokens: 300000
132+
input_token_cost: 0.0000008
133+
output_token_cost: 0.0000016
134+
us.amazon.nova-lite-v1:0:
135+
mode: chat
136+
max_tokens: 300000
137+
input_token_cost: 0.00000006
138+
output_token_cost: 0.00000012
139+
us.amazon.nova-micro-v1:0:
140+
mode: chat
141+
max_tokens: 128000
142+
input_token_cost: 0.000000035
143+
output_token_cost: 0.00000007
144+
129145
parameters:
130146
temperature:
131147
name: "Temperature"

libs/core/llmstudio_core/providers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
from llmstudio_core.providers.azure import AzureProvider
5-
from llmstudio_core.providers.bedrock.provider import BedrockProvider
5+
from llmstudio_core.providers.bedrock_converse import BedrockConverseProvider
66

77
# from llmstudio_core.providers.ollama import OllamaProvider #TODO: adapt it
88
from llmstudio_core.providers.openai import OpenAIProvider

libs/core/llmstudio_core/providers/bedrock/__init__.py

Whitespace-only changes.

libs/core/llmstudio_core/providers/bedrock/provider.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

libs/core/llmstudio_core/providers/bedrock/anthropic.py renamed to libs/core/llmstudio_core/providers/bedrock_converse.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
ChoiceDelta,
2424
ChoiceDeltaToolCall,
2525
ChoiceDeltaToolCallFunction,
26+
CompletionUsage,
2627
)
2728
from pydantic import ValidationError
2829

2930
SERVICE = "bedrock-runtime"
3031

3132

3233
@provider
33-
class BedrockAnthropicProvider(ProviderCore):
34+
class BedrockConverseProvider(ProviderCore):
3435
def __init__(self, config, **kwargs):
3536
super().__init__(config, **kwargs)
3637
self._client = boto3.client(
@@ -46,17 +47,17 @@ def __init__(self, config, **kwargs):
4647

4748
@staticmethod
4849
def _provider_config_name():
49-
return "bedrock-antropic"
50+
return "bedrock"
5051

5152
def validate_request(self, request: ChatRequest):
5253
return ChatRequest(**request)
5354

5455
async def agenerate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Any]:
55-
"""Generate an AWS Bedrock client"""
56+
"""Generate an AWS Bedrock Converse client"""
5657
return self.generate_client(request=request)
5758

5859
def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator]:
59-
"""Generate an AWS Bedrock client"""
60+
"""Generate an AWS Bedrock Converse client"""
6061
try:
6162
messages, system_prompt = self._process_messages(request.chat_input)
6263
tools = self._process_tools(request.parameters)
@@ -83,7 +84,9 @@ def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator
8384
async def aparse_response(
8485
self, response: Any, **kwargs
8586
) -> AsyncGenerator[Any, None]:
86-
return self.parse_response(response=response, **kwargs)
87+
result = self.parse_response(response=response, **kwargs)
88+
for chunk in result:
89+
yield chunk
8790

8891
def parse_response(self, response: AsyncGenerator[Any, None], **kwargs) -> Any:
8992
tool_name = None
@@ -222,6 +225,22 @@ def parse_response(self, response: AsyncGenerator[Any, None], **kwargs) -> Any:
222225
)
223226
yield final_chunk.model_dump()
224227

228+
elif chunk.get("metadata"):
229+
usage = chunk["metadata"].get("usage")
230+
final_stream_chunk = ChatCompletionChunk(
231+
id=str(uuid.uuid4()),
232+
choices=[],
233+
created=int(time.time()),
234+
model=kwargs.get("request").model,
235+
object="chat.completion.chunk",
236+
usage=CompletionUsage(
237+
completion_tokens=usage["outputTokens"],
238+
prompt_tokens=usage["inputTokens"],
239+
total_tokens=usage["totalTokens"],
240+
),
241+
)
242+
yield final_stream_chunk.model_dump()
243+
225244
@staticmethod
226245
def _process_messages(
227246
chat_input: Union[str, List[Dict[str, str]]]

libs/core/llmstudio_core/providers/provider.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,11 @@ def _calculate_metrics(
808808
output_tokens = len(self.tokenizer.encode(self._output_to_string(output)))
809809
total_tokens = input_tokens + output_tokens
810810

811+
if usage:
812+
input_tokens = usage.get("prompt_tokens", input_tokens)
813+
output_tokens = usage.get("completion_tokens", output_tokens)
814+
total_tokens = usage.get("total_tokens", total_tokens)
815+
811816
# Cost calculations
812817
input_cost = self._calculate_cost(input_tokens, model_config.input_token_cost)
813818
output_cost = self._calculate_cost(
@@ -823,9 +828,12 @@ def _calculate_metrics(
823828
)
824829
total_cost_usd -= cached_savings
825830

826-
reasoning_tokens = usage.get("completion_tokens_details", {}).get(
827-
"reasoning_tokens", None
828-
)
831+
completion_tokens_details = usage.get("completion_tokens_details")
832+
if completion_tokens_details:
833+
reasoning_tokens = completion_tokens_details.get(
834+
"reasoning_tokens", None
835+
)
836+
829837
if reasoning_tokens:
830838
total_tokens += reasoning_tokens
831839
reasoning_cost = self._calculate_cost(

0 commit comments

Comments
 (0)