@@ -361,6 +361,7 @@ def __init__(
361361 self .infer_bs_coeff = 1
362362 self .enable_torch_compile = enable_torch_compile
363363 self ._adjust_torch_compile (enable_torch_compile )
364+ os .environ ["AR_TORCH_COMPILE" ] = "1" if self .enable_torch_compile else "0"
364365 self ._check_configs ()
365366 torch .set_printoptions (precision = 3 , sci_mode = True )
366367
@@ -1911,10 +1912,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111912
19121913 self .model = mv_module_from_gpu (self .model , self .low_cpu_mem_usage )
19131914 clear_memory ()
1914- if self .enable_torch_compile :
1915- quant_layer = compile_func (self ._quantize_layer , self .device )
1916- else :
1917- quant_layer = self ._quantize_layer
1915+ quant_layer = self ._quantize_layer
19181916 for layer_name in layer_names :
19191917 layer_input = layer_inputs [layer_name ]
19201918 layer_input = to_device (layer_input , self .cache_device )
@@ -3008,14 +3006,8 @@ def _quantize_blocks(
30083006 logger .info ("using algorithm extension for quantization." )
30093007 except (ImportError , ModuleNotFoundError ):
30103008 quantize_block = self ._quantize_block
3011- if self .enable_torch_compile :
3012- quantize_block = compile_func (quantize_block , device )
3013- else :
3014- quantize_block = quantize_block
30153009 else :
30163010 quantize_block = self ._quantize_block
3017- if self .enable_torch_compile :
3018- quantize_block = compile_func (quantize_block , device )
30193011
30203012 if pbar is None :
30213013 pbar = tqdm (range (0 , len (block_names ), nblocks ))
0 commit comments