Skip to content

Commit

Permalink
[python] Make paged attention configurable (deepjavalibrary#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored and KexinFeng committed Aug 16, 2023
1 parent 51000f6 commit 767d946
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
self.properties = properties
self.batch_cls = None
self._init_model(kwargs, model_id_or_path)
self._warmup(**kwargs)
self.paged_attention = self.properties.get("paged_attention", "true").lower() == "true"
if self.paged_attention:
self._warmup(**kwargs)
self.batch_id_counter = 0
self.cache: Batch = None

Expand All @@ -67,7 +69,8 @@ def _init_model(self, kwargs, model_id_or_path):
sharded=sharded,
quantize=quantize,
dtype=dtype,
trust_remote_code=kwargs.get("trust_remote_code"))
trust_remote_code=kwargs.get("trust_remote_code"),
paged_attention=self.paged_attention)
self.batch_cls = self.model.batch_type

def _warmup(self, **kwargs):
Expand All @@ -84,7 +87,7 @@ def _warmup(self, **kwargs):
max_prefill_tokens = int(
self.properties.get(
"max_rolling_batch_prefill_tokens",
int(self.properties.get("max_rolling_batch_size", 4)) * 512))
int(self.properties.get("max_rolling_batch_size", 4)) * 272))
requests = [
lmi_dist.utils.types.Request(id=0,
inputs='_test ' * max_prefill_tokens,
Expand Down

0 comments on commit 767d946

Please sign in to comment.