Skip to content

Commit

Permalink
fix crash in multi-modal (#2245)
Browse files Browse the repository at this point in the history
* fix crash in multi-modal

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update according to review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix llava_next regression in latest main

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Jul 24, 2024
1 parent a895029 commit 5ad39dd
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(self, prefix, config, weights):
FlashLlamaLayer(
index=0,
prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config,
weights=weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ def forward(
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def forward(
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -348,6 +349,7 @@ def forward(
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down

0 comments on commit 5ad39dd

Please sign in to comment.