Skip to content

Commit

Permalink
feat: Add support for ordery_by in Metadata SDK list methods for Arti…
Browse files Browse the repository at this point in the history
…fact, Execution and Context.

PiperOrigin-RevId: 488531299
  • Loading branch information
SinaChavoshi authored and copybara-github committed Nov 15, 2022
1 parent a9a438a commit 2377606
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 26 deletions.
14 changes: 13 additions & 1 deletion google/cloud/aiplatform/metadata/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List artifacts in the parent path that matches the filter.
Expand All @@ -253,10 +254,21 @@ def _list_resources(
Required. The path where Artifacts are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of artifacts.
"""
list_request = gca_metadata_service.ListArtifactsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_artifacts(request=list_request)

Expand Down
14 changes: 13 additions & 1 deletion google/cloud/aiplatform/metadata/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List Contexts in the parent path that matches the filter.
Expand All @@ -309,11 +310,22 @@ def _list_resources(
Required. The path where Contexts are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of Contexts.
"""

list_request = gca_metadata_service.ListContextsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_contexts(request=list_request)

Expand Down
13 changes: 12 additions & 1 deletion google/cloud/aiplatform/metadata/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List Executions in the parent path that matches the filter.
Expand All @@ -469,11 +470,21 @@ def _list_resources(
Required. The path where Executions are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of execution.
"""

list_request = gca_metadata_service.ListExecutionsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_executions(request=list_request)

Expand Down
14 changes: 12 additions & 2 deletions google/cloud/aiplatform/metadata/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,14 @@ def update(
@classmethod
def list(
cls,
filter: Optional[str] = None,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["_Resource"]:
"""List Metadata resources that match the list filter in target metadataStore.
"""List resources that match the list filter in target metadataStore.
Args:
filter (str):
Expand All @@ -339,6 +340,14 @@ def list(
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
resources (sequence[_Resource]):
Expand All @@ -358,6 +367,7 @@ def list(
location=location,
credentials=credentials,
parent=parent,
order_by=order_by,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ def list_artifact_sample(
project: str,
location: str,
display_name_fitler: Optional[str] = "display_name=\"my_model_*\"",
create_date_filter: Optional[str] = "create_time>\"2022-06-11T12:30:00-08:00\"",
create_date_filter: Optional[str] = "create_time>\"2022-06-11\"",
order_by: Optional[str] = None,
):
aiplatform.init(
project=project,
location=location)
aiplatform.init(project=project, location=location)

combined_filters = f"{display_name_fitler} AND {create_date_filter}"
return aiplatform.Artifact.list(filter=combined_filters)
return aiplatform.Artifact.list(
filter=combined_filters,
order_by=order_by,
)


# [END aiplatform_sdk_create_artifact_with_sdk_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def test_list_artifact_with_sdk_sample(mock_artifact, mock_list_artifact):
location=constants.LOCATION,
display_name_fitler=constants.DISPLAY_NAME,
create_date_filter=constants.CREATE_DATE,
order_by=constants.ORDER_BY,
)

mock_list_artifact.assert_called_with(
filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}"
filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}",
order_by=constants.ORDER_BY,
)
assert len(artifacts) == 2
# Returning list of 2 context to avoid confusion with get method
Expand Down
2 changes: 2 additions & 0 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

CREATE_DATE = "2022-06-11T12:30:00-08:00"

ORDER_BY = "CREATE_TIME desc"

STAGING_BUCKET = "gs://my-staging-bucket"
EXPERIMENT_NAME = "fraud-detection-trial-72"
CREDENTIALS = credentials.AnonymousCredentials()
Expand Down
46 changes: 31 additions & 15 deletions tests/unit/aiplatform/test_metadata_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"test-param2": "test-value-1",
"test-param3": False,
}
_TEST_ORDER_BY = "test_order_by"

# context
_TEST_CONTEXT_ID = "test-context-id"
Expand Down Expand Up @@ -473,9 +474,11 @@ def test_update_context(self, update_context_mock):
def test_list_contexts(self, list_contexts_mock):
aiplatform.init(project=_TEST_PROJECT)

filter = "test-filter"
filter_query = "test-filter"
context_list = context.Context.list(
filter=filter, metadata_store_id=_TEST_METADATA_STORE
filter=filter_query,
metadata_store_id=_TEST_METADATA_STORE,
order_by=_TEST_ORDER_BY,
)

expected_context = GapicContext(
Expand All @@ -490,11 +493,14 @@ def test_list_contexts(self, list_contexts_mock):
list_contexts_mock.assert_called_once_with(
request={
"parent": _TEST_PARENT,
"filter": filter,
"filter": filter_query,
"order_by": _TEST_ORDER_BY,
}
)
assert len(context_list) == 2
# pylint: disable-next=protected-access
assert context_list[0]._gca_resource == expected_context
# pylint: disable-next=protected-access
assert context_list[1]._gca_resource == expected_context

@pytest.mark.usefixtures("get_context_mock")
Expand Down Expand Up @@ -710,9 +716,11 @@ def test_update_execution(self, update_execution_mock):
def test_list_executions(self, list_executions_mock):
aiplatform.init(project=_TEST_PROJECT)

filter = "test-filter"
filter_query = "test-filter"
execution_list = execution.Execution.list(
filter=filter, metadata_store_id=_TEST_METADATA_STORE
filter=filter_query,
metadata_store_id=_TEST_METADATA_STORE,
order_by=_TEST_ORDER_BY,
)

expected_execution = GapicExecution(
Expand All @@ -725,13 +733,16 @@ def test_list_executions(self, list_executions_mock):
)

list_executions_mock.assert_called_once_with(
request=dict(
parent=_TEST_PARENT,
filter=filter,
)
request={
"parent": _TEST_PARENT,
"filter": filter_query,
"order_by": _TEST_ORDER_BY,
}
)
assert len(execution_list) == 2
# pylint: disable-next=protected-access
assert execution_list[0]._gca_resource == expected_execution
# pylint: disable-next=protected-access
assert execution_list[1]._gca_resource == expected_execution

@pytest.mark.usefixtures("get_execution_mock", "get_artifact_mock")
Expand Down Expand Up @@ -1002,9 +1013,11 @@ def test_update_artifact(self, update_artifact_mock):
def test_list_artifacts(self, list_artifacts_mock):
aiplatform.init(project=_TEST_PROJECT)

filter = "test-filter"
filter_query = "test-filter"
artifact_list = artifact.Artifact.list(
filter=filter, metadata_store_id=_TEST_METADATA_STORE
filter=filter_query,
metadata_store_id=_TEST_METADATA_STORE,
order_by=_TEST_ORDER_BY,
)

expected_artifact = GapicArtifact(
Expand All @@ -1017,11 +1030,14 @@ def test_list_artifacts(self, list_artifacts_mock):
)

list_artifacts_mock.assert_called_once_with(
request=dict(
parent=_TEST_PARENT,
filter=filter,
)
request={
"parent": _TEST_PARENT,
"filter": filter_query,
"order_by": _TEST_ORDER_BY,
}
)
assert len(artifact_list) == 2
# pylint: disable-next=protected-access
assert artifact_list[0]._gca_resource == expected_artifact
# pylint: disable-next=protected-access
assert artifact_list[1]._gca_resource == expected_artifact

0 comments on commit 2377606

Please sign in to comment.