Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions atoma-bin/atoma_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,7 @@ async fn main() -> Result<()> {
encryption_sender: app_state_encryption_sender,
compute_shared_secret_sender,
tokenizers: Arc::new(tokenizers),
models: Arc::new(
config
.service
.models
.into_iter()
.map(|model| model.to_lowercase())
.collect(),
),
models: Arc::new(config.service.models),
chat_completions_service_urls: config.service.chat_completions_service_urls,
embeddings_service_url: config
.service
Expand Down
87 changes: 44 additions & 43 deletions atoma-service/src/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,7 @@ pub async fn chat_completions_handler(
let model = payload
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or_default()
.to_lowercase();
.unwrap_or_default();

match handle_response(
&state,
Expand All @@ -261,33 +260,35 @@ pub async fn chat_completions_handler(
Ok(response) => {
CHAT_COMPLETIONS_ESTIMATED_TOTAL_TOKENS.add(
num_input_tokens + estimated_output_tokens,
&[KeyValue::new(MODEL_KEY, model.clone())],
&[KeyValue::new(MODEL_KEY, model.to_owned())],
);
if !is_stream {
TOTAL_COMPLETED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_COMPLETED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
Ok(response)
}
Err(e) => {
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_UNAUTHORIZED_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
_ => {
TOTAL_FAILED_CHAT_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_CHAT_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
}

Expand Down Expand Up @@ -316,7 +317,7 @@ pub async fn chat_completions_handler(
&state.state_manager_sender,
user_id,
user_address,
model.clone(),
model.to_string(),
num_input_tokens,
0,
estimated_output_tokens,
Expand Down Expand Up @@ -464,10 +465,10 @@ pub async fn confidential_chat_completions_handler(
let model = payload
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or(UNKNOWN_MODEL)
.to_lowercase();
.unwrap_or(UNKNOWN_MODEL);

CHAT_COMPLETIONS_CONFIDENTIAL_NUM_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
CHAT_COMPLETIONS_CONFIDENTIAL_NUM_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);

let endpoint = request_metadata.endpoint_path.clone();

Expand All @@ -491,34 +492,35 @@ pub async fn confidential_chat_completions_handler(
Ok(response) => {
CHAT_COMPLETIONS_ESTIMATED_TOTAL_TOKENS.add(
num_input_tokens + estimated_output_tokens,
&[KeyValue::new(MODEL_KEY, model.clone())],
&[KeyValue::new(MODEL_KEY, model.to_owned())],
);
if !is_stream {
TOTAL_COMPLETED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_COMPLETED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
Ok(response)
}
Err(e) => {
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_UNAUTHORIZED_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
_ => {
TOTAL_FAILED_CHAT_CONFIDENTIAL_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
}
}
// NOTE: We need to update the stack number of tokens as the service failed to generate
Expand Down Expand Up @@ -546,7 +548,7 @@ pub async fn confidential_chat_completions_handler(
&state.state_manager_sender,
user_id,
user_address,
model.clone(),
model.to_string(),
num_input_tokens,
0,
estimated_output_tokens,
Expand Down Expand Up @@ -755,10 +757,9 @@ async fn handle_non_streaming_response(
let model = payload
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or(UNKNOWN_MODEL)
.to_lowercase();
.unwrap_or(UNKNOWN_MODEL);

CHAT_COMPLETIONS_NUM_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
CHAT_COMPLETIONS_NUM_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
let timer = Instant::now();
debug!(
target = "atoma-service",
Expand All @@ -778,7 +779,7 @@ async fn handle_non_streaming_response(
level = "debug",
"Received non-streaming chat completions response from {endpoint}"
);
let (input_tokens, output_tokens) = utils::extract_total_num_tokens(&response_body, &model);
let (input_tokens, output_tokens) = utils::extract_total_num_tokens(&response_body, model);

utils::serve_non_streaming_response(
state,
Expand All @@ -795,7 +796,7 @@ async fn handle_non_streaming_response(
client_encryption_metadata,
endpoint,
timer,
&model,
model,
)
.await
}
Expand Down Expand Up @@ -886,14 +887,13 @@ async fn handle_streaming_response(
let model = payload
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or(UNKNOWN_MODEL)
.to_lowercase();
CHAT_COMPLETIONS_NUM_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
.unwrap_or(UNKNOWN_MODEL);
CHAT_COMPLETIONS_NUM_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model.to_owned())]);
let timer = Instant::now();

let chat_completions_service_urls = state
.chat_completions_service_urls
.get(&model)
.get(&model.to_lowercase())
.ok_or_else(|| {
AtomaServiceError::InternalError {
message: format!(
Expand All @@ -907,7 +907,7 @@ async fn handle_streaming_response(
get_best_available_chat_completions_service_url(
&state.running_num_requests,
chat_completions_service_urls,
&model,
&model.to_lowercase(),
state.memory_upper_threshold,
state.max_num_queued_requests,
)
Expand All @@ -917,7 +917,9 @@ async fn handle_streaming_response(
endpoint: endpoint.clone(),
})?;
if status_code == StatusCode::TOO_MANY_REQUESTS {
state.too_many_requests.insert(model, Instant::now());
state
.too_many_requests
.insert(model.to_string(), Instant::now());
return Err(AtomaServiceError::ChatCompletionsServiceUnavailable {
message: "Too many requests".to_string(),
endpoint: endpoint.clone(),
Expand Down Expand Up @@ -997,7 +999,7 @@ async fn handle_streaming_response(
payload_hash,
state.keystore.clone(),
state.address_index,
model.clone(),
model.to_string(),
streaming_encryption_metadata,
endpoint,
request_id,
Expand Down Expand Up @@ -1323,11 +1325,10 @@ pub mod utils {
let model = payload
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or(UNKNOWN_MODEL)
.to_lowercase();
.unwrap_or(UNKNOWN_MODEL);
let chat_completions_service_url_services = state
.chat_completions_service_urls
.get(&model)
.get(&model.to_lowercase())
.ok_or_else(|| {
AtomaServiceError::InternalError {
message: format!(
Expand All @@ -1341,7 +1342,7 @@ pub mod utils {
get_best_available_chat_completions_service_url(
&state.running_num_requests,
chat_completions_service_url_services,
&model,
model,
state.memory_upper_threshold,
state.max_num_queued_requests,
)
Expand All @@ -1353,7 +1354,7 @@ pub mod utils {
if status_code == StatusCode::TOO_MANY_REQUESTS {
state
.too_many_requests
.insert(model.clone(), Instant::now());
.insert(model.to_string(), Instant::now());
return Err(AtomaServiceError::ChatCompletionsServiceUnavailable {
message: "Too many requests".to_string(),
endpoint: endpoint.to_string(),
Expand Down Expand Up @@ -1677,7 +1678,7 @@ pub mod utils {
&state.state_manager_sender,
user_id,
user_address,
model.to_owned(),
model.to_string(),
estimated_input_tokens,
input_tokens,
estimated_output_tokens,
Expand Down
Loading
Loading