Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: EntitlementIterator behavior and type-hinting #2555

Merged
merged 11 commits into from
Aug 28, 2024
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ These changes are available on the `master` branch, but have not yet been releas
- Added `Guild.fetch_role` method.
([#2528](https://github.com/Pycord-Development/pycord/pull/2528))

### Fixed

- Fixed `EntitlementIterator` behavior with `limit > 100`.
([#2555](https://github.com/Pycord-Development/pycord/pull/2555))

## [2.6.0] - 2024-07-09

### Added
Expand Down
74 changes: 56 additions & 18 deletions discord/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .types.audit_log import AuditLog as AuditLogPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload
from .types.monetization import Entitlement as EntitlementPayload
from .types.threads import Thread as ThreadPayload
from .types.user import PartialUser as PartialUserPayload
from .user import User
Expand Down Expand Up @@ -988,11 +989,21 @@ def __init__(
self.guild_id = guild_id
self.exclude_ended = exclude_ended

self._filter = None

if self.before and self.after:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
self._filter = lambda e: int(e["id"]) > self.after.id
elif self.after:
self._retrieve_entitlements = self._retrieve_entitlements_after_strategy
else:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy

self.state = state
self.get_entitlements = state.http.list_entitlements
self.entitlements = asyncio.Queue()

async def next(self) -> BanEntry:
async def next(self) -> Entitlement:
if self.entitlements.empty():
await self.fill_entitlements()

Expand All @@ -1014,30 +1025,57 @@ async def fill_entitlements(self):
if not self._get_retrieve():
return

data = await self._retrieve_entitlements(self.retrieve)

if self._filter:
data = list(filter(self._filter, data))

if len(data) < 100:
self.limit = 0 # terminate loop

for element in data:
await self.entitlements.put(Entitlement(data=element, state=self.state))

async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
"""Retrieve entitlements and update next parameters."""
raise NotImplementedError

async def _retrieve_entitlements_before_strategy(
self, retrieve: int
) -> list[EntitlementPayload]:
"""Retrieve entitlements using before parameter."""
before = self.before.id if self.before else None
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
before=before,
after=after,
limit=self.retrieve,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if data:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
return data

if not data:
# no data, terminate
return

if self.limit:
self.limit -= self.retrieve

if len(data) < 100:
self.limit = 0 # terminate loop

self.after = Object(id=int(data[-1]["id"]))

for element in reversed(data):
await self.entitlements.put(Entitlement(data=element, state=self.state))
async def _retrieve_entitlements_after_strategy(
self, retrieve: int
) -> list[EntitlementPayload]:
"""Retrieve entitlements using after parameter."""
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
after=after,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if data:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[-1]["id"]))
return data