Skip to content

Fix: rest paginator iteration error #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 47 additions & 45 deletions githubkit/rest/paginator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Generic,
Optional,
TypedDict,
TypeVar,
Union,
cast,
Expand All @@ -16,6 +17,7 @@
import httpx

from githubkit.response import Response
from githubkit.typing import HeaderTypes
from githubkit.utils import is_async

if TYPE_CHECKING:
Expand All @@ -35,6 +37,12 @@
NEXT_LINK_PATTERN = r'<([^<>]+)>;\s*rel="next"'


class PaginatorState(TypedDict):
next_link: Optional[httpx.URL]
request_method: str
response_model: Any


# https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api
# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts
class Paginator(Generic[RT]):
Expand Down Expand Up @@ -76,33 +84,34 @@ def __init__(

self.map_func = map_func

self._initialized: bool = False
self._request_method: Optional[str] = None
self._response_model: Optional[Any] = None
self._next_link: Optional[httpx.URL] = None
self._state: Optional[PaginatorState] = None

self._index: int = 0
self._cached_data: list[RT] = []

@property
def finalized(self) -> bool:
"""Whether the paginator is finalized or not."""
return self._initialized and self._next_link is None
return (self._state["next_link"] is None) if self._state is not None else False

@property
def _headers(self) -> Optional[HeaderTypes]:
return self.kwargs.get("headers") # type: ignore

def reset(self) -> None:
"""Reset the paginator to the initial state."""

self._initialized = False
self._next_link = None
self._state = None
self._index = 0
self._cached_data = []

def __next__(self) -> RT:
while self._index >= len(self._cached_data):
self._get_next_page()
if self.finalized:
raise StopIteration

self._get_next_page()

current = self._cached_data[self._index]
self._index += 1
return current
Expand All @@ -114,10 +123,11 @@ def __iter__(self: Self) -> Self:

async def __anext__(self) -> RT:
while self._index >= len(self._cached_data):
await self._aget_next_page()
if self.finalized:
raise StopAsyncIteration

await self._aget_next_page()

current = self._cached_data[self._index]
self._index += 1
return current
Expand Down Expand Up @@ -151,64 +161,56 @@ def _fill_cache_data(self, data: list[RT]) -> None:
self._index = 0

def _get_next_page(self) -> None:
if not self._initialized:
if self._state is None:
# First request
response = cast(
Response[Any],
self.request(*self.args, **self.kwargs),
)
self._initialized = True
self._request_method = response.raw_request.method
response = cast(Response[Any], self.request(*self.args, **self.kwargs))
else:
# Next request
if self._next_link is None:
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
if self._request_method is None:
raise RuntimeError("Request method is not set, this should not happen.")
if self._response_model is None:
raise RuntimeError("Response model is not set, this should not happen.")

# we request the next page with the same method and response model
if self._state["next_link"] is None:
raise RuntimeError("No next page to request")

response = cast(
Response[Any],
self.rest._github.request(
self._request_method,
self._next_link,
headers=self.kwargs.get("headers"), # type: ignore
response_model=self._response_model, # type: ignore
self._state["request_method"],
self._state["next_link"],
headers=self._headers, # type: ignore
response_model=self._state["response_model"], # type: ignore
),
)

self._next_link = self._find_next_link(response)
self._state = PaginatorState(
next_link=self._find_next_link(response),
request_method=response.raw_request.method,
response_model=response._data_model,
)
self._fill_cache_data(self._apply_map_func(response))

async def _aget_next_page(self) -> None:
if not self._initialized:
if self._state is None:
# First request
response = cast(
Response[Any],
await self.request(*self.args, **self.kwargs), # type: ignore
)
self._initialized = True
self._request_method = response.raw_request.method
else:
# Next request
if self._next_link is None:
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
if self._request_method is None:
raise RuntimeError("Request method is not set, this should not happen.")
if self._response_model is None:
raise RuntimeError("Response model is not set, this should not happen.")
# we request the next page with the same method and response model
if self._state["next_link"] is None:
raise RuntimeError("No next page to request")

response = cast(
Response[Any],
await self.rest._github.request(
self._request_method,
self._next_link,
headers=self.kwargs.get("headers"), # type: ignore
response_model=self._response_model, # type: ignore
await self.rest._github.arequest(
self._state["request_method"],
self._state["next_link"],
headers=self._headers, # type: ignore
response_model=self._state["response_model"], # type: ignore
),
)

self._next_link = self._find_next_link(response)
self._state = PaginatorState(
next_link=self._find_next_link(response),
request_method=response.raw_request.method,
response_model=response._data_model,
)
self._fill_cache_data(self._apply_map_func(response))
50 changes: 38 additions & 12 deletions tests/test_rest/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@

from githubkit import GitHub
from githubkit.versions import LATEST_VERSION
from githubkit.versions.latest.models import FullRepository
from githubkit.versions.latest.models import FullRepository, Issue

OWNER = "yanyongyu"
REPO = "githubkit"
ISSUE_COUNT_QUERY = """
query($owner: String!, $repo: String!) {
repository(owner: $owner, name: $repo) {
issues {
totalCount
}
}
}
"""


def test_call(g: GitHub):
Expand Down Expand Up @@ -56,34 +65,51 @@ async def test_async_call_with_raw_body(g: GitHub):

def test_paginate(g: GitHub):
paginator = g.rest.paginate(
g.rest.issues.list_for_repo, owner=OWNER, repo=REPO, per_page=50
g.rest.issues.list_for_repo, owner=OWNER, repo=REPO, state="all", per_page=50
)
for _ in paginator:
...
count = 0
for issue in paginator:
assert isinstance(issue, Issue)
if not issue.pull_request:
count += 1

result = g.graphql.request(ISSUE_COUNT_QUERY, {"owner": OWNER, "repo": REPO})
assert result["repository"]["issues"]["totalCount"] == count


def test_paginate_with_partial(g: GitHub):
paginator = g.rest.paginate(
partial(g.rest.issues.list_for_repo, OWNER, REPO), per_page=50
partial(g.rest.issues.list_for_repo, OWNER, REPO), state="all", per_page=50
)
for _ in paginator:
...
for issue in paginator:
assert isinstance(issue, Issue)


@pytest.mark.anyio
async def test_async_paginate(g: GitHub):
paginator = g.rest.paginate(
g.rest.issues.async_list_for_repo, owner=OWNER, repo=REPO, per_page=50
g.rest.issues.async_list_for_repo,
owner=OWNER,
repo=REPO,
state="all",
per_page=50,
)
async for _ in paginator:
...
count = 0
async for issue in paginator:
assert isinstance(issue, Issue)
if not issue.pull_request:
count += 1

result = g.graphql.request(ISSUE_COUNT_QUERY, {"owner": OWNER, "repo": REPO})
assert result["repository"]["issues"]["totalCount"] == count


@pytest.mark.anyio
async def test_async_paginate_with_partial(g: GitHub):
paginator = g.rest.paginate(
partial(g.rest.issues.async_list_for_repo, OWNER, REPO),
state="all",
per_page=50,
)
async for _ in paginator:
...
async for issue in paginator:
assert isinstance(issue, Issue)