Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,32 +419,32 @@ def __init__(self,
fields.add(key)
setattr(self, key, value)

def get_copy_kwargs(self, **kwargs):
def get_copy_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
for f in self.__class__.fields:
if f not in kwargs:
with contextlib.suppress(AttributeError):
kwargs[f] = getattr(self, f)
return kwargs

def copy(self, **kwargs):
def copy(self, **kwargs: Any) -> "RecordWithoutPickling":
return self.__class__(**self.get_copy_kwargs(**kwargs))

def __repr__(self):
def __repr__(self) -> str:
return "{}({})".format(
self.__class__.__name__,
", ".join(f"{fld}={getattr(self, fld)!r}"
for fld in sorted(self.__class__.fields)
if hasattr(self, fld)))

def register_fields(self, new_fields):
def register_fields(self, new_fields: Set[str]) -> None:
try:
fields = self.__class__.fields
except AttributeError:
self.__class__.fields = fields = set()

fields.update(new_fields)

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
# This method is implemented to avoid pylint 'no-member' errors for
# attribute access.
raise AttributeError(
Expand All @@ -455,13 +455,13 @@ def __getattr__(self, name):
class Record(RecordWithoutPickling):
__slots__: ClassVar[list[str]] = []

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {
key: getattr(self, key)
for key in self.__class__.fields
if hasattr(self, key)}

def __setstate__(self, valuedict):
def __setstate__(self, valuedict: Dict[str, Any]) -> None:
try:
fields = self.__class__.fields
except AttributeError:
Expand All @@ -471,31 +471,28 @@ def __setstate__(self, valuedict):
fields.add(key)
setattr(self, key, value)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if self is other:
return True
return (self.__class__ == other.__class__
and self.__getstate__() == other.__getstate__())

def __ne__(self, other):
return not self.__eq__(other)


class ImmutableRecordWithoutPickling(RecordWithoutPickling):
"""Hashable record. Does not explicitly enforce immutability."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
RecordWithoutPickling.__init__(self, *args, **kwargs)
self._cached_hash = None
self._cached_hash: Optional[int] = None

def __hash__(self):
def __hash__(self) -> int:
# This attribute may vanish during pickling.
if getattr(self, "_cached_hash", None) is None:
self._cached_hash = hash((
type(self),
*(getattr(self, field) for field in self.__class__.fields)
))

return self._cached_hash
return cast(int, self._cached_hash)


class ImmutableRecord(ImmutableRecordWithoutPickling, Record):
Expand Down
Loading