Skip to content

Commit

Permalink
add ruff lint rule to remove unused imports via ruff (pytorch#969)
Browse files Browse the repository at this point in the history
remove unused imports via ruff
  • Loading branch information
jimexist authored and melvinebenezer committed Oct 7, 2024
1 parent e86acdd commit 21a3534
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 56 deletions.
13 changes: 9 additions & 4 deletions .github/workflows/ruff_linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
pip install ruff==0.6.8
- name: Analyzing the code with ruff
run: |
ruff check .
- name: Check all Python files for syntax errors (E999) and undefined vars (F821)
- name: Check *all* Python files for F821, F823, and W191
run: |
ruff check --isolated --select E999,F821
- name: Check well formatted code
# --isolated is used to skip the allowlist at all so this applies to all files
# please be careful when using this large changes means everyone needs to rebase
ruff check --isolated --select F821,F823,W191
- name: Check the allow-listed files for F,I
run: |
ruff check --select F,I
- name: Check the allow-listed files for well formatted code
run: |
ruff format --check
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ tabulate # QOL for printing tables to stdout
ninja

# Linting
ruff
ruff==0.6.8
pre-commit
33 changes: 17 additions & 16 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
import pytest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)
import pytest

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import copy
import io
import random
import unittest
from contextlib import nullcontext
from functools import partial
from typing import Tuple

import pytest
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils

from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
quantize_,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.quantization.observer import PerTensor, PerRow
from torchao.float8.float8_utils import compute_error
import torch
import unittest
import pytest
import copy
import random
from functools import partial
from typing import Tuple
from contextlib import nullcontext
import io

from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine

random.seed(0)
torch.manual_seed(0)
Expand Down
21 changes: 11 additions & 10 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import copy
import io
import logging
import unittest
from packaging import version
import math
import unittest
from collections import OrderedDict
from typing import Tuple, Union

import pytest
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Expand All @@ -19,18 +23,15 @@
parametrize,
run_tests,
)

import torchao
from packaging import version
from torchao.dtypes.nf4tensor import (
_INNER_TENSOR_NAMES_FOR_SHARDING,
NF4Tensor,
linear_nf4,
to_nf4,
_INNER_TENSOR_NAMES_FOR_SHARDING,
)
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao
from typing import Tuple, Union


bnb_available = False

Expand Down
20 changes: 11 additions & 9 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import re
import unittest

import torch
import torch.nn as nn

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
PerTensor,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
from torch.testing._internal import common_utils
import unittest

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torchao.quantization.quant_primitives import (
MappingType,
)


class TestQuantFlow(TestCase):
Expand Down
7 changes: 3 additions & 4 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import functools
from dataclasses import dataclass, replace
import math
from typing import Dict, Tuple, Any, Optional, Union
import sys
from dataclasses import dataclass, replace
from enum import Enum, auto
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch.distributed.device_mesh import DeviceMesh
from torch._prims_common import make_contiguous_strides_for

from torch.distributed.device_mesh import DeviceMesh

aten = torch.ops.aten

Expand Down
6 changes: 3 additions & 3 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import enum
from typing import Dict, Optional, NamedTuple
from typing import Dict, NamedTuple, Optional

import torch

import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

from torchao.float8.float8_utils import (
e4m3_dtype,
to_fp8_saturated,
)
from torch.distributed._tensor import DTensor

aten = torch.ops.aten

Expand Down
4 changes: 2 additions & 2 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from typing import Iterable, Literal, Tuple, Union

import torchao.float8.config as config

import torch
import torch.distributed as dist

import torchao.float8.config as config

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html

Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
Defines an nn module designed to be used during inference
"""

from typing import Optional, Tuple, NamedTuple
from typing import NamedTuple, Optional, Tuple

import torch

from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul

Tensor = torch.Tensor
Expand Down
9 changes: 5 additions & 4 deletions torchao/quantization/linear_activation_weight_observer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Callable, Dict, Optional

import torch
from typing import Callable, Optional, Dict
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.quantization.observer import AffineQuantizedObserverBase
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)

from torchao.quantization.observer import AffineQuantizedObserverBase

__all__ = [
"LinearActivationWeightObservedTensor",
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Callable, Dict, Optional

import torch
from typing import Callable, Optional, Dict, Any
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)

__all__ = [
Expand Down

0 comments on commit 21a3534

Please sign in to comment.