Skip to content

Feature: support link header pagination #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions codegen/templates/versions/rest.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,29 @@

import importlib
from weakref import WeakKeyDictionary, ref
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, Union, Literal, TypeVar, Callable, Optional, Awaitable, overload
from typing_extensions import ParamSpec

from githubkit.rest.paginator import Paginator

from . import VERSIONS, LATEST_VERSION, VERSION_TYPE

if TYPE_CHECKING:
from githubkit import GitHubCore
from githubkit import GitHubCore, Response
{% for version, module in versions.items() %}
from .{{ module }}.rest import RestNamespace as {{ pascal_case(module) }}RestNamespace
{% endfor %}


CP = ParamSpec("CP")
CT = TypeVar("CT")
RT = TypeVar("RT")

R = Union[
Callable[CP, "Response[RT]"],
Callable[CP, Awaitable["Response[RT]"]],
]

if TYPE_CHECKING:

class _VersionProxy({{ pascal_case(versions[latest_version]) }}RestNamespace):
Expand Down Expand Up @@ -46,6 +59,33 @@ class RestVersionSwitcher(_VersionProxy):
"Do not use the namespace after the client has been collected."
)

@overload
def paginate(
self,
request: "R[CP, list[RT]]",
map_func: None = None,
*args: CP.args,
**kwargs: CP.kwargs,
) -> "Paginator[RT]": ...

@overload
def paginate(
self,
request: "R[CP, CT]",
map_func: Callable[["Response[CT]"], list[RT]],
*args: CP.args,
**kwargs: CP.kwargs,
) -> "Paginator[RT]": ...

def paginate(
self,
request: "R[CP, CT]",
map_func: Optional[Callable[["Response[CT]"], list[RT]]] = None,
*args: CP.args,
**kwargs: CP.kwargs,
) -> "Paginator[RT]":
return Paginator(self, request, map_func, *args, **kwargs) # type: ignore

{% for version, module in versions.items() %}
@overload
def __call__(self, version: Literal["{{ version }}"]) -> "{{ pascal_case(module) }}RestNamespace":
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/graphql.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ for result in github.graphql.paginate(
Note that the `result` is a dict containing the list of nodes/edges for each page and the `pageInfo` object. You should iterate over the `nodes` or `edges` list to get the actual data. For example:

```python
for result in g.graphql.paginate(query, {"owner": "owner", "repo": "repo"}):
for result in github.graphql.paginate(query, {"owner": "owner", "repo": "repo"}):
for issue in result["repository"]["issues"]["nodes"]:
print(issue)
```
Expand Down
10 changes: 5 additions & 5 deletions docs/usage/rest-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ Current supported versions are: (you can find it in the section `[[tool.codegen.

When a response from the REST API would include many results, GitHub will paginate the results and return a subset of the results. In this case, some APIs provide `page` and `per_page` parameters to control the pagination. See [GitHub Docs - Using pagination in the REST API](https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api) for more information.

githubkit provides a built-in pagination feature to handle this. You can use the `github.paginate` method to iterate over all the results:
githubkit provides a built-in pagination feature to handle this. You can use the `github.rest.paginate` method to iterate over all the results:

> Pagination typing is checked with Pylance ([Pyright](https://github.com/microsoft/pyright)).

Expand All @@ -307,7 +307,7 @@ githubkit provides a built-in pagination feature to handle this. You can use the
```python hl_lines="3-5"
from githubkit.versions.latest.models import Issue

for issue in github.paginate(
for issue in github.rest.paginate(
github.rest.issues.list_for_repo, owner="owner", repo="repo", state="open"
):
issue: Issue
Expand All @@ -319,7 +319,7 @@ githubkit provides a built-in pagination feature to handle this. You can use the
```python hl_lines="3-5"
from githubkit.versions.latest.models import Issue

async for issue in github.paginate(
async for issue in github.rest.paginate(
github.rest.issues.async_list_for_repo, owner="owner", repo="repo", state="open"
):
issue: Issue
Expand All @@ -333,7 +333,7 @@ You can also provide a custom map function to handle complex pagination (such as
```python hl_lines="5"
from githubkit.versions.latest.models import Repository

for accessible_repo in github.paginate(
for accessible_repo in github.rest.paginate(
github.rest.apps.list_installation_repos_for_authenticated_user,
map_func=lambda r: r.parsed_data.repositories,
installation_id=1,
Expand All @@ -347,7 +347,7 @@ You can also provide a custom map function to handle complex pagination (such as
```python hl_lines="5"
from githubkit.versions.latest.models import Repository

async for accessible_repo in github.paginate(
async for accessible_repo in github.rest.paginate(
github.rest.apps.async_list_installation_repos_for_authenticated_user,
map_func=lambda r: r.parsed_data.repositories,
installation_id=1,
Expand Down
6 changes: 5 additions & 1 deletion githubkit/paginator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Awaitable
from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast, overload
from typing_extensions import ParamSpec, Self
from typing_extensions import ParamSpec, Self, deprecated

from .response import Response
from .utils import is_async
Expand All @@ -16,6 +16,10 @@
]


@deprecated(
"Legacy pagination based on page and per_page is deprecated. "
"Use github.rest.paginate instead."
)
class Paginator(Generic[RT]):
"""Paginate through the responses of the rest api request."""

Expand Down
214 changes: 214 additions & 0 deletions githubkit/rest/paginator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from collections.abc import Awaitable
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import ParamSpec, Self

import httpx

from githubkit.response import Response
from githubkit.utils import is_async

if TYPE_CHECKING:
from githubkit.versions import RestVersionSwitcher

CP = ParamSpec("CP")
CT = TypeVar("CT")
RT = TypeVar("RT")
RTS = TypeVar("RTS")

R = Union[
Callable[CP, Response[RT]],
Callable[CP, Awaitable[Response[RT]]],
]

# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts#L40
NEXT_LINK_PATTERN = r'<([^<>]+)>;\s*rel="next"'


# https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api
# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts
class Paginator(Generic[RT]):
"""Paginate through the responses of the rest api request."""

@overload
def __init__(
self: "Paginator[RTS]",
rest: "RestVersionSwitcher",
request: R[CP, list[RTS]],
map_func: None = None,
*args: CP.args,
**kwargs: CP.kwargs,
): ...

@overload
def __init__(
self: "Paginator[RTS]",
rest: "RestVersionSwitcher",
request: R[CP, CT],
map_func: Callable[[Response[CT]], list[RTS]],
*args: CP.args,
**kwargs: CP.kwargs,
): ...

def __init__(
self,
rest: "RestVersionSwitcher",
request: R[CP, CT],
map_func: Optional[Callable[[Response[CT]], list[RT]]] = None,
*args: CP.args,
**kwargs: CP.kwargs,
):
self.rest = rest

self.request = request
self.args = args
self.kwargs = kwargs

self.map_func = map_func

self._initialized: bool = False
self._request_method: Optional[str] = None
self._response_model: Optional[Any] = None
self._next_link: Optional[httpx.URL] = None

self._index: int = 0
self._cached_data: list[RT] = []

@property
def finalized(self) -> bool:
"""Whether the paginator is finalized or not."""
return self._initialized and self._next_link is None

def reset(self) -> None:
"""Reset the paginator to the initial state."""

self._initialized = False
self._next_link = None
self._index = 0
self._cached_data = []

def __next__(self) -> RT:
while self._index >= len(self._cached_data):
self._get_next_page()
if self.finalized:
raise StopIteration

current = self._cached_data[self._index]
self._index += 1
return current

def __iter__(self: Self) -> Self:
if is_async(self.request):
raise TypeError(f"Request method {self.request} is not an sync function")
return self

async def __anext__(self) -> RT:
while self._index >= len(self._cached_data):
await self._aget_next_page()
if self.finalized:
raise StopAsyncIteration

current = self._cached_data[self._index]
self._index += 1
return current

def __aiter__(self: Self) -> Self:
if not is_async(self.request):
raise TypeError(f"Request method {self.request} is not an async function")
return self

def _find_next_link(self, response: Response[Any]) -> Optional[httpx.URL]:
"""Find the next link in the response headers."""
if links := response.headers.get("link"):
if match := re.search(NEXT_LINK_PATTERN, links):
return httpx.URL(match.group(1))
return None

def _apply_map_func(self, response: Response[Any]) -> list[RT]:
if self.map_func is not None:
result = self.map_func(response)
if not isinstance(result, list):
raise TypeError(f"Map function must return a list, got {type(result)}")
else:
result = cast(Response[list[RT]], response).parsed_data
if not isinstance(result, list):
raise TypeError(f"Response is not a list, got {type(result)}")
return result

def _fill_cache_data(self, data: list[RT]) -> None:
"""Fill the cache with the data."""
self._cached_data = data
self._index = 0

def _get_next_page(self) -> None:
if not self._initialized:
# First request
response = cast(
Response[Any],
self.request(*self.args, **self.kwargs),
)
self._initialized = True
self._request_method = response.raw_request.method
else:
# Next request
if self._next_link is None:
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
if self._request_method is None:
raise RuntimeError("Request method is not set, this should not happen.")
if self._response_model is None:
raise RuntimeError("Response model is not set, this should not happen.")

# we request the next page with the same method and response model
response = cast(
Response[Any],
self.rest._github.request(
self._request_method,
self._next_link,
headers=self.kwargs.get("headers"), # type: ignore
response_model=self._response_model, # type: ignore
),
)

self._next_link = self._find_next_link(response)
self._fill_cache_data(self._apply_map_func(response))

async def _aget_next_page(self) -> None:
if not self._initialized:
# First request
response = cast(
Response[Any],
await self.request(*self.args, **self.kwargs), # type: ignore
)
self._initialized = True
self._request_method = response.raw_request.method
else:
# Next request
if self._next_link is None:
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
if self._request_method is None:
raise RuntimeError("Request method is not set, this should not happen.")
if self._response_model is None:
raise RuntimeError("Response model is not set, this should not happen.")

response = cast(
Response[Any],
await self.rest._github.request(
self._request_method,
self._next_link,
headers=self.kwargs.get("headers"), # type: ignore
response_model=self._response_model, # type: ignore
),
)

self._next_link = self._find_next_link(response)
self._fill_cache_data(self._apply_map_func(response))
Loading