@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
145145 example_inputs ,
146146 additional_inductor_config ,
147147 compilation_config : CompilationConfig ,
148+ vllm_backend : "VllmBackend" ,
148149 graph_index : int = 0 ,
149150 num_graphs : int = 1 ,
150151 runtime_shape : Optional [int ] = None ,
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
176177 # see https://github.com/pytorch/pytorch/issues/138980
177178 graph = copy .deepcopy (graph )
178179
179- cache_data = compilation_config .inductor_hash_cache
180+ cache_data = vllm_backend .inductor_hash_cache
180181 if (runtime_shape , graph_index ) in cache_data :
181182 # we compiled this graph before
182183 # so we can directly lookup the compiled graph via hash
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
196197 hash_str , example_inputs , True , False )
197198 assert inductor_compiled_graph is not None , (
198199 "Inductor cache lookup failed. Please remove"
199- f"the cache file { compilation_config . inductor_hash_cache .cache_file_path } and try again." # noqa
200+ f"the cache file { cache_data .cache_file_path } and try again." # noqa
200201 )
201202
202203 # Inductor calling convention (function signature):
@@ -354,14 +355,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
354355
355356 def __init__ (self , module : torch .fx .GraphModule ,
356357 compile_submod_names : List [str ], vllm_config : VllmConfig ,
357- graph_pool ):
358+ graph_pool , vllm_backend : "VllmBackend" ):
358359 super ().__init__ (module )
359360 from torch ._guards import detect_fake_mode
360361 self .fake_mode = detect_fake_mode ()
361362 self .compile_submod_names = compile_submod_names
362363 self .compilation_config = vllm_config .compilation_config
363364 self .graph_pool = graph_pool
364365 self .vllm_config = vllm_config
366+ self .vllm_backend = vllm_backend
365367
366368 def run (self , * args ):
367369 fake_args = [
@@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target,
389391 args ,
390392 self .compilation_config .inductor_compile_config ,
391393 self .compilation_config ,
394+ self .vllm_backend ,
392395 graph_index = index ,
393396 num_graphs = len (self .compile_submod_names ),
394397 runtime_shape = None ,
@@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target,
397400 self .module .__dict__ [target ] = PiecewiseBackend (
398401 submod , self .vllm_config , self .graph_pool , index ,
399402 len (self .compile_submod_names ), sym_shape_indices ,
400- compiled_graph_for_general_shape )
403+ compiled_graph_for_general_shape , self . vllm_backend )
401404
402405 compilation_counter .num_piecewise_capturable_graphs_seen += 1
403406
@@ -430,6 +433,7 @@ class VllmBackend:
430433 post_grad_passes : Sequence [Callable ]
431434 sym_tensor_indices : List [int ]
432435 input_buffers : List [torch .Tensor ]
436+ inductor_hash_cache : InductorHashCache
433437
434438 def __init__ (
435439 self ,
@@ -472,6 +476,53 @@ def configure_post_pass(self):
472476
473477 def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
474478
479+ if not self .compilation_config .cache_dir :
480+ # no provided cache dir, generate one based on the known factors
481+ # that affects the compilation. if none of the factors change,
482+ # the cache dir will be the same so that we can reuse the compiled
483+ # graph.
484+
485+ # 1. factors come from the vllm_config (it mainly summarizes how the
486+ # model is created)
487+ vllm_config = self .vllm_config
488+ config_hash = vllm_config .compute_hash ()
489+
490+ # 2. factors come from the code files that are traced by Dynamo (
491+ # it mainly summarizes how the model is used in forward pass)
492+ forward_code_files = list (
493+ sorted (self .compilation_config .traced_files ))
494+ self .compilation_config .traced_files .clear ()
495+ logger .debug (
496+ "Traced files (to be considered for compilation cache):\n %s" ,
497+ "\n " .join (forward_code_files ))
498+ hash_content = []
499+ for filepath in forward_code_files :
500+ hash_content .append (filepath )
501+ with open (filepath ) as f :
502+ hash_content .append (f .read ())
503+ import hashlib
504+ code_hash = hashlib .md5 (
505+ "\n " .join (hash_content ).encode ()).hexdigest ()
506+
507+ # combine the two hashes to generate the cache dir
508+ hash_key = hashlib .md5 (
509+ f"{ config_hash } _{ code_hash } " .encode ()).hexdigest ()[:10 ]
510+ cache_dir = os .path .join (
511+ envs .VLLM_CACHE_ROOT , "torch_compile_cache" , hash_key ,
512+ f"rank_{ vllm_config .parallel_config .rank } " )
513+ else :
514+ cache_dir = self .compilation_config .cache_dir
515+ os .makedirs (cache_dir , exist_ok = True )
516+
517+ disabled = envs .VLLM_DISABLE_COMPILE_CACHE
518+ self .inductor_hash_cache : InductorHashCache = InductorHashCache (
519+ cache_dir , disabled = disabled )
520+ if disabled :
521+ logger .info ("vLLM's torch.compile cache is disabled." )
522+ else :
523+ logger .info ("Using cache directory: %s for vLLM's torch.compile" ,
524+ cache_dir )
525+
475526 # when dynamo calls the backend, it means the bytecode
476527 # transform and analysis are done
477528 compilation_counter .num_graphs_seen += 1
@@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
507558 # propagate the split graph to the piecewise backend,
508559 # compile submodules with symbolic shapes
509560 PiecewiseCompileInterpreter (self .split_gm , submod_names_to_compile ,
510- self .vllm_config ,
511- self . graph_pool ).run (* example_inputs )
561+ self .vllm_config , self . graph_pool ,
562+ self ).run (* example_inputs )
512563
513564 self ._called = True
514565
@@ -577,7 +628,8 @@ class PiecewiseBackend:
577628 def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
578629 graph_pool : Any , piecewise_compile_index : int ,
579630 total_piecewise_compiles : int , sym_shape_indices : List [int ],
580- compiled_graph_for_general_shape : Callable ):
631+ compiled_graph_for_general_shape : Callable ,
632+ vllm_backend : VllmBackend ):
581633 """
582634 The backend for piecewise compilation.
583635 It mainly handles the compilation and cudagraph capturing.
@@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
597649 self .graph_pool = graph_pool
598650 self .piecewise_compile_index = piecewise_compile_index
599651 self .total_piecewise_compiles = total_piecewise_compiles
652+ self .vllm_backend = vllm_backend
600653
601654 self .is_first_graph = piecewise_compile_index == 0
602655 self .is_last_graph = (
@@ -634,7 +687,7 @@ def check_for_ending_compilation(self):
634687 if self .is_last_graph and not self .to_be_compiled_sizes :
635688 # no specific sizes to compile
636689 # save the hash of the inductor graph for the next run
637- self .compilation_config .inductor_hash_cache .save_to_file ()
690+ self .vllm_backend .inductor_hash_cache .save_to_file ()
638691 end_monitoring_torch_compile (self .vllm_config )
639692
640693 def __call__ (self , * args ) -> Any :
@@ -662,6 +715,7 @@ def __call__(self, *args) -> Any:
662715 args ,
663716 self .compilation_config .inductor_compile_config ,
664717 self .compilation_config ,
718+ self .vllm_backend ,
665719 graph_index = self .piecewise_compile_index ,
666720 num_graphs = self .total_piecewise_compiles ,
667721 runtime_shape = runtime_shape ,
0 commit comments