diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index d60ec869edd..3803884e44c 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -27,9 +27,7 @@ from nncf.common.tensor_statistics.statistics_serializer import dump_statistics from nncf.common.tensor_statistics.statistics_serializer import load_statistics from nncf.common.utils.backend import BackendType -from nncf.data.dataset import DataItem from nncf.data.dataset import Dataset -from nncf.data.dataset import ModelInput from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic TensorType = TypeVar("TensorType") @@ -47,7 +45,7 @@ class StatisticsAggregator(ABC): BACKEND: BackendType - def __init__(self, dataset: Dataset[DataItem, ModelInput]): + def __init__(self, dataset: Dataset): self.dataset = dataset self.stat_subset_size = None self.statistic_points = StatisticPointsContainer() diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index a5c3fdf0a19..10540f36a42 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -9,16 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Generic, Iterable, List, Optional, TypeVar +from typing import Any, Callable, Generator, Iterable, Iterator, List, Optional, cast from nncf.common.utils.api_marker import api -DataItem = TypeVar("DataItem") -ModelInput = TypeVar("ModelInput") - @api(canonical_alias="nncf.Dataset") -class Dataset(Generic[DataItem, ModelInput]): +class Dataset: """ Wrapper for passing custom user datasets into NNCF algorithms. @@ -41,13 +38,11 @@ class Dataset(Generic[DataItem, ModelInput]): will be passed into the model as-is. """ - def __init__( - self, data_source: Iterable[DataItem], transform_func: Optional[Callable[[DataItem], ModelInput]] = None - ): + def __init__(self, data_source: Iterable[Any], transform_func: Optional[Callable[..., Any]] = None): self._data_source = data_source self._transform_func = transform_func - def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: + def get_data(self, indices: Optional[List[int]] = None) -> Iterable[Any]: """ Returns the iterable object that contains selected data items from the data source as-is. @@ -58,7 +53,7 @@ def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: """ return DataProvider(self._data_source, None, indices) - def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]: + def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Any]: """ Returns the iterable object that contains selected data items from the data source, for which the transformation function was applied. The item, which was returned per iteration from this @@ -78,7 +73,7 @@ def get_length(self) -> Optional[int]: :return: The length of the data_source if __len__() is implemented for it, and None otherwise. """ if hasattr(self._data_source, "__len__"): - return self._data_source.__len__() + return cast(int, self._data_source.__len__()) return None def get_batch_size(self) -> Optional[int]: @@ -87,26 +82,27 @@ def get_batch_size(self) -> Optional[int]: :return: The value of batch_size or _batch_size attributes of the data_source if exist, and None otherwise. """ if hasattr(self._data_source, "batch_size"): # Torch dataloader - return self._data_source.batch_size + return cast(int, self._data_source.batch_size) if hasattr(self._data_source, "_batch_size"): # TF dataloader - return self._data_source._batch_size + return cast(int, self._data_source._batch_size) return None -class DataProvider(Generic[DataItem, ModelInput]): +class DataProvider: def __init__( self, - data_source: Iterable[DataItem], - transform_func: Callable[[DataItem], ModelInput], + data_source: Iterable[Any], + transform_func: Optional[Callable[..., Any]], indices: Optional[List[int]] = None, ): self._data_source = data_source if transform_func is None: - transform_func = lambda x: x - self._transform_func = transform_func + self._transform_func = lambda x: x + else: + self._transform_func = transform_func self._indices = indices - def __iter__(self): + def __iter__(self) -> Iterator[Any]: if self._indices is None: return map(self._transform_func, self._data_source) @@ -117,15 +113,15 @@ def __iter__(self): @staticmethod def _get_iterator_for_map_style( - data_source: Iterable[DataItem], transform_func: Callable[[DataItem], ModelInput], indices: List[int] - ): + data_source: Iterable[Any], transform_func: Callable[..., Any], indices: List[int] + ) -> Generator[Any, None, None]: for index in indices: - yield transform_func(data_source[index]) + yield transform_func(data_source[index]) # type: ignore[index] @staticmethod def _get_iterator_for_iter( - data_source: Iterable[DataItem], transform_func: Callable[[DataItem], ModelInput], indices: List[int] - ): + data_source: Iterable[Any], transform_func: Callable[..., Any], indices: List[int] + ) -> Generator[Any, None, None]: pos = 0 num_indices = len(indices) for idx, data_item in enumerate(data_source): diff --git a/nncf/data/generators.py b/nncf/data/generators.py index 8720c3913bf..67fe033c00f 100644 --- a/nncf/data/generators.py +++ b/nncf/data/generators.py @@ -54,9 +54,9 @@ def generate_text_data( raise nncf.ModuleNotFoundError("torch is required in order to generate text data: `pip install torch`.") try: - from transformers import PreTrainedModel + from transformers import PreTrainedModel # type: ignore from transformers import PreTrainedTokenizerBase - from transformers.utils import logging + from transformers.utils import logging # type: ignore logging.set_verbosity_error() except ImportError: @@ -70,7 +70,7 @@ def generate_text_data( if not isinstance(tokenizer, PreTrainedTokenizerBase.__bases__): raise nncf.ValidationError("tokenizer should be instance of the `transformers.PreTrainedTokenizerBase`.") - generated_data = [] + generated_data: List[str] = [] vocab_size_names = ["padded_vocab_size", "vocab_size"] vocab_size = BASE_VOCAB_SIZE diff --git a/nncf/telemetry/extractors.py b/nncf/telemetry/extractors.py index 6719c5d04aa..799a40e59a2 100644 --- a/nncf/telemetry/extractors.py +++ b/nncf/telemetry/extractors.py @@ -60,9 +60,9 @@ def extract(self, argvalue: SerializableData) -> CollectedEvent: class FunctionCallTelemetryExtractor(TelemetryExtractor): - def __init__(self, argvalue=None): + def __init__(self, argvalue: Any = None) -> None: super().__init__() self._argvalue = argvalue - def extract(self, _: Any): + def extract(self, _: Any) -> CollectedEvent: return CollectedEvent(name="function_call", data=self._argvalue) diff --git a/pyproject.toml b/pyproject.toml index 7492e0b6fa8..439807c1fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ strict = true implicit_optional = true files = [ "nncf/api", + "nncf/data", "nncf/common/collector.py", "nncf/common/engine.py", "nncf/common/hook_handle.py",