Skip to content

Commit 9a22be7

Browse files
committed
feat: add factory support to resource sequences
1 parent f4c5445 commit 9a22be7

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

src/posit/connect/resources.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,35 +92,36 @@ def update(self, **attributes): # type: ignore[reportIncompatibleMethodOverride
9292
super().update(**result)
9393

9494

95-
T = TypeVar("T", bound=Resource)
95+
_T = TypeVar("_T", bound=Resource)
96+
_T_co = TypeVar("_T_co", bound=Resource, covariant=True)
9697

9798

98-
class ResourceFactory(Protocol):
99-
def __call__(self, ctx: Context, path: str, **attributes) -> Resource: ...
99+
class ResourceFactory(Protocol[_T_co]):
100+
def __call__(self, ctx: Context, path: str, **attributes: Any) -> _T_co: ...
100101

101102

102-
class ResourceSequence(Protocol[T]):
103+
class ResourceSequence(Protocol[_T]):
103104
@overload
104-
def __getitem__(self, index: SupportsIndex, /) -> T: ...
105+
def __getitem__(self, index: SupportsIndex, /) -> _T: ...
105106

106107
@overload
107-
def __getitem__(self, index: slice, /) -> List[T]: ...
108+
def __getitem__(self, index: slice, /) -> List[_T]: ...
108109

109110
def __len__(self) -> int: ...
110111

111-
def __iter__(self) -> Iterator[T]: ...
112+
def __iter__(self) -> Iterator[_T]: ...
112113

113114
def __str__(self) -> str: ...
114115

115116
def __repr__(self) -> str: ...
116117

117118

118-
class _ResourceSequence(Sequence[T], ResourceSequence[T]):
119+
class _ResourceSequence(Sequence[_T], ResourceSequence[_T]):
119120
def __init__(
120121
self,
121122
ctx: Context,
122123
path: str,
123-
factory: ResourceFactory = _Resource,
124+
factory: ResourceFactory[_T_co] = _Resource,
124125
uid: str = "guid",
125126
):
126127
self._ctx = ctx
@@ -134,7 +135,7 @@ def __getitem__(self, index):
134135
def __len__(self) -> int:
135136
return len(list(self.fetch()))
136137

137-
def __iter__(self) -> Iterator[T]:
138+
def __iter__(self) -> Iterator[_T]:
138139
return iter(self.fetch())
139140

140141
def __str__(self) -> str:
@@ -143,32 +144,32 @@ def __str__(self) -> str:
143144
def __repr__(self) -> str:
144145
return repr(self.fetch())
145146

146-
def create(self, **attributes: Any) -> T:
147+
def create(self, **attributes: Any) -> _T:
147148
response = self._ctx.client.post(self._path, json=attributes)
148149
result = response.json()
149150
uid = result[self._uid]
150151
path = posixpath.join(self._path, uid)
151-
return cast(T, self._factory(self._ctx, path, **result))
152+
return cast(_T, self._factory(self._ctx, path, **result))
152153

153-
def fetch(self, **conditions) -> Iterable[T]:
154+
def fetch(self, **conditions) -> Iterable[_T]:
154155
response = self._ctx.client.get(self._path, params=conditions)
155156
results = response.json()
156-
resources: List[T] = []
157+
resources: List[_T] = []
157158
for result in results:
158159
uid = result[self._uid]
159160
path = posixpath.join(self._path, uid)
160-
resource = cast(T, self._factory(self._ctx, path, **result))
161+
resource = cast(_T, self._factory(self._ctx, path, **result))
161162
resources.append(resource)
162163

163164
return resources
164165

165-
def find(self, *args: str) -> T:
166+
def find(self, *args: str) -> _T:
166167
path = posixpath.join(self._path, *args)
167168
response = self._ctx.client.get(path)
168169
result = response.json()
169-
return cast(T, self._factory(self._ctx, path, **result))
170+
return cast(_T, self._factory(self._ctx, path, **result))
170171

171-
def find_by(self, **conditions) -> T | None:
172+
def find_by(self, **conditions) -> _T | None:
172173
"""
173174
Find the first record matching the specified conditions.
174175
@@ -183,19 +184,19 @@ def find_by(self, **conditions) -> T | None:
183184
Optional[T]
184185
The first record matching the conditions, or `None` if no match is found.
185186
"""
186-
collection: Iterable[T] = self.fetch(**conditions)
187+
collection: Iterable[_T] = self.fetch(**conditions)
187188
return next((v for v in collection if v.items() >= conditions.items()), None)
188189

189190

190-
class _PaginatedResourceSequence(_ResourceSequence[T]):
191-
def fetch(self, **conditions) -> Iterator[T]:
191+
class _PaginatedResourceSequence(_ResourceSequence[_T]):
192+
def fetch(self, **conditions) -> Iterator[_T]:
192193
paginator = Paginator(self._ctx, self._path, dict(**conditions))
193194
for page in paginator.fetch_pages():
194195
resources = []
195196
results = page.results
196197
for result in results:
197198
uid = result[self._uid]
198199
path = posixpath.join(self._path, uid)
199-
resource = cast(T, self._factory(self._ctx, path, **result))
200+
resource = cast(_T, self._factory(self._ctx, path, **result))
200201
resources.append(resource)
201202
yield from resources

0 commit comments

Comments
 (0)