Skip to content

Commit

Permalink
move ListField params into __init_subclass__
Browse files Browse the repository at this point in the history
  • Loading branch information
mesozoic committed Sep 8, 2024
1 parent 7460013 commit bf8e9ea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
47 changes: 28 additions & 19 deletions pyairtable/orm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,26 @@ class _ListFieldBase(

valid_types = list
list_class: Type[T_ORM_List]
contains_type: Optional[Type[T_ORM]] = None
contains_type: Optional[Type[T_ORM]]

# List fields will always return a list, never ``None``, so we
# have to overload the type annotations for __get__

def __init_subclass__(cls, **kwargs: Any) -> None:
cls.contains_type = kwargs.pop("contains_type", None)
cls.list_class = kwargs.pop("list_class", ChangeTrackingList)

if cls.contains_type and not isinstance(cls.contains_type, type):
raise TypeError(f"contains_type= expected a type, got {cls.contains_type}")
if not isinstance(cls.list_class, type):
raise TypeError(f"list_class= expected a type, got {cls.list_class}")
if not issubclass(cls.list_class, ChangeTrackingList):
raise TypeError(
f"list_class= expected Type[ChangeTrackingList], got {cls.list_class}"
)

return super().__init_subclass__(**kwargs)

@overload
def __get__(self, instance: None, owner: Type[Any]) -> SelfType: ...

Expand All @@ -509,12 +524,9 @@ def _get_list_value(self, instance: "Model") -> T_ORM_List:
# We need to keep track of any mutations to this list, so we know
# whether to write the field back to the API when the model is saved.
if not isinstance(value, self.list_class):
if not isinstance(self.list_class, type):
raise RuntimeError(f"expected a type, got {self.list_class}")
if not issubclass(self.list_class, ChangeTrackingList):
raise RuntimeError(
f"expected Type[ChangeTrackingList], got {self.list_class}"
)
# These were already checked in __init_subclass__ but mypy doesn't know that.
assert isinstance(self.list_class, type)
assert issubclass(self.list_class, ChangeTrackingList)
value = self.list_class(value, field=self, model=instance)

# For implementers to be able to modify this list in place
Expand All @@ -532,7 +544,7 @@ def valid_or_raise(self, value: Any) -> None:


class _ListField(Generic[T], _ListFieldBase[T, T, ChangeTrackingList[T]]):
list_class = ChangeTrackingList
pass


class _LinkFieldOptions(Enum):
Expand All @@ -556,8 +568,6 @@ class LinkField(
See `Link to another record <https://airtable.com/developers/web/api/field-model#foreignkey>`__.
"""

list_class = ChangeTrackingList

_linked_model: Union[str, Literal[_LinkFieldOptions.LinkSelf], Type[T_Linked]]
_max_retrieve: Optional[int] = None

Expand Down Expand Up @@ -891,9 +901,12 @@ class AITextField(_DictField[AITextDict]):
readonly = True


class AttachmentsField(_ListFieldBase[AttachmentDict, AttachmentDict, AttachmentsList]):
contains_type = cast(Type[AttachmentDict], dict)
list_class = AttachmentsList
class AttachmentsField(
_ListFieldBase[AttachmentDict, AttachmentDict, AttachmentsList],
list_class=AttachmentsList,
contains_type=dict,
):
pass


class BarcodeField(_DictField[BarcodeDict]):
Expand Down Expand Up @@ -998,24 +1011,20 @@ class ManualSortField(TextField):
readonly = True


class MultipleCollaboratorsField(_ListField[CollaboratorDict]):
class MultipleCollaboratorsField(_ListField[CollaboratorDict], contains_type=dict):
"""
Accepts a list of dicts in the format detailed in
`Multiple Collaborators <https://airtable.com/developers/web/api/field-model#multicollaborator>`_.
"""

contains_type = cast(Type[CollaboratorDict], dict)


class MultipleSelectField(_ListField[str]):
class MultipleSelectField(_ListField[str], contains_type=str):
"""
Accepts a list of ``str``.
See `Multiple select <https://airtable.com/developers/web/api/field-model#multiselect>`__.
"""

contains_type = str


class PercentField(NumberField):
"""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_orm_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,22 @@ class T(Model):
with mock.patch("pyairtable.Table.update", return_value=obj.to_record()) as m:
obj.save(force=True)
m.assert_called_once_with(obj.id, fields, typecast=True)


@pytest.mark.parametrize(
"class_kwargs",
[
{"contains_type": 1},
{"list_class": 1},
{"list_class": dict},
],
)
def test_invalid_list_class_params(class_kwargs):
"""
Test that certain parameters to ListField are invalid.
"""

with pytest.raises(TypeError):

class ListFieldSubclass(f._ListField, **class_kwargs):
pass

0 comments on commit bf8e9ea

Please sign in to comment.