Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions v03_pipeline/lib/misc/gcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import datetime

import google.auth
import google.auth.transport.requests
import google.oauth2.credentials
import pytz

SERVICE_ACCOUNT_CREDENTIALS = None
SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE = [
'https://www.googleapis.com/auth/userinfo.profile',
'https://www.googleapis.com/auth/userinfo.email',
'openid',
]
ONE_MINUTE_S = 60


def get_service_account_credentials() -> google.oauth2.credentials.Credentials:
global SERVICE_ACCOUNT_CREDENTIALS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you using the global keyword here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I figured it out. Is it because we don't want to reset SERVICE_ACCOUNT_CREDENTIALS every time get_service_account_credentials() is called?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we just refresh the token.

if not SERVICE_ACCOUNT_CREDENTIALS:
SERVICE_ACCOUNT_CREDENTIALS, _ = google.auth.default(
scopes=SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE,
)
tz = pytz.UTC
if (
SERVICE_ACCOUNT_CREDENTIALS.token
and (
tz.localize(SERVICE_ACCOUNT_CREDENTIALS.expiry)
- datetime.datetime.now(tz=tz)
).total_seconds()
> ONE_MINUTE_S
):
return SERVICE_ACCOUNT_CREDENTIALS
SERVICE_ACCOUNT_CREDENTIALS.refresh(
request=google.auth.transport.requests.Request(),
)
return SERVICE_ACCOUNT_CREDENTIALS
5 changes: 5 additions & 0 deletions v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pip._internal.operations import freeze as pip_freeze

from v03_pipeline.lib.logger import get_logger
from v03_pipeline.lib.misc.gcp import get_service_account_credentials
from v03_pipeline.lib.model import Env, ReferenceGenome
from v03_pipeline.lib.tasks.base.base_loading_pipeline_params import (
BaseLoadingPipelineParams,
Expand All @@ -22,9 +23,11 @@


def get_cluster_config(reference_genome: ReferenceGenome, run_id: str):
service_account_credentials = get_service_account_credentials()
return {
'project_id': Env.GCLOUD_PROJECT,
'cluster_name': f'{CLUSTER_NAME_PREFIX}-{reference_genome.value.lower()}-{run_id}',
# Schema found at https://cloud.google.com/dataproc/docs/reference/rest/v1/ClusterConfig
'config': {
'gce_cluster_config': {
'zone_uri': Env.GCLOUD_ZONE,
Expand All @@ -35,6 +38,8 @@ def get_cluster_config(reference_genome: ReferenceGenome, run_id: str):
'REFERENCE_GENOME': reference_genome.value,
'PIPELINE_RUNNER_APP_VERSION': Env.PIPELINE_RUNNER_APP_VERSION,
},
'service_account': service_account_credentials.service_account_email,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we will also call get_service_account_credentials() downstream when we actually make the API request. This ensures we properly scope and credential the cluster at spinup time.

'service_account_scopes': service_account_credentials.scopes,
},
'master_config': {
'num_instances': 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,29 @@
import google.api_core.exceptions
import luigi

from v03_pipeline.lib.misc.gcp import SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE
from v03_pipeline.lib.model import DatasetType, ReferenceGenome
from v03_pipeline.lib.tasks.dataproc.create_dataproc_cluster import (
CreateDataprocClusterTask,
)


@patch(
'v03_pipeline.lib.tasks.dataproc.create_dataproc_cluster.get_service_account_credentials',
return_value=SimpleNamespace(
service_account_email='test@serviceaccount.com',
scopes=SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE,
),
)
@patch(
'v03_pipeline.lib.tasks.dataproc.create_dataproc_cluster.dataproc.ClusterControllerClient',
)
class CreateDataprocClusterTaskTest(unittest.TestCase):
def test_dataset_type_unsupported(self, mock_cluster_controller: Mock) -> None:
def test_dataset_type_unsupported(
self,
mock_cluster_controller: Mock,
_: Mock,
) -> None:
worker = luigi.worker.Worker()
task = CreateDataprocClusterTask(
reference_genome=ReferenceGenome.GRCh38,
Expand All @@ -29,6 +41,7 @@ def test_dataset_type_unsupported(self, mock_cluster_controller: Mock) -> None:
def test_spinup_cluster_already_exists_failed(
self,
mock_cluster_controller: Mock,
_: Mock,
) -> None:
mock_client = mock_cluster_controller.return_value
mock_client.get_cluster.return_value = SimpleNamespace(
Expand All @@ -50,6 +63,7 @@ def test_spinup_cluster_already_exists_failed(
def test_spinup_cluster_already_exists_success(
self,
mock_cluster_controller: Mock,
_: Mock,
) -> None:
mock_client = mock_cluster_controller.return_value
mock_client.get_cluster.return_value = SimpleNamespace(
Expand All @@ -73,6 +87,7 @@ def test_spinup_cluster_doesnt_exist_failed(
self,
mock_logger: Mock,
mock_cluster_controller: Mock,
_: Mock,
) -> None:
mock_client = mock_cluster_controller.return_value
mock_client.get_cluster.side_effect = google.api_core.exceptions.NotFound(
Expand All @@ -98,6 +113,7 @@ def test_spinup_cluster_doesnt_exist_success(
self,
mock_logger: Mock,
mock_cluster_controller: Mock,
_: Mock,
) -> None:
mock_client = mock_cluster_controller.return_value
mock_client.get_cluster.side_effect = google.api_core.exceptions.NotFound(
Expand Down
Loading