Skip to content

Commit

Permalink
[mypy] nncf/data (#3173)
Browse files Browse the repository at this point in the history
### Changes

Enable mypy check for `nncf/data` 
Removed `Generic[DataItem, ModelInput]` for `nncf.Dataset`, to simplify
using class
  • Loading branch information
AlexanderDokuchaev authored Jan 8, 2025
1 parent 801e5af commit f37e5ad
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 32 deletions.
4 changes: 1 addition & 3 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
from nncf.common.tensor_statistics.statistics_serializer import dump_statistics
from nncf.common.tensor_statistics.statistics_serializer import load_statistics
from nncf.common.utils.backend import BackendType
from nncf.data.dataset import DataItem
from nncf.data.dataset import Dataset
from nncf.data.dataset import ModelInput
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic

TensorType = TypeVar("TensorType")
Expand All @@ -47,7 +45,7 @@ class StatisticsAggregator(ABC):

BACKEND: BackendType

def __init__(self, dataset: Dataset[DataItem, ModelInput]):
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()
Expand Down
44 changes: 20 additions & 24 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Generic, Iterable, List, Optional, TypeVar
from typing import Any, Callable, Generator, Iterable, Iterator, List, Optional, cast

from nncf.common.utils.api_marker import api

DataItem = TypeVar("DataItem")
ModelInput = TypeVar("ModelInput")


@api(canonical_alias="nncf.Dataset")
class Dataset(Generic[DataItem, ModelInput]):
class Dataset:
"""
Wrapper for passing custom user datasets into NNCF algorithms.
Expand All @@ -41,13 +38,11 @@ class Dataset(Generic[DataItem, ModelInput]):
will be passed into the model as-is.
"""

def __init__(
self, data_source: Iterable[DataItem], transform_func: Optional[Callable[[DataItem], ModelInput]] = None
):
def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None):
self._data_source = data_source
self._transform_func = transform_func

def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]:
def get_data(self, indices: Optional[List[int]] = None) -> Iterable[Any]:
"""
Returns the iterable object that contains selected data items from the data source as-is.
Expand All @@ -58,7 +53,7 @@ def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]:
"""
return DataProvider(self._data_source, None, indices)

def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]:
def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Any]:
"""
Returns the iterable object that contains selected data items from the data source, for which
the transformation function was applied. The item, which was returned per iteration from this
Expand All @@ -78,7 +73,7 @@ def get_length(self) -> Optional[int]:
:return: The length of the data_source if __len__() is implemented for it, and None otherwise.
"""
if hasattr(self._data_source, "__len__"):
return self._data_source.__len__()
return cast(int, self._data_source.__len__())
return None

def get_batch_size(self) -> Optional[int]:
Expand All @@ -87,26 +82,27 @@ def get_batch_size(self) -> Optional[int]:
:return: The value of batch_size or _batch_size attributes of the data_source if exist, and None otherwise.
"""
if hasattr(self._data_source, "batch_size"): # Torch dataloader
return self._data_source.batch_size
return cast(int, self._data_source.batch_size)
if hasattr(self._data_source, "_batch_size"): # TF dataloader
return self._data_source._batch_size
return cast(int, self._data_source._batch_size)
return None


class DataProvider(Generic[DataItem, ModelInput]):
class DataProvider:
def __init__(
self,
data_source: Iterable[DataItem],
transform_func: Callable[[DataItem], ModelInput],
data_source: Iterable[Any],
transform_func: Optional[Callable[..., Any]],
indices: Optional[List[int]] = None,
):
self._data_source = data_source
if transform_func is None:
transform_func = lambda x: x
self._transform_func = transform_func
self._transform_func = lambda x: x
else:
self._transform_func = transform_func
self._indices = indices

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
if self._indices is None:
return map(self._transform_func, self._data_source)

Expand All @@ -117,15 +113,15 @@ def __iter__(self):

@staticmethod
def _get_iterator_for_map_style(
data_source: Iterable[DataItem], transform_func: Callable[[DataItem], ModelInput], indices: List[int]
):
data_source: Iterable[Any], transform_func: Callable[..., Any], indices: List[int]
) -> Generator[Any, None, None]:
for index in indices:
yield transform_func(data_source[index])
yield transform_func(data_source[index]) # type: ignore[index]

@staticmethod
def _get_iterator_for_iter(
data_source: Iterable[DataItem], transform_func: Callable[[DataItem], ModelInput], indices: List[int]
):
data_source: Iterable[Any], transform_func: Callable[..., Any], indices: List[int]
) -> Generator[Any, None, None]:
pos = 0
num_indices = len(indices)
for idx, data_item in enumerate(data_source):
Expand Down
6 changes: 3 additions & 3 deletions nncf/data/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def generate_text_data(
raise nncf.ModuleNotFoundError("torch is required in order to generate text data: `pip install torch`.")

try:
from transformers import PreTrainedModel
from transformers import PreTrainedModel # type: ignore
from transformers import PreTrainedTokenizerBase
from transformers.utils import logging
from transformers.utils import logging # type: ignore

logging.set_verbosity_error()
except ImportError:
Expand All @@ -70,7 +70,7 @@ def generate_text_data(
if not isinstance(tokenizer, PreTrainedTokenizerBase.__bases__):
raise nncf.ValidationError("tokenizer should be instance of the `transformers.PreTrainedTokenizerBase`.")

generated_data = []
generated_data: List[str] = []

vocab_size_names = ["padded_vocab_size", "vocab_size"]
vocab_size = BASE_VOCAB_SIZE
Expand Down
4 changes: 2 additions & 2 deletions nncf/telemetry/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def extract(self, argvalue: SerializableData) -> CollectedEvent:


class FunctionCallTelemetryExtractor(TelemetryExtractor):
def __init__(self, argvalue=None):
def __init__(self, argvalue: Any = None) -> None:
super().__init__()
self._argvalue = argvalue

def extract(self, _: Any):
def extract(self, _: Any) -> CollectedEvent:
return CollectedEvent(name="function_call", data=self._argvalue)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ strict = true
implicit_optional = true
files = [
"nncf/api",
"nncf/data",
"nncf/common/collector.py",
"nncf/common/engine.py",
"nncf/common/hook_handle.py",
Expand Down

0 comments on commit f37e5ad

Please sign in to comment.