Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Triton Punica kernels #658

Merged
merged 77 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
832d905
Collect timings
tgaddair Oct 23, 2024
697bf4d
Profiler
tgaddair Oct 23, 2024
d155163
Allow max batch prefill tokens < max input length
tgaddair Oct 23, 2024
ca3280c
Fix fallback
tgaddair Oct 23, 2024
830ce3d
Vectorize test
tgaddair Oct 23, 2024
7f250fe
Triton punica kernels
tgaddair Oct 24, 2024
e4fb765
Use triton punica
tgaddair Oct 24, 2024
634c8e2
Fix format
tgaddair Oct 24, 2024
7870729
Plumb weights
tgaddair Oct 24, 2024
0e057f0
Fixed issues
tgaddair Oct 24, 2024
c8ad4cb
Fixed cuda graphs
tgaddair Oct 24, 2024
a82eb64
Remove debug
tgaddair Oct 24, 2024
f68d2c0
Remove debug
tgaddair Oct 24, 2024
2ffc1db
Move init to warmup
tgaddair Oct 24, 2024
ea6c86d
Fix preloaded and speculators
tgaddair Oct 24, 2024
0497a76
Docker test
tgaddair Oct 24, 2024
9e2a29d
Profiling docs
tgaddair Oct 24, 2024
94e3742
Revert timings
tgaddair Oct 25, 2024
0abeccc
Fixed merge
tgaddair Oct 25, 2024
6f5a976
Added LORAX_SPECULATION_MAX_BATCH_SIZE
tgaddair Oct 26, 2024
f89ee87
Try separate trees per adapter
tgaddair Oct 27, 2024
23a77d2
Allow refcount==0
tgaddair Oct 27, 2024
22ed54d
Message
tgaddair Oct 28, 2024
327bb91
Docker test
tgaddair Oct 28, 2024
fbb2b3f
Cleanup
tgaddair Oct 28, 2024
f0693e9
Padding
tgaddair Oct 28, 2024
e62e0f8
Fixed turbo lora + compile
tgaddair Oct 28, 2024
66d8676
Fix
tgaddair Oct 28, 2024
55e5c41
Fix adapter root node id
tgaddair Oct 30, 2024
a6f3a17
More tests
tgaddair Oct 30, 2024
352c92a
Docker test
tgaddair Oct 30, 2024
1ea8d6e
Bump flashinfer
tgaddair Oct 30, 2024
c0640f2
Added logprobs fix
tgaddair Oct 31, 2024
54c36c9
Fix slots
tgaddair Oct 31, 2024
88cd932
No debugging
tgaddair Oct 31, 2024
3505b52
Docker test
tgaddair Oct 31, 2024
cf3d2d9
Fixed slot filtering
tgaddair Oct 31, 2024
d1ff7b4
Triton kernels
tgaddair Oct 31, 2024
57c33d7
Fix ragged
tgaddair Oct 31, 2024
ece47f7
More fixes
tgaddair Oct 31, 2024
779bff3
Merge
tgaddair Oct 31, 2024
cb99320
Revert docker
tgaddair Oct 31, 2024
466ea37
Renamed sgmv -> punica
tgaddair Oct 31, 2024
2f80c6a
Refactor PunicaWrapper
tgaddair Oct 31, 2024
47bfd0c
More configuration
tgaddair Oct 31, 2024
2343d78
More logs
tgaddair Oct 31, 2024
f915abe
Fixes
tgaddair Oct 31, 2024
ad460c0
Guard init
tgaddair Nov 1, 2024
43c129b
Guard model has lm_head
tgaddair Nov 1, 2024
1c70ec6
Determine trace set from preloaded adapter set
tgaddair Nov 1, 2024
3ebcbea
Plumb skip_lm_head
tgaddair Nov 1, 2024
922c5d6
Cleanup comments
tgaddair Nov 1, 2024
b2de54f
Fixed orient for rank
tgaddair Nov 1, 2024
35c7de2
Format
tgaddair Nov 1, 2024
295829f
Fixed tests
tgaddair Nov 1, 2024
ef86071
Fixed CausalLM and embedding model
tgaddair Nov 1, 2024
0d78a0a
Replace flume
tgaddair Nov 1, 2024
8cb79b2
Remove unused dep
tgaddair Nov 1, 2024
045a45a
Update axum
tgaddair Nov 1, 2024
20cf752
Client debug mode, fixed /
tgaddair Nov 1, 2024
2868acc
Docker test
tgaddair Nov 1, 2024
2131dc1
Fixed unused imports
tgaddair Nov 1, 2024
b727a94
Revert docker
tgaddair Nov 1, 2024
cc17d47
Add back tracing
tgaddair Nov 1, 2024
68991ba
Debug
tgaddair Nov 1, 2024
5380426
Docker test
tgaddair Nov 1, 2024
89abd51
Debug registration
tgaddair Nov 1, 2024
3c7b69b
Update tag
tgaddair Nov 1, 2024
d52f530
Don't skip filter
tgaddair Nov 4, 2024
45c6c53
Docker test
tgaddair Nov 4, 2024
3ad4d66
Remove register
tgaddair Nov 4, 2024
b45c219
Revert docker
tgaddair Nov 4, 2024
a4a2d5f
Fixed tests
tgaddair Nov 4, 2024
4a264bc
ruff
tgaddair Nov 4, 2024
e1067a0
Fix tests
tgaddair Nov 4, 2024
848b4c7
Clear cache
tgaddair Nov 4, 2024
107be9a
Check for key in lora weights
tgaddair Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed issues
  • Loading branch information
tgaddair committed Oct 24, 2024
commit 0e057f08740345f42dd0c5e880920ee31cfa4921
11 changes: 9 additions & 2 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class AdapterBatchData:
data: Dict[str, Dict[str, BatchAdapterWeights]]

# layer type -> fused lora weights
layer_to_lora_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]]
layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]]

punica_wrapper: "PunicaWrapper"

Expand All @@ -128,6 +128,7 @@ class AdapterBatchData:
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]],
punica_wrapper: "PunicaWrapper",
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
Expand All @@ -139,7 +140,13 @@ def from_meta(
layer_weights = v.get_data(meta, k, prefill, prefill_head_indices if k == LM_HEAD else None)
if layer_weights:
data[k] = layer_weights
return AdapterBatchData(meta=meta, data=data, punica_wrapper=punica_wrapper, prefill=prefill)
return AdapterBatchData(
meta=meta,
data=data,
layer_to_lora_weights=layer_to_lora_weights,
punica_wrapper=punica_wrapper,
prefill=prefill,
)

def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def __init__(self, prefix, config, weights, layer_id):
layer_id,
[MLP_GATE_PROJ, MLP_UP_PROJ],
sizes=[
config.intermediate_size // 2,
config.intermediate_size // 2,
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
Expand Down
41 changes: 28 additions & 13 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lorax_server.adapters.lora import LoraWeights
from lorax_server.adapters.medusa_lora import MedusaLoraWeights
from lorax_server.utils.sgmv import pad_to_min_rank
import torch
from loguru import logger
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -252,14 +253,14 @@ def register_preloaded_adapters(
self.preloaded_adapters.extend(preloaded_adapters)

# For Triton kernels: need weights into contiguous tensor
# dict of layer_name -> (lora_a_weights, lora_b_weights)
# dict of (layer_name, layer_id) -> (lora_a_weights, lora_b_weights)
# where:
# lora_a_weights = [num_adapters, r, hidden_size]
# lora_b_weights = [num_adapters, hidden_size, r]
self.layer_to_lora_weights = {}
for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items():
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
lora_a_weights = []
lora_b_weights = []
layer_id_to_lora_a_weights = defaultdict(list)
layer_id_to_lora_b_weights = defaultdict(list)
for i, adapter in enumerate(preloaded_adapters):
adapter_index = adapter.adapter_index
adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index)
Expand All @@ -271,17 +272,31 @@ def register_preloaded_adapters(
# only applicable to lora for now
continue

# transpose to ensure col major
lora_a = adapter_weights.weights_a_t
lora_b = adapter_weights.weights_b_t

lora_a_weights.append(lora_a)
lora_b_weights.append(lora_b)
# transpose into col major
lora_a = adapter_weights.weights_a.transpose(1, 2)
lora_b = adapter_weights.weights_b.transpose(1, 2)

nlayers = lora_a.size(0)
for layer_id in range(nlayers):
layer_id_to_lora_a_weights[layer_id].append(lora_a[layer_id])
layer_id_to_lora_b_weights[layer_id].append(lora_b[layer_id])

# stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r]
lora_a_weights = torch.stack(lora_a_weights, device=self.device).contiguous()
lora_b_weights = torch.stack(lora_b_weights, device=self.device).contiguous()
self.layer_to_lora_weights[layer_name] = (lora_a_weights, lora_b_weights)
for layer_id, lora_a_weights in layer_id_to_lora_a_weights.items():
lora_b_weights = layer_id_to_lora_b_weights[layer_id]

# right pad every adapter to the max rank
# TODO(travis)
# r = max([w.size(-1) for w in lora_b_weights])
# lora_a_weights = [pad_to_min_rank(w, 1, r) for w in lora_a_weights]
# lora_b_weights = [pad_to_min_rank(w, 2, r) for w in lora_b_weights]

# stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r]
lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous()
lora_b_weights = torch.stack(lora_b_weights).to(self.device).contiguous()
print("!!! lora_a_weights", lora_a_weights.shape, layer_name, layer_id)
print("!!! lora_b_weights", lora_b_weights.shape)
# ('self_attn.q_proj', 32)
self.layer_to_lora_weights[(layer_name, layer_id)] = (lora_a_weights, lora_b_weights)

def load_adapter(
self,
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward_layer_type(
data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
print("!!! layer_type", layer_type, "start_idx", start_idx, "end_idx", end_idx, "result", result.shape)
if end_idx - start_idx != result.shape[1]:
# proj = torch.zeros_like(result[:, start_idx:end_idx])
y_offset = start_idx
Expand All @@ -89,7 +90,7 @@ def forward_layer_type(
# lora_a_ptr = rank_segments.lora_a_ptr
# lora_b_ptr = rank_segments.lora_b_ptr

lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[layer_type]
lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)]
adapter_data.punica_wrapper.add_lora(
result,
input,
Expand Down Expand Up @@ -230,6 +231,7 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
print("!!! sizes", self.sizes, self.process_group.size())

result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx)

Expand Down
10 changes: 9 additions & 1 deletion server/lorax_server/utils/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import triton
import triton.language as tl

from lorax_server.utils.ops import libentry
from lorax_server.utils.ops.libentry import libentry


@libentry()
Expand Down Expand Up @@ -128,6 +128,14 @@ def sgmv_expand(
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
"""
print("!!! inputs", inputs.shape)
print("!!! lora_b_weights", lora_b_weights.shape)
print("!!! output_tensor", output_tensor.shape)
print("!!! b_seq_start_loc", b_seq_start_loc)
print("!!! seq_len_tensor", seq_len_tensor)
print("!!! lora_indices_tensor", lora_indices_tensor)
print("!!! batches", batches)
print("!!! max_seq_length", max_seq_length)

assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
Expand Down
12 changes: 11 additions & 1 deletion server/lorax_server/utils/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import triton
import triton.language as tl

from lorax_server.utils.ops import libentry
from lorax_server.utils.ops.libentry import libentry


@libentry()
Expand Down Expand Up @@ -137,6 +137,16 @@ def sgmv_expand_slice(
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
"""
print("!!! inputs", inputs.shape)
print("!!! lora_b_weights", lora_b_weights.shape)
print("!!! output_tensor", output_tensor.shape)
print("!!! b_seq_start_loc", b_seq_start_loc)
print("!!! seq_len_tensor", seq_len_tensor)
print("!!! lora_indices_tensor", lora_indices_tensor)
print("!!! batches", batches)
print("!!! max_seq_length", max_seq_length)
print("!!! slice_offset", slice_offset)
print("!!! slice_size", slice_size)

assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
Expand Down
11 changes: 10 additions & 1 deletion server/lorax_server/utils/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import triton
import triton.language as tl

from lorax_server.utils.ops import libentry
from lorax_server.utils.ops.libentry import libentry


@libentry()
Expand Down Expand Up @@ -131,6 +131,15 @@ def sgmv_shrink(
in the batch
scaling (float): Scaling factor.
"""
print("!!! inputs", inputs.shape)
print("!!! lora_a_weights", lora_a_weights.shape)
print("!!! output_tensor", output_tensor.shape)
print("!!! b_seq_start_loc", b_seq_start_loc)
print("!!! seq_len_tensor", seq_len_tensor)
print("!!! lora_indices_tensor", lora_indices_tensor)
print("!!! batch_size", batches)
print("!!! max_seq_length", max_seq_length)
print("!!! scaling", scaling)
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
Expand Down
11 changes: 10 additions & 1 deletion server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

from lorax_server.utils.ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink
import torch
import torch.nn.functional as F

from lorax_server.utils.ops.bgmv_expand import bgmv_expand
from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice
from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink
from lorax_server.utils.ops.sgmv_expand import sgmv_expand
from lorax_server.utils.ops.sgmv_expand_slice import sgmv_expand_slice
from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink

if TYPE_CHECKING:
from lorax_server.adapters.weights import AdapterBatchMetadata

Expand Down Expand Up @@ -39,7 +45,10 @@ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
return pad_to_min_rank(t, dim, min_rank)


def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
Expand Down