Skip to content

feat: add bpd.options.bigquery.requests_transport_adapters option #1755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from __future__ import annotations

from typing import Literal, Optional
from typing import Literal, Optional, Sequence, Tuple
import warnings

import google.auth.credentials
import requests.adapters

import bigframes.enums
import bigframes.exceptions as bfe
Expand Down Expand Up @@ -90,6 +91,9 @@ def __init__(
allow_large_results: bool = False,
ordering_mode: Literal["strict", "partial"] = "strict",
client_endpoints_override: Optional[dict] = None,
requests_transport_adapters: Sequence[
Tuple[str, requests.adapters.BaseAdapter]
] = (),
):
self._credentials = credentials
self._project = project
Expand All @@ -100,6 +104,7 @@ def __init__(
self._kms_key_name = kms_key_name
self._skip_bq_connection_check = skip_bq_connection_check
self._allow_large_results = allow_large_results
self._requests_transport_adapters = requests_transport_adapters
self._session_started = False
# Determines the ordering strictness for the session.
self._ordering_mode = _validate_ordering_mode(ordering_mode)
Expand Down Expand Up @@ -379,3 +384,43 @@ def client_endpoints_override(self, value: dict):
)

self._client_endpoints_override = value

@property
def requests_transport_adapters(
self,
) -> Sequence[Tuple[str, requests.adapters.BaseAdapter]]:
"""Transport adapters for requests-based REST clients such as the
google-cloud-bigquery package.

For more details, see the explanation in `requests guide to transport
adapters
<https://requests.readthedocs.io/en/latest/user/advanced/#transport-adapters>`_.

**Examples:**

Increase the connection pool size using the requests `HTTPAdapter
<https://requests.readthedocs.io/en/latest/api/#requests.adapters.HTTPAdapter>`_.

>>> import bigframes.pandas as bpd
>>> bpd.options.bigquery.requests_transport_adapters = (
... ("http://", requests.adapters.HTTPAdapter(pool_maxsize=100)),
... ("https://", requests.adapters.HTTPAdapter(pool_maxsize=100)),
... ) # doctest: +SKIP

Returns:
Sequence[Tuple[str, requests.adapters.BaseAdapter]]:
Prefixes and corresponding transport adapters to `mount
<https://requests.readthedocs.io/en/latest/api/#requests.Session.mount>`_
in requests-based REST clients.
"""
return self._requests_transport_adapters

@requests_transport_adapters.setter
def requests_transport_adapters(
self, value: Sequence[Tuple[str, requests.adapters.BaseAdapter]]
) -> None:
if self._session_started and self._requests_transport_adapters != value:
raise ValueError(
SESSION_STARTED_MESSAGE.format(attribute="requests_transport_adapters")
)
self._requests_transport_adapters = value
1 change: 1 addition & 0 deletions bigframes/pandas/io/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def _set_default_session_location_if_possible(query):
application_name=config.options.bigquery.application_name,
bq_kms_key_name=config.options.bigquery.kms_key_name,
client_endpoints_override=config.options.bigquery.client_endpoints_override,
requests_transport_adapters=config.options.bigquery.requests_transport_adapters,
)

bqclient = clients_provider.bqclient
Expand Down
1 change: 1 addition & 0 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
application_name=context.application_name,
bq_kms_key_name=self._bq_kms_key_name,
client_endpoints_override=context.client_endpoints_override,
requests_transport_adapters=context.requests_transport_adapters,
)

# TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494
Expand Down
20 changes: 18 additions & 2 deletions bigframes/session/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
import os
import threading
import typing
from typing import Optional
from typing import Optional, Sequence, Tuple

import google.api_core.client_info
import google.api_core.client_options
import google.api_core.gapic_v1.client_info
import google.auth.credentials
import google.auth.transport.requests
import google.cloud.bigquery as bigquery
import google.cloud.bigquery_connection_v1
import google.cloud.bigquery_storage_v1
import google.cloud.functions_v2
import google.cloud.resourcemanager_v3
import pydata_google_auth
import requests

import bigframes.constants
import bigframes.version
Expand Down Expand Up @@ -79,6 +81,10 @@ def __init__(
application_name: Optional[str] = None,
bq_kms_key_name: Optional[str] = None,
client_endpoints_override: dict = {},
*,
requests_transport_adapters: Sequence[
Tuple[str, requests.adapters.BaseAdapter]
] = (),
):
credentials_project = None
if credentials is None:
Expand Down Expand Up @@ -124,6 +130,7 @@ def __init__(
)
self._location = location
self._use_regional_endpoints = use_regional_endpoints
self._requests_transport_adapters = requests_transport_adapters

self._credentials = credentials
self._bq_kms_key_name = bq_kms_key_name
Expand Down Expand Up @@ -173,12 +180,21 @@ def _create_bigquery_client(self):
user_agent=self._application_name
)

requests_session = google.auth.transport.requests.AuthorizedSession(
self._credentials
)
for prefix, adapter in self._requests_transport_adapters:
requests_session.mount(prefix, adapter)

bq_client = bigquery.Client(
client_info=bq_info,
client_options=bq_options,
credentials=self._credentials,
project=self._project,
location=self._location,
# Instead of credentials, use _http so that users can override
# requests options with transport adapters. See internal issue
# b/419106112.
_http=requests_session,
)

# If a new enough client library is available, we opt-in to the faster
Expand Down
1 change: 1 addition & 0 deletions tests/unit/_config/test_bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
("skip_bq_connection_check", False, True),
("client_endpoints_override", {}, {"bqclient": "endpoint_address"}),
("ordering_mode", "strict", "partial"),
("requests_transport_adapters", object(), object()),
],
)
def test_setter_raises_if_session_started(attribute, original_value, new_value):
Expand Down
28 changes: 22 additions & 6 deletions tests/unit/session/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,22 @@
import os
import pathlib
import tempfile
from typing import Optional
from typing import cast, Optional
import unittest.mock as mock

import google.api_core.client_info
import google.api_core.client_options
import google.api_core.exceptions
import google.api_core.gapic_v1.client_info
import google.auth.credentials
import google.cloud.bigquery
import google.cloud.bigquery_connection_v1
import google.cloud.bigquery_storage_v1
import google.cloud.functions_v2
import google.cloud.resourcemanager_v3
import requests.adapters

import bigframes.session.clients as clients
import bigframes.version


def create_clients_provider(application_name: Optional[str] = None):
def create_clients_provider(application_name: Optional[str] = None, **kwargs):
credentials = mock.create_autospec(google.auth.credentials.Credentials)
return clients.ClientsProvider(
project="test-project",
Expand All @@ -42,6 +39,7 @@ def create_clients_provider(application_name: Optional[str] = None):
credentials=credentials,
application_name=application_name,
bq_kms_key_name="projects/my-project/locations/us/keyRings/myKeyRing/cryptoKeys/myKey",
**kwargs,
)


Expand Down Expand Up @@ -136,6 +134,24 @@ def assert_clients_wo_user_agent(
)


def test_requests_transport_adapters_pool_maxsize(monkeypatch):
monkeypatch_client_constructors(monkeypatch)
requests_transport_adapters = (
("http://", requests.adapters.HTTPAdapter(pool_maxsize=123)),
("https://", requests.adapters.HTTPAdapter(pool_maxsize=123)),
) # doctest: +SKIP
provider = create_clients_provider(
requests_transport_adapters=requests_transport_adapters
)

_, kwargs = cast(mock.Mock, provider.bqclient).call_args
requests_session = kwargs.get("_http")
adapter: requests.adapters.HTTPAdapter = requests_session.get_adapter(
"https://bigquery.googleapis.com/"
)
assert adapter._pool_maxsize == 123 # type: ignore


def test_user_agent_default(monkeypatch):
monkeypatch_client_constructors(monkeypatch)
provider = create_clients_provider(application_name=None)
Expand Down