Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return streaming errors as an event formatted for openai's client #2668

Merged
Next Next commit
feat: return streaming errors as an event formatted for openai's client
  • Loading branch information
drbh authored and Narsil committed Nov 15, 2024
commit 84cd8434b076734b091172ab32c38e106e1cc388
197 changes: 113 additions & 84 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1274,99 +1274,108 @@ pub(crate) async fn chat_completions(
};
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
match result {
Ok(stream_token) => {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
None => {
continue;
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);

yield Ok::<Event, Infallible>(event);
}

// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);

yield Ok::<Event, Infallible>(event);
}
}
Err(err) => {
let error_event: ErrorEvent = err.into();
let event = Event::default().json_data(error_event).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
});
yield Ok::<Event, Infallible>(event);
break;
}
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
Expand Down Expand Up @@ -2517,6 +2526,26 @@ impl From<InferError> for Event {
}
}

#[derive(serde::Serialize)]
pub struct ErrorWithMessage {
message: String,
}

#[derive(serde::Serialize)]
pub struct ErrorEvent {
error: ErrorWithMessage,
}

impl From<InferError> for ErrorEvent {
fn from(err: InferError) -> Self {
ErrorEvent {
error: ErrorWithMessage {
message: err.to_string(),
},
}
}
}

#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Axum error: {0}")]
Expand Down