Skip to content

Commit 3b51a36

Browse files
authored
1 parent 1ff2755 commit 3b51a36

File tree

5 files changed

+33
-46
lines changed

5 files changed

+33
-46
lines changed

bigframes/clients.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
)
3030
logger = logging.getLogger(__name__)
3131

32-
_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"
33-
3432

3533
class BqConnectionManager:
3634
"""Manager to handle operations with BQ connections."""
@@ -46,6 +44,23 @@ def __init__(
4644
self._bq_connection_client = bq_connection_client
4745
self._cloud_resource_manager_client = cloud_resource_manager_client
4846

47+
@classmethod
48+
def resolve_full_connection_name(
49+
cls, connection_name: str, default_project: str, default_location: str
50+
) -> str:
51+
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
52+
Use default project, location or connection_id when any of them are missing."""
53+
if connection_name.count(".") == 2:
54+
return connection_name
55+
56+
if connection_name.count(".") == 1:
57+
return f"{default_project}.{connection_name}"
58+
59+
if connection_name.count(".") == 0:
60+
return f"{default_project}.{default_location}.{connection_name}"
61+
62+
raise ValueError(f"Invalid connection name format: {connection_name}.")
63+
4964
def create_bq_connection(
5065
self, project_id: str, location: str, connection_id: str, iam_role: str
5166
):
@@ -164,25 +179,3 @@ def _get_service_account_if_connection_exists(
164179
pass
165180

166181
return service_account
167-
168-
169-
def get_connection_name_full(
170-
connection_name: Optional[str], default_project: str, default_location: str
171-
) -> str:
172-
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
173-
Use default project, location or connection_id when any of them are missing."""
174-
if connection_name is None:
175-
return (
176-
f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}"
177-
)
178-
179-
if connection_name.count(".") == 2:
180-
return connection_name
181-
182-
if connection_name.count(".") == 1:
183-
return f"{default_project}.{connection_name}"
184-
185-
if connection_name.count(".") == 0:
186-
return f"{default_project}.{default_location}.{connection_name}"
187-
188-
raise ValueError(f"Invalid connection name format: {connection_name}.")

bigframes/ml/llm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ def __init__(
4949
connection_name: Optional[str] = None,
5050
):
5151
self.session = session or bpd.get_global_session()
52+
self._bq_connection_manager = clients.BqConnectionManager(
53+
self.session.bqconnectionclient, self.session.resourcemanagerclient
54+
)
5255

5356
connection_name = connection_name or self.session._bq_connection
54-
self.connection_name = clients.get_connection_name_full(
57+
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
5558
connection_name,
5659
default_project=self.session._project,
5760
default_location=self.session._location,
5861
)
5962

60-
self._bq_connection_manager = clients.BqConnectionManager(
61-
self.session.bqconnectionclient, self.session.resourcemanagerclient
62-
)
6363
self._bqml_model_factory = globals.bqml_model_factory()
6464
self._bqml_model: core.BqmlModel = self._create_bqml_model()
6565

@@ -188,17 +188,17 @@ def __init__(
188188
connection_name: Optional[str] = None,
189189
):
190190
self.session = session or bpd.get_global_session()
191+
self._bq_connection_manager = clients.BqConnectionManager(
192+
self.session.bqconnectionclient, self.session.resourcemanagerclient
193+
)
191194

192195
connection_name = connection_name or self.session._bq_connection
193-
self.connection_name = clients.get_connection_name_full(
196+
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
194197
connection_name,
195198
default_project=self.session._project,
196199
default_location=self.session._location,
197200
)
198201

199-
self._bq_connection_manager = clients.BqConnectionManager(
200-
self.session.bqconnectionclient, self.session.resourcemanagerclient
201-
)
202202
self._bqml_model_factory = globals.bqml_model_factory()
203203
self._bqml_model: core.BqmlModel = self._create_bqml_model()
204204

bigframes/remote_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def remote_function(
772772
if not bigquery_connection:
773773
bigquery_connection = session._bq_connection # type: ignore
774774

775-
bigquery_connection = clients.get_connection_name_full(
775+
bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
776776
bigquery_connection,
777777
default_project=dataset_ref.project,
778778
default_location=bq_location,

bigframes/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@
9797
_BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com"
9898
_BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com"
9999

100+
_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"
101+
100102
_MAX_CLUSTER_COLUMNS = 4
101103

102104
# TODO(swast): Need to connect to regional endpoints when performing remote
@@ -321,7 +323,7 @@ def __init__(
321323
),
322324
)
323325

324-
self._bq_connection = context.bq_connection
326+
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID
325327

326328
# Now that we're starting the session, don't allow the options to be
327329
# changed.

tests/unit/test_clients.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,22 @@
1717
from bigframes import clients
1818

1919

20-
def test_get_connection_name_full_none():
21-
connection_name = clients.get_connection_name_full(
22-
None, default_project="default-project", default_location="us"
23-
)
24-
assert connection_name == "default-project.us.bigframes-default-connection"
25-
26-
2720
def test_get_connection_name_full_connection_id():
28-
connection_name = clients.get_connection_name_full(
21+
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
2922
"connection-id", default_project="default-project", default_location="us"
3023
)
3124
assert connection_name == "default-project.us.connection-id"
3225

3326

3427
def test_get_connection_name_full_location_connection_id():
35-
connection_name = clients.get_connection_name_full(
28+
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
3629
"eu.connection-id", default_project="default-project", default_location="us"
3730
)
3831
assert connection_name == "default-project.eu.connection-id"
3932

4033

4134
def test_get_connection_name_full_all():
42-
connection_name = clients.get_connection_name_full(
35+
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
4336
"my-project.eu.connection-id",
4437
default_project="default-project",
4538
default_location="us",
@@ -48,9 +41,8 @@ def test_get_connection_name_full_all():
4841

4942

5043
def test_get_connection_name_full_raise_value_error():
51-
5244
with pytest.raises(ValueError):
53-
clients.get_connection_name_full(
45+
clients.BqConnectionManager.resolve_full_connection_name(
5446
"my-project.eu.connection-id.extra_field",
5547
default_project="default-project",
5648
default_location="us",

0 commit comments

Comments
 (0)