Skip to content

Commit cdd3261

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4bbb48d commit cdd3261

File tree

6 files changed

+6
-17
lines changed

6 files changed

+6
-17
lines changed

auto_round/inference/backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,4 +1025,3 @@ def build_pip_commands(gptq_req, other_reqs):
10251025
log(joined_cmds)
10261026
if logger_level == "error":
10271027
exit(-1)
1028-

auto_round/testing_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import importlib.util
1616
import unittest
1717
from functools import wraps
18-
from typing import Literal
18+
from typing import Callable, Literal
19+
1920
import torch
20-
from typing import Callable
2121
from transformers.utils.versions import require_version
2222

2323
from auto_round.logger import logger
@@ -218,9 +218,7 @@ def require_vlm_env(test_case):
218218

219219

220220
def require_package_version(
221-
package: str,
222-
version_spec: str,
223-
on_fail: Literal["skip", "warn", "error"] = "skip"
221+
package: str, version_spec: str, on_fail: Literal["skip", "warn", "error"] = "skip"
224222
) -> bool:
225223
"""
226224
Check if a package satisfies a version requirement.
@@ -264,10 +262,9 @@ def require_package_version_ut(package: str, version_spec: str) -> Callable:
264262
Returns:
265263
Callable: A decorator to wrap unittest test methods.
266264
"""
265+
267266
def decorator(test_func: Callable) -> Callable:
268267
reason = f"Test requires {package}{version_spec}"
269-
return unittest.skipUnless(
270-
require_package_version(package, version_spec, on_fail="skip"),
271-
reason
272-
)(test_func)
268+
return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func)
269+
273270
return decorator

test/test_cuda/test_auto_round_format.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,3 @@ def test_load_gptq_model_3bits(self):
330330

331331
if __name__ == "__main__":
332332
unittest.main()
333-

test/test_cuda/test_export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,4 +405,3 @@ def test_nvfp4_llmcompressor_format(self):
405405

406406
if __name__ == "__main__":
407407
unittest.main()
408-

test/test_cuda/test_main_func.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_backend_awq(self):
8383
assert accuracy > 0.35
8484
shutil.rmtree("./saved", ignore_errors=True)
8585

86-
8786
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
8887
@require_gptqmodel
8988
def test_fp_layers(self):
@@ -108,7 +107,6 @@ def test_fp_layers(self):
108107
assert accuracy > 0.35
109108
shutil.rmtree("./saved", ignore_errors=True)
110109

111-
112110
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
113111
@require_awq
114112
@require_package_version_ut("transformers", "<4.57.0")
@@ -134,7 +132,6 @@ def test_fp_layers_awq(self):
134132
assert accuracy > 0.35
135133
shutil.rmtree("./saved", ignore_errors=True)
136134

137-
138135
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
139136
def test_undivided_group_size_tuning(self):
140137
model_name = "/models/opt-125m"
@@ -185,4 +182,3 @@ def test_autoround_asym(self): ##need to install false
185182

186183
if __name__ == "__main__":
187184
unittest.main()
188-

test/test_cuda/test_support_vlms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,4 +388,3 @@ def test_deepseek_vl2(self):
388388

389389
if __name__ == "__main__":
390390
unittest.main()
391-

0 commit comments

Comments
 (0)