Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Triton Punica kernels #658

Merged
merged 77 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
832d905
Collect timings
tgaddair Oct 23, 2024
697bf4d
Profiler
tgaddair Oct 23, 2024
d155163
Allow max batch prefill tokens < max input length
tgaddair Oct 23, 2024
ca3280c
Fix fallback
tgaddair Oct 23, 2024
830ce3d
Vectorize test
tgaddair Oct 23, 2024
7f250fe
Triton punica kernels
tgaddair Oct 24, 2024
e4fb765
Use triton punica
tgaddair Oct 24, 2024
634c8e2
Fix format
tgaddair Oct 24, 2024
7870729
Plumb weights
tgaddair Oct 24, 2024
0e057f0
Fixed issues
tgaddair Oct 24, 2024
c8ad4cb
Fixed cuda graphs
tgaddair Oct 24, 2024
a82eb64
Remove debug
tgaddair Oct 24, 2024
f68d2c0
Remove debug
tgaddair Oct 24, 2024
2ffc1db
Move init to warmup
tgaddair Oct 24, 2024
ea6c86d
Fix preloaded and speculators
tgaddair Oct 24, 2024
0497a76
Docker test
tgaddair Oct 24, 2024
9e2a29d
Profiling docs
tgaddair Oct 24, 2024
94e3742
Revert timings
tgaddair Oct 25, 2024
0abeccc
Fixed merge
tgaddair Oct 25, 2024
6f5a976
Added LORAX_SPECULATION_MAX_BATCH_SIZE
tgaddair Oct 26, 2024
f89ee87
Try separate trees per adapter
tgaddair Oct 27, 2024
23a77d2
Allow refcount==0
tgaddair Oct 27, 2024
22ed54d
Message
tgaddair Oct 28, 2024
327bb91
Docker test
tgaddair Oct 28, 2024
fbb2b3f
Cleanup
tgaddair Oct 28, 2024
f0693e9
Padding
tgaddair Oct 28, 2024
e62e0f8
Fixed turbo lora + compile
tgaddair Oct 28, 2024
66d8676
Fix
tgaddair Oct 28, 2024
55e5c41
Fix adapter root node id
tgaddair Oct 30, 2024
a6f3a17
More tests
tgaddair Oct 30, 2024
352c92a
Docker test
tgaddair Oct 30, 2024
1ea8d6e
Bump flashinfer
tgaddair Oct 30, 2024
c0640f2
Added logprobs fix
tgaddair Oct 31, 2024
54c36c9
Fix slots
tgaddair Oct 31, 2024
88cd932
No debugging
tgaddair Oct 31, 2024
3505b52
Docker test
tgaddair Oct 31, 2024
cf3d2d9
Fixed slot filtering
tgaddair Oct 31, 2024
d1ff7b4
Triton kernels
tgaddair Oct 31, 2024
57c33d7
Fix ragged
tgaddair Oct 31, 2024
ece47f7
More fixes
tgaddair Oct 31, 2024
779bff3
Merge
tgaddair Oct 31, 2024
cb99320
Revert docker
tgaddair Oct 31, 2024
466ea37
Renamed sgmv -> punica
tgaddair Oct 31, 2024
2f80c6a
Refactor PunicaWrapper
tgaddair Oct 31, 2024
47bfd0c
More configuration
tgaddair Oct 31, 2024
2343d78
More logs
tgaddair Oct 31, 2024
f915abe
Fixes
tgaddair Oct 31, 2024
ad460c0
Guard init
tgaddair Nov 1, 2024
43c129b
Guard model has lm_head
tgaddair Nov 1, 2024
1c70ec6
Determine trace set from preloaded adapter set
tgaddair Nov 1, 2024
3ebcbea
Plumb skip_lm_head
tgaddair Nov 1, 2024
922c5d6
Cleanup comments
tgaddair Nov 1, 2024
b2de54f
Fixed orient for rank
tgaddair Nov 1, 2024
35c7de2
Format
tgaddair Nov 1, 2024
295829f
Fixed tests
tgaddair Nov 1, 2024
ef86071
Fixed CausalLM and embedding model
tgaddair Nov 1, 2024
0d78a0a
Replace flume
tgaddair Nov 1, 2024
8cb79b2
Remove unused dep
tgaddair Nov 1, 2024
045a45a
Update axum
tgaddair Nov 1, 2024
20cf752
Client debug mode, fixed /
tgaddair Nov 1, 2024
2868acc
Docker test
tgaddair Nov 1, 2024
2131dc1
Fixed unused imports
tgaddair Nov 1, 2024
b727a94
Revert docker
tgaddair Nov 1, 2024
cc17d47
Add back tracing
tgaddair Nov 1, 2024
68991ba
Debug
tgaddair Nov 1, 2024
5380426
Docker test
tgaddair Nov 1, 2024
89abd51
Debug registration
tgaddair Nov 1, 2024
3c7b69b
Update tag
tgaddair Nov 1, 2024
d52f530
Don't skip filter
tgaddair Nov 4, 2024
45c6c53
Docker test
tgaddair Nov 4, 2024
3ad4d66
Remove register
tgaddair Nov 4, 2024
b45c219
Revert docker
tgaddair Nov 4, 2024
a4a2d5f
Fixed tests
tgaddair Nov 4, 2024
4a264bc
ruff
tgaddair Nov 4, 2024
e1067a0
Fix tests
tgaddair Nov 4, 2024
848b4c7
Clear cache
tgaddair Nov 4, 2024
107be9a
Check for key in lora weights
tgaddair Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Debug registration
  • Loading branch information
tgaddair committed Nov 1, 2024
commit 89abd515efb20f1b2f364b10a0506c0840529ba7
29 changes: 27 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,17 @@ impl Infer {
let mut result_start = None;
let mut result_queued = None;

tracing::info!("Waiting for response");

let mut id = None;

// Iterate on stream
while let Some(response) = stream.next().await {
match response? {
InferStreamResponse::Register { id_val } => {
id = Some(id_val);
tracing::info!("Register response id={id:?}");
}
// Add prefill tokens
InferStreamResponse::Prefill {
tokens,
Expand All @@ -428,9 +436,13 @@ impl Infer {
.collect();
}
result_prefill_length = tokens_length;
tracing::info!("Prefill response id={id:?}");
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
InferStreamResponse::Token(token) => {
tracing::info!("Token response id={id:?}");
result_tokens.push(token)
}
// Final message
// Set return values
InferStreamResponse::End {
Expand All @@ -439,6 +451,7 @@ impl Infer {
start,
queued,
} => {
tracing::info!("End response id={id:?}");
result_tokens.push(token);
result_generated_text = Some(generated_text);
result_start = Some(start);
Expand All @@ -455,6 +468,8 @@ impl Infer {
}
}

tracing::info!("Finished response id={id:?}");

// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
Expand Down Expand Up @@ -564,6 +579,9 @@ impl Infer {
let mut stream = UnboundedReceiverStream::new(response_rx);
while let Some(response) = stream.next().await {
match response? {
InferStreamResponse::Register { .. } => {
tracing::error!("Received a Register message in embed. This is a bug.");
}
// Add prefill tokens
InferStreamResponse::Prefill { .. } => {
tracing::error!("Received a Prefill message in embed. This is a bug.");
Expand Down Expand Up @@ -667,6 +685,9 @@ impl Infer {
let mut stream = UnboundedReceiverStream::new(response_rx);
while let Some(response) = stream.next().await {
match response? {
InferStreamResponse::Register { .. } => {
tracing::error!("Received a Register message in classify. This is a bug.");
}
// Add prefill tokens
InferStreamResponse::Prefill { .. } => {
tracing::error!("Received a Prefill message in classify. This is a bug.");
Expand Down Expand Up @@ -1350,7 +1371,8 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is closed
if entry.response_tx.is_closed() {
tracing::error!("Entry response channel closed.");
let id = generation.request_id;
tracing::error!("Entry id={id:?} response channel closed.");
metrics::increment_counter!("lorax_request_failure", "err" => "dropped");
return Ok(true);
}
Expand Down Expand Up @@ -1497,6 +1519,9 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Register {
id_val: u64,
},
Prefill {
tokens: Option<PrefillTokens>,
tokens_length: u32,
Expand Down
10 changes: 9 additions & 1 deletion router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
use tokio::{sync::Notify, time::Instant};
use tracing::info_span;

use crate::{adapter::Adapter, batch::Entry};
use crate::{adapter::Adapter, batch::Entry, infer::InferStreamResponse};

#[derive(Debug, PartialEq)]
pub(crate) enum AdapterStatus {
Expand Down Expand Up @@ -74,6 +74,11 @@ impl QueueState {
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);

entry
.response_tx
.send(Ok(InferStreamResponse::Register { id_val: entry_id }))
.unwrap();

// Push entry in the queue
self.entries.push_back((entry_id, entry));
}
Expand Down Expand Up @@ -214,9 +219,12 @@ impl AdapterQueuesState {

// ensure that append completes before sending batcher message
let queue = self.queue_map.get_mut(&adapter).unwrap();
let id = self.next_id;
queue.append(self.next_id, entry);
self.next_id += 1;

tracing::info!("append entry id={:?} adapter={:?}", id, adapter.index());

return download;
}

Expand Down
5 changes: 5 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,11 @@ async fn generate_stream_with_callback(
match response {
Ok(response) => {
match response {
InferStreamResponse::Register {
..
} => {
// Register is ignored
}
// Prefill is ignored
InferStreamResponse::Prefill {
tokens_length,
Expand Down
Loading