Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions airflow/providers/google/cloud/hooks/vertex_ai/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def upload_model(
project_id: str,
region: str,
model: Model | dict,
parent_model: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -218,18 +219,24 @@ def upload_model(
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_model_service_client(region)
parent = client.common_location_path(project_id, region)

request = {
"parent": parent,
"model": model,
}

if parent_model:
request["parent_model"] = parent_model

result = client.upload_model(
request={
"parent": parent,
"model": model,
},
request=request,
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand All @@ -385,6 +386,7 @@ def __init__(
project_id: str,
region: str,
model: Model | dict,
parent_model: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -396,6 +398,7 @@ def __init__(
self.project_id = project_id
self.region = region
self.model = model
self.parent_model = parent_model
self.retry = retry
self.timeout = timeout
self.metadata = metadata
Expand All @@ -412,6 +415,7 @@ def execute(self, context: Context):
project_id=self.project_id,
region=self.region,
model=self.model,
parent_model=self.parent_model,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEST_REGION: str = "test-region"
TEST_PROJECT_ID: str = "test-project-id"
TEST_MODEL = None
TEST_PARENT_MODEL = "test-parent-model"
TEST_MODEL_NAME: str = "test_model_name"
TEST_OUTPUT_CONFIG: dict = {}

Expand Down Expand Up @@ -136,6 +137,24 @@ def test_upload_model(self, mock_client) -> None:
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_upload_model_with_parent_model(self, mock_client) -> None:
self.hook.upload_model(
project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.upload_model.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_list_model_versions(self, mock_client) -> None:
self.hook.list_model_versions(
Expand Down Expand Up @@ -322,6 +341,24 @@ def test_upload_model(self, mock_client) -> None:
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_upload_model_with_parent_model(self, mock_client) -> None:
self.hook.upload_model(
project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.upload_model.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_list_model_versions(self, mock_client) -> None:
self.hook.list_model_versions(
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,34 @@ def test_execute(self, mock_hook, to_dict_mock):
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=None,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(VERTEX_AI_PATH.format("model_service.model_service.UploadModelResponse.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("model_service.ModelServiceHook"))
def test_execute_with_parent_model(self, mock_hook, to_dict_mock):
op = UploadModelOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=TEST_PARENT_MODEL,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.upload_model.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=TEST_PARENT_MODEL,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@
model=MODEL_OBJ,
)
# [END how_to_cloud_vertex_ai_upload_model_operator]
upload_model_with_parent_model = UploadModelOperator(
task_id="upload_model_with_parent_model",
region=REGION,
project_id=PROJECT_ID,
model=MODEL_OBJ,
parent_model=MODEL_DISPLAY_NAME,
)

# [START how_to_cloud_vertex_ai_export_model_operator]
export_model = ExportModelOperator(
Expand All @@ -251,6 +258,13 @@
trigger_rule=TriggerRule.ALL_DONE,
)
# [END how_to_cloud_vertex_ai_delete_model_operator]
delete_model_with_parent_model = DeleteModelOperator(
task_id="delete_model_with_parent_model",
project_id=PROJECT_ID,
region=REGION,
model_id=upload_model_with_parent_model.output["model_id"],
trigger_rule=TriggerRule.ALL_DONE,
)

# [START how_to_cloud_vertex_ai_list_models_operator]
list_models = ListModelsOperator(
Expand Down Expand Up @@ -317,8 +331,10 @@
>> set_default_version
>> add_version_alias
>> upload_model
>> upload_model_with_parent_model
>> export_model
>> delete_model
>> delete_model_with_parent_model
>> list_models
# TEST TEARDOWN
>> delete_version_alias
Expand Down