Skip to content

Commit 33b3568

Browse files
authored
[tensorflow] Add __slots__ (#15459)
1 parent 2b36c45 commit 33b3568

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ class Operation:
222222
def __getattr__(self, name: str) -> Incomplete: ...
223223

224224
class TensorShape(metaclass=ABCMeta):
225+
__slots__ = ["_dims"]
225226
def __init__(self, dims: ShapeLike) -> None: ...
226227
@property
227228
def rank(self) -> int: ...
@@ -308,6 +309,7 @@ class UnconnectedGradients(Enum):
308309
_SpecProto = TypeVar("_SpecProto", bound=Message)
309310

310311
class TypeSpec(ABC, Generic[_SpecProto]):
312+
__slots__ = ["_cached_cmp_key"]
311313
@property
312314
@abstractmethod
313315
def value_type(self) -> Any: ...
@@ -323,6 +325,7 @@ class TypeSpec(ABC, Generic[_SpecProto]):
323325
def most_specific_compatible_type(self, other: Self) -> Self: ...
324326

325327
class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
328+
__slots__: list[str] = []
326329
def __init__(self, shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None) -> None: ...
327330
@property
328331
def value_type(self) -> Tensor: ...
@@ -339,6 +342,7 @@ class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
339342
def is_compatible_with(self, spec_or_tensor: Self | TensorCompatible) -> _bool: ... # type: ignore[override]
340343

341344
class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
345+
__slots__ = ["_shape", "_dtype"]
342346
def __init__(self, shape: ShapeLike | None = None, dtype: DTypeLike = ...) -> None: ...
343347
@property
344348
def value_type(self) -> SparseTensor: ...
@@ -350,6 +354,7 @@ class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
350354
def from_value(cls, value: SparseTensor) -> Self: ...
351355

352356
class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
357+
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", "_flat_values_spec"]
353358
def __init__(
354359
self,
355360
shape: ShapeLike | None = None,

stubs/tensorflow/tensorflow/dtypes.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from tensorflow.python.framework.dtypes import HandleData
1010
class _DTypeMeta(ABCMeta): ...
1111

1212
class DType(metaclass=_DTypeMeta):
13+
__slots__ = ["_handle_data"]
1314
def __init__(self, type_enum: int, handle_data: HandleData | None = None) -> None: ...
1415
@property
1516
def name(self) -> str: ...

stubs/tensorflow/tensorflow/saved_model/__init__.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ class Asset:
1818
def __init__(self, path: str | Path | tf.Tensor) -> None: ...
1919

2020
class LoadOptions:
21+
__slots__ = (
22+
"allow_partial_checkpoint",
23+
"experimental_io_device",
24+
"experimental_skip_checkpoint",
25+
"experimental_variable_policy",
26+
"experimental_load_function_aliases",
27+
)
2128
allow_partial_checkpoint: bool
2229
experimental_io_device: str | None
2330
experimental_skip_checkpoint: bool

stubs/tensorflow/tensorflow/train/__init__.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from _typeshed import Incomplete
12
from collections.abc import Callable
23
from typing import Any, TypeVar
34
from typing_extensions import Self
@@ -18,10 +19,20 @@ from tensorflow.python.trackable.base import Trackable
1819
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
1920

2021
class CheckpointOptions:
22+
__slots__ = (
23+
"experimental_io_device",
24+
"experimental_enable_async_checkpoint",
25+
"experimental_write_callbacks",
26+
"enable_async",
27+
"experimental_sharding_callback",
28+
"experimental_skip_slot_variables",
29+
)
2130
experimental_io_device: None | str
2231
experimental_enable_async_checkpoint: bool
2332
experimental_write_callbacks: None | list[Callable[[str], object] | Callable[[], object]]
2433
enable_async: bool
34+
experimental_sharding_callback: Incomplete # should be ShardingCallback
35+
experimental_skip_slot_variables: bool
2536

2637
def __init__(
2738
self,

0 commit comments

Comments
 (0)