Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Triton Punica kernels #658

Merged
merged 77 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
832d905
Collect timings
tgaddair Oct 23, 2024
697bf4d
Profiler
tgaddair Oct 23, 2024
d155163
Allow max batch prefill tokens < max input length
tgaddair Oct 23, 2024
ca3280c
Fix fallback
tgaddair Oct 23, 2024
830ce3d
Vectorize test
tgaddair Oct 23, 2024
7f250fe
Triton punica kernels
tgaddair Oct 24, 2024
e4fb765
Use triton punica
tgaddair Oct 24, 2024
634c8e2
Fix format
tgaddair Oct 24, 2024
7870729
Plumb weights
tgaddair Oct 24, 2024
0e057f0
Fixed issues
tgaddair Oct 24, 2024
c8ad4cb
Fixed cuda graphs
tgaddair Oct 24, 2024
a82eb64
Remove debug
tgaddair Oct 24, 2024
f68d2c0
Remove debug
tgaddair Oct 24, 2024
2ffc1db
Move init to warmup
tgaddair Oct 24, 2024
ea6c86d
Fix preloaded and speculators
tgaddair Oct 24, 2024
0497a76
Docker test
tgaddair Oct 24, 2024
9e2a29d
Profiling docs
tgaddair Oct 24, 2024
94e3742
Revert timings
tgaddair Oct 25, 2024
0abeccc
Fixed merge
tgaddair Oct 25, 2024
6f5a976
Added LORAX_SPECULATION_MAX_BATCH_SIZE
tgaddair Oct 26, 2024
f89ee87
Try separate trees per adapter
tgaddair Oct 27, 2024
23a77d2
Allow refcount==0
tgaddair Oct 27, 2024
22ed54d
Message
tgaddair Oct 28, 2024
327bb91
Docker test
tgaddair Oct 28, 2024
fbb2b3f
Cleanup
tgaddair Oct 28, 2024
f0693e9
Padding
tgaddair Oct 28, 2024
e62e0f8
Fixed turbo lora + compile
tgaddair Oct 28, 2024
66d8676
Fix
tgaddair Oct 28, 2024
55e5c41
Fix adapter root node id
tgaddair Oct 30, 2024
a6f3a17
More tests
tgaddair Oct 30, 2024
352c92a
Docker test
tgaddair Oct 30, 2024
1ea8d6e
Bump flashinfer
tgaddair Oct 30, 2024
c0640f2
Added logprobs fix
tgaddair Oct 31, 2024
54c36c9
Fix slots
tgaddair Oct 31, 2024
88cd932
No debugging
tgaddair Oct 31, 2024
3505b52
Docker test
tgaddair Oct 31, 2024
cf3d2d9
Fixed slot filtering
tgaddair Oct 31, 2024
d1ff7b4
Triton kernels
tgaddair Oct 31, 2024
57c33d7
Fix ragged
tgaddair Oct 31, 2024
ece47f7
More fixes
tgaddair Oct 31, 2024
779bff3
Merge
tgaddair Oct 31, 2024
cb99320
Revert docker
tgaddair Oct 31, 2024
466ea37
Renamed sgmv -> punica
tgaddair Oct 31, 2024
2f80c6a
Refactor PunicaWrapper
tgaddair Oct 31, 2024
47bfd0c
More configuration
tgaddair Oct 31, 2024
2343d78
More logs
tgaddair Oct 31, 2024
f915abe
Fixes
tgaddair Oct 31, 2024
ad460c0
Guard init
tgaddair Nov 1, 2024
43c129b
Guard model has lm_head
tgaddair Nov 1, 2024
1c70ec6
Determine trace set from preloaded adapter set
tgaddair Nov 1, 2024
3ebcbea
Plumb skip_lm_head
tgaddair Nov 1, 2024
922c5d6
Cleanup comments
tgaddair Nov 1, 2024
b2de54f
Fixed orient for rank
tgaddair Nov 1, 2024
35c7de2
Format
tgaddair Nov 1, 2024
295829f
Fixed tests
tgaddair Nov 1, 2024
ef86071
Fixed CausalLM and embedding model
tgaddair Nov 1, 2024
0d78a0a
Replace flume
tgaddair Nov 1, 2024
8cb79b2
Remove unused dep
tgaddair Nov 1, 2024
045a45a
Update axum
tgaddair Nov 1, 2024
20cf752
Client debug mode, fixed /
tgaddair Nov 1, 2024
2868acc
Docker test
tgaddair Nov 1, 2024
2131dc1
Fixed unused imports
tgaddair Nov 1, 2024
b727a94
Revert docker
tgaddair Nov 1, 2024
cc17d47
Add back tracing
tgaddair Nov 1, 2024
68991ba
Debug
tgaddair Nov 1, 2024
5380426
Docker test
tgaddair Nov 1, 2024
89abd51
Debug registration
tgaddair Nov 1, 2024
3c7b69b
Update tag
tgaddair Nov 1, 2024
d52f530
Don't skip filter
tgaddair Nov 4, 2024
45c6c53
Docker test
tgaddair Nov 4, 2024
3ad4d66
Remove register
tgaddair Nov 4, 2024
b45c219
Revert docker
tgaddair Nov 4, 2024
a4a2d5f
Fixed tests
tgaddair Nov 4, 2024
4a264bc
ruff
tgaddair Nov 4, 2024
e1067a0
Fix tests
tgaddair Nov 4, 2024
848b4c7
Clear cache
tgaddair Nov 4, 2024
107be9a
Check for key in lora weights
tgaddair Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ruff
  • Loading branch information
tgaddair committed Nov 4, 2024
commit 4a264bcdf19023957a376fa0b3ea5aacc5130d3b
6 changes: 3 additions & 3 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,11 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option
# TODO(travis): don't update this if indices haven't changed
# Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous
adapter_data = AdapterBatchData.from_meta(
meta=batch.adapter_meta,
weights=self.layer_to_adapter_weights,
meta=batch.adapter_meta,
weights=self.layer_to_adapter_weights,
layer_to_lora_weights={},
punica_wrapper=None,
prefill=True,
prefill=True,
prefill_head_indices=None,
)

Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def embed(self, batch) -> torch.Tensor:
adapter_meta = batch.adapter_meta
prefill = False
adapter_data = AdapterBatchData.from_meta(
meta=adapter_meta,
weights=self.layer_to_adapter_weights,
meta=adapter_meta,
weights=self.layer_to_adapter_weights,
layer_to_lora_weights={},
punica_wrapper=None,
prefill=prefill,
prefill=prefill,
prefill_head_indices=batch.prefill_head_indices,
)
embedding, _ = self.forward(batch, adapter_data=adapter_data)
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def forward(self, batch: FlashEmbeddingClassificationBatch):
@tracer.start_as_current_span("embed")
def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:
adapter_data = AdapterBatchData.from_meta(
meta=batch.adapter_meta,
weights=self.layer_to_adapter_weights,
meta=batch.adapter_meta,
weights=self.layer_to_adapter_weights,
layer_to_lora_weights={},
punica_wrapper=None,
prefill=False,
prefill=False,
prefill_head_indices=None,
)

Expand Down
Loading