11"""Utilities for selecting and loading models."""
2+ from functools import partial
23from typing import Optional
3-
44import math
55import torch
66import numpy as np
1111from vllm .sequence import SamplerOutput
1212from vllm .utils import is_openvino_optimum_intel
1313
14+ import openvino as ov
15+
1416
1517def _flattenize_inputs (inputs ):
1618 """
@@ -53,7 +55,7 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor:
5355
5456
5557def patch_stateful_model (
56- model : torch . nn . Module ,
58+ model : ov . Model ,
5759 factory ):
5860 print ('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM' )
5961 from openvino .runtime .passes import Manager , MatcherPass , WrapType , Matcher , AnyInput , Or
@@ -194,7 +196,14 @@ def __init__(self):
194196 seq = WrapType ("opset13.Gather" , [kv_shape , AnyInput (), AnyInput ()])
195197
196198 def callback (m : Matcher ) -> bool :
197- replace_node (m .get_match_root (), max_context_len )
199+ gather = m .get_match_root ()
200+ target_type = gather .get_output_element_type (0 )
201+ if max_context_len .get_output_element_type (0 ) != target_type :
202+ print (f'Converting { max_context_len .get_output_element_type (0 )} of max_context_len to { target_type } ' )
203+ replacement = opset13 .convert (max_context_len , target_type )
204+ else :
205+ replacement = max_context_len
206+ replace_node (gather , replacement )
198207 print ("DETECTED PATTERN FOR max_sequence_length, CONNECTED TO A DEDICATED PARAMETER" )
199208 return True
200209
@@ -270,7 +279,6 @@ def _patch_model_with_openvino(
270279 from vllm .model_executor .layers .attention .attention import Attention
271280 from openvino .frontend .pytorch import ModuleExtension
272281 from openvino import Core , convert_model , Type , PartialShape
273- from functools import partial
274282
275283 # Avoid usage of vllm._C.ops
276284
@@ -426,7 +434,7 @@ def get_model(model_config: ModelConfig,
426434
427435 pt_model = None
428436
429- if is_openvino_optimum_intel () and False :
437+ if is_openvino_optimum_intel ():
430438 import openvino as ov
431439 from optimum .intel import OVModelForCausalLM
432440 pt_model = OVModelForCausalLM .from_pretrained (model_config .model , export = True , compile = False , load_in_8bit = False , trust_remote_code = True ) # need stateful because it also enables SDPA
@@ -438,9 +446,8 @@ def get_model(model_config: ModelConfig,
438446 patch_stateful_model (pt_model .model , pt_model .ov_node_factory )
439447 core = ov .Core ()
440448 ov_compiled = core .compile_model (pt_model .model , "CPU" )
441- pt_model .ov_request = ov_compiled .create_infer_request ()
449+ pt_model ._ov_request = ov_compiled .create_infer_request ()
442450
443- from functools import partial
444451 pt_model ._openvino_patch_orig_forward = pt_model .forward
445452 pt_model .forward = partial (ov_wrapper , pt_model )
446453
0 commit comments