Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Add timeout arg to github repo reader (#873)
Browse files Browse the repository at this point in the history
* update github repo reader with timeout

* formatting
  • Loading branch information
rwood-97 authored Jan 19, 2024
1 parent 0db48a5 commit 10850e8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
2 changes: 2 additions & 0 deletions llama_hub/github_repo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ loader = GithubRepositoryReader(
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
verbose = True,
concurrent_requests = 10,
timeout = 5,
)

docs = loader.load_data(branch="main")
Expand Down Expand Up @@ -74,6 +75,7 @@ if docs is None:
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
verbose = True,
concurrent_requests = 10,
timeout = 5,
)

docs = loader.load_data(branch="main")
Expand Down
13 changes: 10 additions & 3 deletions llama_hub/github_repo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
use_parser: bool = False,
verbose: bool = False,
concurrent_requests: int = 5,
timeout: Optional[int] = 5,
filter_directories: Optional[Tuple[List[str], FilterType]] = None,
filter_file_extensions: Optional[Tuple[List[str], FilterType]] = None,
):
Expand All @@ -87,6 +88,7 @@ def __init__(
- verbose (bool): Whether to print verbose messages.
- concurrent_requests (int): Number of concurrent requests to
make to the Github API.
- timeout (int or None): Timeout for the requests to the Github API. Default is 5.
- filter_directories (Optional[Tuple[List[str], FilterType]]): Tuple
containing a list of directories and a FilterType. If the FilterType
is INCLUDE, only the files in the directories in the list will be
Expand All @@ -109,6 +111,7 @@ def __init__(
self._use_parser = use_parser
self._verbose = verbose
self._concurrent_requests = concurrent_requests
self._timeout = timeout
self._filter_directories = filter_directories
self._filter_file_extensions = filter_file_extensions

Expand Down Expand Up @@ -224,7 +227,9 @@ def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
:return: list of documents
"""
commit_response: GitCommitResponseModel = self._loop.run_until_complete(
self._github_client.get_commit(self._owner, self._repo, commit_sha)
self._github_client.get_commit(
self._owner, self._repo, commit_sha, timeout=self._timeout
)
)

tree_sha = commit_response.commit.tree.sha
Expand All @@ -247,7 +252,9 @@ def _load_data_from_branch(self, branch: str) -> List[Document]:
:return: list of documents
"""
branch_data: GitBranchResponseModel = self._loop.run_until_complete(
self._github_client.get_branch(self._owner, self._repo, branch)
self._github_client.get_branch(
self._owner, self._repo, branch, timeout=self._timeout
)
)

tree_sha = branch_data.commit.commit.tree.sha
Expand Down Expand Up @@ -319,7 +326,7 @@ async def _recurse_tree(
)

tree_data: GitTreeResponseModel = await self._github_client.get_tree(
self._owner, self._repo, tree_sha
self._owner, self._repo, tree_sha, timeout=self._timeout
)
print_if_verbose(
self._verbose, "\t" * current_depth + f"tree data: {tree_data}"
Expand Down
58 changes: 49 additions & 9 deletions llama_hub/github_repo/github_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ async def request(
endpoint: str,
method: str,
headers: Dict[str, Any] = {},
timeout: Optional[int] = 5,
**kwargs: Any,
) -> Any:
"""
Expand All @@ -283,6 +284,7 @@ async def request(
- `endpoint (str)`: Name of the endpoint to make the request to.
- `method (str)`: HTTP method to use for the request.
- `headers (dict)`: HTTP headers to include in the request.
- `timeout (int or None)`: Timeout for the request in seconds. Default is 5.
- `**kwargs`: Keyword arguments to pass to the endpoint URL.
Returns:
Expand All @@ -295,7 +297,7 @@ async def request(
Examples:
>>> response = client.request("getTree", "GET",
owner="owner", repo="repo",
tree_sha="tree_sha")
tree_sha="tree_sha", timeout=5)
"""
try:
import httpx
Expand All @@ -309,7 +311,9 @@ async def request(

_client: httpx.AsyncClient
async with httpx.AsyncClient(
headers=_headers, base_url=self._base_url
headers=_headers,
base_url=self._base_url,
timeout=timeout,
) as _client:
try:
response = await _client.request(
Expand All @@ -326,6 +330,7 @@ async def get_branch(
repo: str,
branch: Optional[str] = None,
branch_name: Optional[str] = None,
timeout: Optional[int] = 5,
) -> GitBranchResponseModel:
"""
Get information about a branch. (Github API endpoint: getBranch).
Expand All @@ -349,13 +354,22 @@ async def get_branch(
return GitBranchResponseModel.from_json(
(
await self.request(
"getBranch", "GET", owner=owner, repo=repo, branch=branch
"getBranch",
"GET",
owner=owner,
repo=repo,
branch=branch,
timeout=timeout,
)
).text
)

async def get_tree(
self, owner: str, repo: str, tree_sha: str
self,
owner: str,
repo: str,
tree_sha: str,
timeout: Optional[int] = 5,
) -> GitTreeResponseModel:
"""
Get information about a tree. (Github API endpoint: getTree).
Expand All @@ -364,6 +378,7 @@ async def get_tree(
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `tree_sha (str)`: SHA of the tree.
- `timeout (int or None)`: Timeout for the request in seconds. Default is 5.
Returns:
- `tree_info (GitTreeResponseModel)`: Information about the tree.
Expand All @@ -374,13 +389,22 @@ async def get_tree(
return GitTreeResponseModel.from_json(
(
await self.request(
"getTree", "GET", owner=owner, repo=repo, tree_sha=tree_sha
"getTree",
"GET",
owner=owner,
repo=repo,
tree_sha=tree_sha,
timeout=timeout,
)
).text
)

async def get_blob(
self, owner: str, repo: str, file_sha: str
self,
owner: str,
repo: str,
file_sha: str,
timeout: Optional[int] = 5,
) -> GitBlobResponseModel:
"""
Get information about a blob. (Github API endpoint: getBlob).
Expand All @@ -389,6 +413,7 @@ async def get_blob(
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `file_sha (str)`: SHA of the file.
- `timeout (int or None)`: Timeout for the request in seconds. Default is 5.
Returns:
- `blob_info (GitBlobResponseModel)`: Information about the blob.
Expand All @@ -399,13 +424,22 @@ async def get_blob(
return GitBlobResponseModel.from_json(
(
await self.request(
"getBlob", "GET", owner=owner, repo=repo, file_sha=file_sha
"getBlob",
"GET",
owner=owner,
repo=repo,
file_sha=file_sha,
timeout=timeout,
)
).text
)

async def get_commit(
self, owner: str, repo: str, commit_sha: str
self,
owner: str,
repo: str,
commit_sha: str,
timeout: Optional[int] = 5,
) -> GitCommitResponseModel:
"""
Get information about a commit. (Github API endpoint: getCommit).
Expand All @@ -414,6 +448,7 @@ async def get_commit(
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `commit_sha (str)`: SHA of the commit.
- `timeout (int or None)`: Timeout for the request in seconds. Default is 5.
Returns:
- `commit_info (GitCommitResponseModel)`: Information about the commit.
Expand All @@ -424,7 +459,12 @@ async def get_commit(
return GitCommitResponseModel.from_json(
(
await self.request(
"getCommit", "GET", owner=owner, repo=repo, commit_sha=commit_sha
"getCommit",
"GET",
owner=owner,
repo=repo,
commit_sha=commit_sha,
timeout=timeout,
)
).text
)
Expand Down

0 comments on commit 10850e8

Please sign in to comment.