Skip to content

Commit 60b4ac0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d1fca76 commit 60b4ac0

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

auto_round/compressors/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@
8686
is_mx_fp,
8787
is_nv_fp,
8888
is_standard_fp,
89-
is_torch_compile_enabled,
9089
is_static_wfp8afp8,
90+
is_torch_compile_enabled,
9191
is_wfp8afp8,
9292
llm_load_model,
9393
mv_module_from_gpu,
@@ -2095,9 +2095,9 @@ def _get_block_outputs(
20952095
tmp_input_ids, tmp_input_others = self._sampling_inputs(
20962096
input_ids, input_others, indices, self.seqlen, self.batch_dim, share_cache_keys=self.shared_cache_keys
20972097
)
2098-
tmp_output = self.block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
2099-
cache_device
2100-
)
2098+
tmp_output = self.block_forward(
2099+
block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device
2100+
).to(cache_device)
21012101
if save_output:
21022102
if self.batch_size == 1:
21032103
output.append(tmp_output)

auto_round/compressors/diffusion/compressor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
LazyImport,
2929
block_forward,
3030
clear_memory,
31+
compile_func,
3132
diffusion_load_model,
3233
extract_block_names_to_str,
3334
find_matching_blocks,
3435
get_block_names,
3536
is_torch_compile_enabled,
36-
compile_func,
3737
)
3838

3939
pipeline_utils = LazyImport("diffusers.pipelines.pipeline_utils")
@@ -208,7 +208,9 @@ def _get_current_q_output(
208208
hidden_states = current_input_ids.pop("hidden_states")
209209
current_input_others.update(current_input_ids)
210210
current_input_ids = hidden_states
211-
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device, idx)
211+
output_q = self.block_forward(
212+
block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device, idx
213+
)
212214
return output_q
213215

214216
@torch.no_grad()
@@ -252,7 +254,9 @@ def _get_block_outputs(
252254
tmp_input_others.update(tmp_input_ids)
253255
tmp_input_ids = hidden_states
254256

255-
tmp_output = self.block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device, None)
257+
tmp_output = self.block_forward(
258+
block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device, None
259+
)
256260
assert len(output_config) == len(tmp_output)
257261
tmp_output = dict(zip(output_config, tmp_output))
258262

auto_round/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from .utils import (
2323
SUPPORTED_LAYER_TYPES,
2424
check_to_quantized,
25+
compile_func,
2526
deepspeed_exists,
2627
get_scale_shape,
2728
is_mx_fp,
2829
is_nv_fp,
29-
set_module,
3030
is_torch_compile_enabled,
31-
compile_func,
31+
set_module,
3232
)
3333

3434
if deepspeed_exists:

0 commit comments

Comments
 (0)