Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): auto max_batch_total_tokens for flash att models #630

Merged
merged 19 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,

/// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it).
Expand Down Expand Up @@ -369,12 +369,6 @@ fn shard_manager(
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

// Use cuda allocator. It leads to less memory fragmentation
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));

// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
Expand Down Expand Up @@ -428,7 +422,7 @@ fn shard_manager(
}

// Start process
tracing::info!("Starting shard {rank}");
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.envs(envs)
Expand Down Expand Up @@ -493,17 +487,17 @@ fn shard_manager(
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait();
tracing::info!("Shard {rank} terminated");
tracing::info!("Shard terminated");
return;
}

// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready...");
tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
Expand Down Expand Up @@ -860,8 +854,6 @@ fn spawn_webserver(
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
Expand All @@ -878,6 +870,12 @@ fn spawn_webserver(
args.model_id,
];

// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
router_args.push(max_batch_total_tokens.to_string());
}

// Model optional revision
if let Some(ref revision) = args.revision {
router_args.push("--revision".to_string());
Expand Down Expand Up @@ -1036,18 +1034,7 @@ fn main() -> Result<(), LauncherError> {
args.max_batch_prefill_tokens, args.max_input_length
)));
}
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}

if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand All @@ -1065,6 +1052,21 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Sharding model on {num_shard} processes");
}

if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
)));
}
}

// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
Expand Down
7 changes: 4 additions & 3 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}

/// Empty response
message WarmupResponse {}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}
13 changes: 4 additions & 9 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ impl Client {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();

Expand Down Expand Up @@ -143,13 +142,9 @@ impl Client {
max_tokens: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}

/// Generate one token for each request in the given batch
Expand Down
7 changes: 2 additions & 5 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,11 @@ impl ShardedClient {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
Expand Down
2 changes: 1 addition & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding);
let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
Expand Down
57 changes: 41 additions & 16 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct Args {
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
Expand Down Expand Up @@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> {
if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}

if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}

if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}

// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
Expand Down Expand Up @@ -210,14 +214,35 @@ fn main() -> Result<(), RouterError> {

// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens)
.await
.map_err(RouterError::Warmup)?;
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");

let addr = match hostname.parse() {
Expand All @@ -240,7 +265,7 @@ fn main() -> Result<(), RouterError> {
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_supported_batch_total_tokens,
max_waiting_tokens,
sharded_client,
tokenizer,
Expand Down
Loading
Loading