Skip to content

Commit

Permalink
feat: Add scheduled pipelines client list/pause/resume methods and un…
Browse files Browse the repository at this point in the history
…it tests.

PiperOrigin-RevId: 537995661
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jun 5, 2023
1 parent 74c2066 commit ce5dee4
Show file tree
Hide file tree
Showing 3 changed files with 560 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Optional
from typing import List, Optional

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
Expand All @@ -39,6 +39,7 @@
from google.cloud.aiplatform_v1beta1.types import (
pipeline_job as gca_pipeline_job_v1beta1,
)
from google.protobuf import field_mask_pb2 as field_mask


_LOGGER = base.Logger(__name__)
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
):
"""Retrieves a PipelineJobSchedule resource and instantiates its
representation.
Args:
pipeline_job (PipelineJob):
Required. PipelineJob used to init the schedule.
Expand Down Expand Up @@ -255,3 +257,131 @@ def _create(
)

_LOGGER.info("View Schedule:\n%s" % self._dashboard_uri())

@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
enable_simple_view: bool = True,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List["PipelineJobSchedule"]:
"""List all instances of this PipelineJobSchedule resource.
Example Usage:
aiplatform.PipelineJobSchedule.list(
filter='display_name="experiment_a27"',
order_by='create_time desc'
)
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
enable_simple_view (bool):
Optional. Whether to pass the `read_mask` parameter to the list call.
Defaults to False if not provided. This will improve the performance of calling
list(). However, the returned PipelineJobSchedule list will not include all fields for
each PipelineJobSchedule. Setting this to True will exclude the following fields in your
response: 'create_pipeline_job_request', 'next_run_time', 'last_pause_time',
'last_resume_time', 'max_concurrent_run_count', 'allow_queueing','last_scheduled_run_response'.
The following fields will be included in each PipelineJobSchedule resource in your
response: 'name', 'display_name', 'start_time', 'end_time', 'max_run_count',
'started_run_count', 'state', 'create_time', 'update_time', 'cron', 'catch_up'.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[PipelineJobSchedule] - A list of PipelineJobSchedule resource objects.
"""

read_mask_fields = None

if enable_simple_view:
read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS)
_LOGGER.warn(
"By enabling simple view, the PipelineJobSchedule resources returned from this method will not contain all fields."
)

return cls._list_with_local_order(
filter=filter,
order_by=order_by,
read_mask=read_mask_fields,
project=project,
location=location,
credentials=credentials,
)

def list_jobs(
self,
filter: Optional[str] = None,
order_by: Optional[str] = None,
enable_simple_view: bool = False,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[PipelineJob]:
"""List all PipelineJob 's created by this PipelineJobSchedule.
Example usage:
pipeline_job_schedule.list_jobs(order_by='create_time_desc')
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
enable_simple_view (bool):
Optional. Whether to pass the `read_mask` parameter to the list call.
Defaults to False if not provided. This will improve the performance of calling
list(). However, the returned PipelineJob list will not include all fields for
each PipelineJob. Setting this to True will exclude the following fields in your
response: `runtime_config`, `service_account`, `network`, and some subfields of
`pipeline_spec` and `job_detail`. The following fields will be included in
each PipelineJob resource in your response: `state`, `display_name`,
`pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`,
`update_time`, `labels`, `template_uri`, `template_metadata.version`,
`job_detail.pipeline_run_context`, `job_detail.pipeline_context`.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[PipelineJob] - A list of PipelineJob resource objects.
"""
list_filter = f"schedule_name={self._gca_resource.name}"
if filter:
list_filter = list_filter + f" AND {filter}"

return PipelineJob.list(
filter=list_filter,
order_by=order_by,
enable_simple_view=enable_simple_view,
project=project,
location=location,
credentials=credentials,
)
30 changes: 29 additions & 1 deletion google/cloud/aiplatform/preview/schedule/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
location: str,
):
"""Retrieves a Schedule resource and instantiates its representation.
Args:
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create this Schedule.
Expand Down Expand Up @@ -111,6 +110,35 @@ def get(

return self

def pause(self) -> None:
"""Starts asynchronous pause on the Schedule.
Changes Schedule state from State.ACTIVE to State.PAUSED.
"""
self.api_client.pause_schedule(name=self.resource_name)

def resume(
self,
catch_up: bool = True,
) -> None:
"""Starts asynchronous resume on the Schedule.
Changes Schedule state from State.PAUSED to State.ACTIVE.
Args:
catch_up (bool):
Optional. Whether to backfill missed runs when the Schedule is
resumed from State.PAUSED.
"""
self.api_client.resume_schedule(name=self.resource_name)

def done(self) -> bool:
"""Helper method that return True is Schedule is done. False otherwise."""
if not self._gca_resource:
return False

return self.state in _SCHEDULE_COMPLETE_STATES

def wait(self) -> None:
"""Wait for this Schedule to complete."""
if self._latest_future is None:
Expand Down
Loading

0 comments on commit ce5dee4

Please sign in to comment.