From 3500eab379593023147c35654758daf2c0eaf02d Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 24 Feb 2023 14:51:30 -0800 Subject: [PATCH] feat: add support for enable_dashboard_access field for Training jobs in SDK PiperOrigin-RevId: 512169773 --- google/cloud/aiplatform/training_jobs.py | 40 +++- tests/unit/aiplatform/test_training_jobs.py | 246 ++++++++++++++++++++ 2 files changed, 285 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 83880c3fa8..ef3c566f7d 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1479,6 +1479,7 @@ def _prepare_training_task_inputs_and_output_dir( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1508,6 +1509,9 @@ def _prepare_training_task_inputs_and_output_dir( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -1547,6 +1551,8 @@ def _prepare_training_task_inputs_and_output_dir( training_task_inputs["tensorboard"] = tensorboard if enable_web_access: training_task_inputs["enable_web_access"] = enable_web_access + if enable_dashboard_access: + training_task_inputs["enable_dashboard_access"] = enable_dashboard_access if timeout or restart_job_on_worker_restart: timeout = f"{timeout}s" if timeout else None @@ -1608,7 +1614,9 @@ def _wait_callback(self): self._has_logged_custom_job = True - if self._gca_resource.training_task_inputs.get("enable_web_access"): + if self._gca_resource.training_task_inputs.get( + "enable_web_access" + ) or self._gca_resource.training_task_inputs.get("enable_dashboard_access"): self._log_web_access_uris() def _custom_job_console_uri(self) -> str: @@ -2902,6 +2910,7 @@ def run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, @@ -3164,6 +3173,9 @@ def run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -3238,6 +3250,7 @@ def run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -3283,6 +3296,7 @@ def _run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -3444,6 +3458,9 @@ def _run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -3517,6 +3534,7 @@ def _run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, ) @@ -3834,6 +3852,7 @@ def run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, @@ -4089,6 +4108,9 @@ def run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -4162,6 +4184,7 @@ def run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -4206,6 +4229,7 @@ def _run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -4363,6 +4387,9 @@ def _run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -4430,6 +4457,7 @@ def _run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, ) @@ -6124,6 +6152,7 @@ def run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, @@ -6379,6 +6408,9 @@ def run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -6447,6 +6479,7 @@ def run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -6491,6 +6524,7 @@ def _run( timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, + enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -6635,6 +6669,9 @@ def _run( Whether you want Vertex AI to enable interactive shell access to training containers. https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + enable_dashboard_access (bool): + Whether you want Vertex AI to enable access to the customized dashboard + to training containers. tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -6702,6 +6739,7 @@ def _run( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, + enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index c9e5e9b528..da3bb4739e 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -208,7 +208,9 @@ _TEST_RESTART_JOB_ON_WORKER_RESTART = True _TEST_ENABLE_WEB_ACCESS = True +_TEST_ENABLE_DASHBOARD_ACCESS = True _TEST_WEB_ACCESS_URIS = {"workerpool0-0": "uri"} +_TEST_DASHBOARD_ACCESS_URIS = {"workerpool0-0:8888": "uri"} _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( job_spec=gca_custom_job.CustomJobSpec(), @@ -226,6 +228,19 @@ def _get_custom_job_proto_with_enable_web_access(state=None, name=None, version= return custom_job_proto +def _get_custom_job_proto_with_enable_dashboard_access( + state=None, name=None, version="v1" +): + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) + custom_job_proto.name = name + custom_job_proto.state = state + + custom_job_proto.job_spec.enable_dashboard_access = _TEST_ENABLE_DASHBOARD_ACCESS + if state == gca_job_state.JobState.JOB_STATE_RUNNING: + custom_job_proto.web_access_uris = _TEST_DASHBOARD_ACCESS_URIS + return custom_job_proto + + def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) custom_job_proto.name = name @@ -351,6 +366,40 @@ def mock_get_backing_custom_job_with_enable_web_access(): yield get_custom_job_mock +@pytest.fixture +def mock_get_backing_custom_job_with_enable_dashboard_access(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, + ), + _get_custom_job_proto_with_enable_dashboard_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + @pytest.mark.skipif( sys.executable is None, reason="requires python path to invoke subprocess" ) @@ -635,6 +684,19 @@ def make_training_pipeline_with_enable_web_access(state): return training_pipeline +def make_training_pipeline_with_enable_dashboard_access(state): + training_pipeline = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=state, + training_task_inputs={"enable_dashboard_access": _TEST_ENABLE_DASHBOARD_ACCESS}, + ) + if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: + training_pipeline.training_task_metadata = { + "backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME + } + return training_pipeline + + def make_training_pipeline_with_scheduling(state): training_pipeline = gca_training_pipeline.TrainingPipeline( name=_TEST_PIPELINE_RESOURCE_NAME, @@ -706,6 +768,35 @@ def mock_pipeline_service_get_with_enable_web_access(): yield mock_get_training_pipeline +@pytest.fixture +def mock_pipeline_service_get_with_enable_dashboard_access(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.side_effect = [ + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ), + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + ] + + yield mock_get_training_pipeline + + @pytest.fixture def mock_pipeline_service_get_with_scheduling(): with mock.patch.object( @@ -770,6 +861,19 @@ def mock_pipeline_service_create_with_enable_web_access(): yield mock_create_training_pipeline +@pytest.fixture +def mock_pipeline_service_create_with_enable_dashboard_access(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = ( + make_training_pipeline_with_enable_dashboard_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ) + ) + yield mock_create_training_pipeline + + @pytest.fixture def mock_pipeline_service_create_with_scheduling(): with mock.patch.object( @@ -2041,6 +2145,54 @@ def test_run_call_pipeline_service_create_with_enable_web_access( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) + # TODO: Update test to address Mutant issue b/270708320 + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_dashboard_access", + "mock_pipeline_service_get_with_enable_dashboard_access", + "mock_get_backing_custom_job_with_enable_dashboard_access", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_dashboard_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_dashboard_access=_TEST_ENABLE_DASHBOARD_ACCESS, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + job.wait() + + print(caplog.text) + assert "workerpool0-0:8888" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_dashboard_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures( @@ -3994,6 +4146,53 @@ def test_run_call_pipeline_service_create_with_enable_web_access( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) + # TODO: Update test to address Mutant issue b/270708320 + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_dashboard_access", + "mock_pipeline_service_get_with_enable_dashboard_access", + "mock_get_backing_custom_job_with_enable_dashboard_access", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_dashboard_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_dashboard_access=_TEST_ENABLE_DASHBOARD_ACCESS, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + job.wait() + + print(caplog.text) + assert "workerpool0-0:8888" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_dashboard_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures( @@ -6221,6 +6420,53 @@ def test_run_call_pipeline_service_create_with_enable_web_access( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) + # TODO: Update test to address Mutant issue b/270708320 + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_dashboard_access", + "mock_pipeline_service_get_with_enable_dashboard_access", + "mock_get_backing_custom_job_with_enable_dashboard_access", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_dashboard_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_dashboard_access=_TEST_ENABLE_DASHBOARD_ACCESS, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + job.wait() + print(caplog.text) + assert "workerpool0-0:8888" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_dashboard_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures(