Skip to content

Commit

Permalink
applied blacken formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanesmith authored and Bryan Smith committed Aug 30, 2022
1 parent 2504858 commit d6e20f6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
27 changes: 11 additions & 16 deletions google/cloud/aiplatform/featurestore/featurestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,11 +1221,10 @@ def batch_serve_to_df(
if bq_dataset_id is None:
temp_bq_full_dataset_id = self._get_ephemeral_bq_full_dataset_id(
featurestore_name_components["featurestore"],
featurestore_name_components["project"]
featurestore_name_components["project"],
)
temp_bq_dataset = self._create_ephemeral_bq_dataset(
bigquery_client,
temp_bq_full_dataset_id
bigquery_client, temp_bq_full_dataset_id
)
temp_bq_batch_serve_table_name = "batch_serve"
temp_bq_read_instances_table_name = "read_instances"
Expand All @@ -1237,8 +1236,8 @@ def batch_serve_to_df(
temp_bq_batch_serve_table_name = f"tmp_batch_serve_{uuid.uuid4()}".replace(
"-", "_"
)
temp_bq_read_instances_table_name = f"tmp_read_instances_{uuid.uuid4()}".replace(
"-", "_"
temp_bq_read_instances_table_name = (
f"tmp_read_instances_{uuid.uuid4()}".replace("-", "_")
)

temp_bq_batch_serve_table_id = (
Expand All @@ -1255,14 +1254,16 @@ def batch_serve_to_df(
# not Datetime
job_config = bigquery.LoadJobConfig(
schema=[
bigquery.SchemaField("timestamp", bigquery.enums.SqlTypeNames.TIMESTAMP)
bigquery.SchemaField(
"timestamp", bigquery.enums.SqlTypeNames.TIMESTAMP
)
]
)

job = bigquery_client.load_table_from_dataframe(
dataframe=read_instances_df,
destination=temp_bq_read_instances_table_id,
job_config=job_config
job_config=job_config,
)
job.result()

Expand Down Expand Up @@ -1312,12 +1313,9 @@ def batch_serve_to_df(

return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(frames)


def _get_ephemeral_bq_full_dataset_id(
self,
featurestore_id: str,
project_number: str
) -> str :
self, featurestore_id: str, project_number: str
) -> str:
temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace(
"-", "_"
)
Expand All @@ -1329,11 +1327,8 @@ def _get_ephemeral_bq_full_dataset_id(

return f"{project_id}.{temp_bq_dataset_name}"[:1024]


def _create_ephemeral_bq_dataset(
self,
bigquery_client: bigquery.Client,
dataset_id: str
self, bigquery_client: bigquery.Client, dataset_id: str
) -> "bigquery.Dataset":

temp_bq_dataset = bigquery.Dataset(dataset_ref=dataset_id)
Expand Down
32 changes: 21 additions & 11 deletions tests/unit/aiplatform/test_featurestores.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,13 @@ def bq_delete_dataset_mock(bq_client_mock):
with patch.object(bq_client_mock, "delete_dataset") as bq_delete_dataset_mock:
yield bq_delete_dataset_mock


@pytest.fixture
def bq_delete_table_mock(bq_client_mock):
with patch.object(bq_client_mock, "delete_table") as bq_delete_table_mock:
yield bq_delete_table_mock


@pytest.fixture
def bqs_client_mock():
mock = MagicMock(bigquery_storage.BigQueryReadClient)
Expand Down Expand Up @@ -1705,7 +1707,13 @@ def test_batch_serve_to_df(self, batch_read_feature_values_mock):
"get_project_mock",
)
@patch("uuid.uuid4", uuid_mock)
def test_batch_serve_to_df_user_specified_bq_dataset(self, batch_read_feature_values_mock, bq_create_dataset_mock, bq_delete_dataset_mock, bq_delete_table_mock):
def test_batch_serve_to_df_user_specified_bq_dataset(
self,
batch_read_feature_values_mock,
bq_create_dataset_mock,
bq_delete_dataset_mock,
bq_delete_table_mock,
):

aiplatform.init(project=_TEST_PROJECT_DIFF)

Expand All @@ -1715,16 +1723,18 @@ def test_batch_serve_to_df_user_specified_bq_dataset(self, batch_read_feature_va

read_instances_df = pd.DataFrame()

expected_temp_bq_dataset_name = 'my_dataset_name'
expected_temp_bq_dataset_id = f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[
:1024
]
expected_temp_bq_batch_serve_table_name = f"tmp_batch_serve_{uuid.uuid4()}".replace(
"-", "_"
expected_temp_bq_dataset_name = "my_dataset_name"
expected_temp_bq_dataset_id = (
f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[:1024]
)
expected_temp_bq_batch_serve_table_id = f"{expected_temp_bq_dataset_id}.{expected_temp_bq_batch_serve_table_name}"
expected_temp_bq_read_instances_table_name = f"tmp_read_instances_{uuid.uuid4()}".replace(
"-", "_"
expected_temp_bq_batch_serve_table_name = (
f"tmp_batch_serve_{uuid.uuid4()}".replace("-", "_")
)
expected_temp_bq_batch_serve_table_id = (
f"{expected_temp_bq_dataset_id}.{expected_temp_bq_batch_serve_table_name}"
)
expected_temp_bq_read_instances_table_name = (
f"tmp_read_instances_{uuid.uuid4()}".replace("-", "_")
)
expected_temp_bq_read_instances_table_id = f"{expected_temp_bq_dataset_id}.{expected_temp_bq_read_instances_table_name}"

Expand All @@ -1751,7 +1761,6 @@ def test_batch_serve_to_df_user_specified_bq_dataset(self, batch_read_feature_va
bigquery_read_instances=gca_io.BigQuerySource(
input_uri=f"bq://{expected_temp_bq_read_instances_table_id}"
),

)
)

Expand Down Expand Up @@ -1779,6 +1788,7 @@ def test_batch_serve_to_df_user_specified_bq_dataset(self, batch_read_feature_va
bq_create_dataset_mock.assert_not_called()
bq_delete_dataset_mock.assert_not_called()


class TestEntityType:
def setup_method(self):
reload(initializer)
Expand Down

0 comments on commit d6e20f6

Please sign in to comment.