|
| 1 | +diff --git a/modeling_dots_ocr_vllm.py b/modeling_dots_ocr_vllm.py |
| 2 | +index a8ba8d0..eb84b0d 100644 |
| 3 | +--- a/modeling_dots_ocr_vllm.py |
| 4 | ++++ b/modeling_dots_ocr_vllm.py |
| 5 | +@@ -178,11 +178,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal): |
| 6 | + ) |
| 7 | + _tp_plan = {} |
| 8 | + |
| 9 | +- @classmethod |
| 10 | +- def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: |
| 11 | +- if modality in ("image",): |
| 12 | +- return "<|img|><|imgpad|><|endofimg|>" |
| 13 | +- |
| 14 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 15 | + super().__init__() |
| 16 | + |
| 17 | +@@ -424,12 +419,20 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal): |
| 18 | + loader = AutoWeightsLoader(self) |
| 19 | + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) |
| 20 | + |
| 21 | ++ @classmethod |
| 22 | ++ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: |
| 23 | ++ if modality.startswith("image"): |
| 24 | ++ return "<|img|><|imgpad|><|endofimg|>" |
| 25 | ++ |
| 26 | ++ raise ValueError("Only image modality is supported") |
| 27 | ++ |
| 28 | + |
| 29 | + def patch_vllm_chat_placeholder(): |
| 30 | + import vllm |
| 31 | + # return when vllm version > 0.9.1 |
| 32 | +- if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1): |
| 33 | +- return |
| 34 | ++ # our version is 0.9.0.dev, ignore the following version check. |
| 35 | ++ # if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1): |
| 36 | ++ # return |
| 37 | + from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker |
| 38 | + |
| 39 | + ori = BaseMultiModalItemTracker._placeholder_str |
| 40 | +@@ -448,4 +451,4 @@ ModelRegistry.register_model( |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +-patch_vllm_chat_placeholder() |
| 45 | +\ No newline at end of file |
| 46 | ++# patch_vllm_chat_placeholder() |
| 47 | +diff --git a/modeling_dots_vision.py b/modeling_dots_vision.py |
| 48 | +index 1046513..56009a8 100644 |
| 49 | +--- a/modeling_dots_vision.py |
| 50 | ++++ b/modeling_dots_vision.py |
| 51 | +@@ -8,10 +8,16 @@ import torch.utils.checkpoint |
| 52 | + flash_attn_available = True |
| 53 | + npu_available = True |
| 54 | + |
| 55 | ++# try: |
| 56 | ++# from flash_attn import flash_attn_varlen_func |
| 57 | ++# except ImportError: |
| 58 | ++# flash_attn_available = False |
| 59 | ++ |
| 60 | + try: |
| 61 | +- from flash_attn import flash_attn_varlen_func |
| 62 | +-except ImportError: |
| 63 | +- flash_attn_available = False |
| 64 | ++ import intel_extension_for_pytorch as ipex |
| 65 | ++except ImportError as e: |
| 66 | ++ raise ValueError("IPEX is not installed but required for XPU build") |
| 67 | ++ |
| 68 | + |
| 69 | + from torch.nn import LayerNorm |
| 70 | + from transformers.modeling_utils import PreTrainedModel |
| 71 | +@@ -159,9 +165,41 @@ class VisionFlashAttention2(nn.Module): |
| 72 | + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| 73 | + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| 74 | + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| 75 | +- attn_output = flash_attn_varlen_func( |
| 76 | +- q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal |
| 77 | +- ).reshape(seq_length, -1) |
| 78 | ++ |
| 79 | ++ # Original code with flash_attn_varlen_func |
| 80 | ++ # attn_output = flash_attn_varlen_func( |
| 81 | ++ # q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal |
| 82 | ++ # ).reshape(seq_length, -1) |
| 83 | ++ # Original code ends |
| 84 | ++ |
| 85 | ++ # Changes start for XPU |
| 86 | ++ attn_output = torch.empty( |
| 87 | ++ q.shape, |
| 88 | ++ dtype=q.dtype, |
| 89 | ++ device=q.device) |
| 90 | ++ ipex.llm.functional.varlen_attention( |
| 91 | ++ q.contiguous(), # query |
| 92 | ++ k.contiguous(), # key |
| 93 | ++ v.contiguous(), # value |
| 94 | ++ attn_output, # out |
| 95 | ++ cu_seqlens.int(), # seqlen_q |
| 96 | ++ cu_seqlens.int(), # seqlen_k |
| 97 | ++ None, # alibi_slopes |
| 98 | ++ max_seqlen, # max_seqlen_q |
| 99 | ++ max_seqlen, # max_seqlen_k |
| 100 | ++ 0.0, # pdropout |
| 101 | ++ 1.0 / (q.shape[-1] ** 0.5), # softmax_scale |
| 102 | ++ False, # zero_tensors |
| 103 | ++ self.is_causal, # is_causal |
| 104 | ++ False, # return_softmax |
| 105 | ++ None, # gen_ |
| 106 | ++ -1, # window_size_left |
| 107 | ++ -1, # window_size_right |
| 108 | ++ -1, # logits_soft_cap |
| 109 | ++ ) |
| 110 | ++ attn_output = attn_output.reshape(seq_length, -1) |
| 111 | ++ # Changes end for XPU |
| 112 | ++ |
| 113 | + attn_output = self.proj(attn_output) |
| 114 | + |
| 115 | + return attn_output |
0 commit comments