diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index 8633b3dc1fbb3..b2aad96d98904 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -30,7 +30,6 @@ from airflow.providers.google.cloud.operators.gcs import GCSToLocalOperator from airflow.utils.dates import days_ago -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') GCS_TMP = os.environ.get('GCP_DATAFLOW_GCS_TMP', 'gs://test-dataflow-example/temp/') GCS_STAGING = os.environ.get('GCP_DATAFLOW_GCS_STAGING', 'gs://test-dataflow-example/staging/') GCS_OUTPUT = os.environ.get('GCP_DATAFLOW_GCS_OUTPUT', 'gs://test-dataflow-example/output') @@ -44,7 +43,6 @@ default_args = { "start_date": days_ago(1), 'dataflow_default_options': { - 'project': GCP_PROJECT_ID, 'tempLocation': GCS_TMP, 'stagingLocation': GCS_STAGING, } @@ -68,6 +66,7 @@ poll_sleep=10, job_class='org.apache.beam.examples.WordCount', check_if_running=CheckJobRunning.IgnoreJob, + location='europe-west3' ) # [END howto_operator_start_java_job] @@ -104,7 +103,8 @@ 'apache-beam[gcp]>=2.14.0' ], py_interpreter='python3', - py_system_site_packages=False + py_system_site_packages=False, + location='europe-west3' ) # [END howto_operator_start_python_job] @@ -130,4 +130,5 @@ 'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt", 'output': GCS_OUTPUT }, + location='europe-west3' ) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 1313037755c6c..d42bb0b43b392 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -25,6 +25,7 @@ import subprocess import time import uuid +import warnings from copy import deepcopy from tempfile import TemporaryDirectory from typing import Any, Callable, Dict, List, Optional, TypeVar @@ -49,36 +50,44 @@ RT = TypeVar('RT') # pylint: disable=invalid-name -def _fallback_to_project_id_from_variables(func: Callable[..., RT]) -> Callable[..., RT]: - """ - Decorator that provides fallback for Google Cloud Platform project id. +def _fallback_variable_parameter(parameter_name, variable_key_name): - :param func: function to wrap - :return: result of the function call - """ - @functools.wraps(func) - def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT: - if args: - raise AirflowException( - "You must use keyword arguments in this methods rather than positional") - - parameter_project_id = kwargs.get('project_id') - variables_project_id = kwargs.get('variables', {}).get('project') - - if parameter_project_id and variables_project_id: - raise AirflowException( - "The mutually exclusive parameter `project_id` and `project` key in `variables` parameters " - "are both present. Please remove one." - ) + def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]: + """ + Decorator that provides fallback for location from `region` key in `variables` parameters. + + :param func: function to wrap + :return: result of the function call + """ + @functools.wraps(func) + def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT: + if args: + raise AirflowException( + "You must use keyword arguments in this methods rather than positional") + + parameter_location = kwargs.get(parameter_name) + variables_location = kwargs.get('variables', {}).get(variable_key_name) + + if parameter_location and variables_location: + raise AirflowException( + f"The mutually exclusive parameter `{parameter_name}` and `{variable_key_name}` key " + f"in `variables` parameter are both present. Please remove one." + ) + if parameter_location or variables_location: + kwargs[parameter_name] = parameter_location or variables_location + if variables_location: + copy_variables = deepcopy(kwargs['variables']) + del copy_variables[variable_key_name] + kwargs['variables'] = copy_variables + + return func(self, *args, **kwargs) + return inner_wrapper - kwargs['project_id'] = parameter_project_id or variables_project_id - if variables_project_id: - copy_variables = deepcopy(kwargs['variables']) - del copy_variables['project'] - kwargs['variables'] = copy_variables + return _wrapper - return func(self, *args, **kwargs) - return inner_wrapper + +_fallback_to_location_from_variables = _fallback_variable_parameter('location', 'region') +_fallback_to_project_id_from_variables = _fallback_variable_parameter('project_id', 'project') class DataflowJobStatus: @@ -425,9 +434,9 @@ def _start_dataflow( label_formatter: Callable[[Dict], List[str]], project_id: str, multiple_jobs: bool = False, - on_new_job_id_callback: Optional[Callable[[str], None]] = None + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION ) -> None: - variables = self._set_variables(variables) cmd = command_prefix + self._build_cmd(variables, label_formatter, project_id) runner = _DataflowRunner( cmd=cmd, @@ -438,7 +447,7 @@ def _start_dataflow( dataflow=self.get_conn(), project_number=project_id, name=name, - location=variables['region'], + location=location, poll_sleep=self.poll_sleep, job_id=job_id, num_retries=self.num_retries, @@ -446,12 +455,7 @@ def _start_dataflow( ) job_controller.wait_for_done() - @staticmethod - def _set_variables(variables: Dict) -> Dict: - if 'region' not in variables.keys(): - variables['region'] = DEFAULT_DATAFLOW_LOCATION - return variables - + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id def start_java_dataflow( @@ -463,7 +467,8 @@ def start_java_dataflow( job_class: Optional[str] = None, append_job_name: bool = True, multiple_jobs: bool = False, - on_new_job_id_callback: Optional[Callable[[str], None]] = None + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION ) -> None: """ Starts Dataflow java job. @@ -484,9 +489,12 @@ def start_java_dataflow( :type multiple_jobs: bool :param on_new_job_id_callback: Callback called when the job ID is known. :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str """ name = self._build_dataflow_job_name(job_name, append_job_name) variables['jobName'] = name + variables['region'] = location def label_formatter(labels_dict): return ['--labels={}'.format( @@ -501,9 +509,11 @@ def label_formatter(labels_dict): label_formatter=label_formatter, project_id=project_id, multiple_jobs=multiple_jobs, - on_new_job_id_callback=on_new_job_id_callback + on_new_job_id_callback=on_new_job_id_callback, + location=location ) + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id def start_template_dataflow( @@ -514,7 +524,8 @@ def start_template_dataflow( dataflow_template: str, project_id: str, append_job_name: bool = True, - on_new_job_id_callback: Optional[Callable[[str], None]] = None + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION ) -> Dict: """ Starts Dataflow template job. @@ -533,8 +544,9 @@ def start_template_dataflow( :type append_job_name: bool :param on_new_job_id_callback: Callback called when the job ID is known. :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str """ - variables = self._set_variables(variables) name = self._build_dataflow_job_name(job_name, append_job_name) # Builds RuntimeEnvironment from variables dictionary # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment @@ -550,7 +562,7 @@ def start_template_dataflow( service = self.get_conn() request = service.projects().locations().templates().launch( # pylint: disable=no-member projectId=project_id, - location=variables['region'], + location=location, gcsPath=dataflow_template, body=body ) @@ -560,18 +572,18 @@ def start_template_dataflow( if on_new_job_id_callback: on_new_job_id_callback(job_id) - variables = self._set_variables(variables) jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), project_number=project_id, name=name, job_id=job_id, - location=variables['region'], + location=location, poll_sleep=self.poll_sleep, num_retries=self.num_retries) jobs_controller.wait_for_done() return response["job"] + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id def start_python_dataflow( # pylint: disable=too-many-arguments @@ -585,7 +597,8 @@ def start_python_dataflow( # pylint: disable=too-many-arguments py_requirements: Optional[List[str]] = None, py_system_site_packages: bool = False, append_job_name: bool = True, - on_new_job_id_callback: Optional[Callable[[str], None]] = None + on_new_job_id_callback: Optional[Callable[[str], None]] = None, + location: str = DEFAULT_DATAFLOW_LOCATION ): """ Starts Dataflow job. @@ -620,9 +633,12 @@ def start_python_dataflow( # pylint: disable=too-many-arguments If set to None or missing, the default project_id from the GCP connection is used. :param on_new_job_id_callback: Callback called when the job ID is known. :type on_new_job_id_callback: callable + :param location: Job location. + :type location: str """ name = self._build_dataflow_job_name(job_name, append_job_name) variables['job_name'] = name + variables['region'] = location def label_formatter(labels_dict): return ['--labels={}={}'.format(key, value) @@ -644,7 +660,8 @@ def label_formatter(labels_dict): command_prefix=command_prefix, label_formatter=label_formatter, project_id=project_id, - on_new_job_id_callback=on_new_job_id_callback + on_new_job_id_callback=on_new_job_id_callback, + location=location ) else: command_prefix = [py_interpreter] + py_options + [dataflow] @@ -655,7 +672,8 @@ def label_formatter(labels_dict): command_prefix=command_prefix, label_formatter=label_formatter, project_id=project_id, - on_new_job_id_callback=on_new_job_id_callback + on_new_job_id_callback=on_new_job_id_callback, + location=location ) @staticmethod @@ -700,27 +718,38 @@ def _build_cmd(variables: Dict, label_formatter: Callable, project_id: str) -> L command.append(f"--{attr}={value}") return command + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id - def is_job_dataflow_running(self, name: str, variables: Dict, project_id: str) -> bool: + def is_job_dataflow_running( + self, + name: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + variables: Optional[Dict] = None + ) -> bool: """ Helper method to check if jos is still running in dataflow :param name: The name of the job. :type name: str - :param variables: Variables passed to the job. - :type variables: dict :param project_id: Optional, the GCP project ID in which to start a job. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: Job location. + :type location: str :return: True if job is running. :rtype: bool """ - variables = self._set_variables(variables) + if variables: + warnings.warn( + "The variables parameter has been deprecated. You should pass location using " + "the location parameter.", DeprecationWarning, stacklevel=4) jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), project_number=project_id, name=name, - location=variables['region'], + location=location, poll_sleep=self.poll_sleep ) return jobs_controller.is_job_running() @@ -731,7 +760,7 @@ def cancel_job( project_id: str, job_name: Optional[str] = None, job_id: Optional[str] = None, - location: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: """ Cancels the job with the specified name prefix or Job ID. @@ -753,7 +782,7 @@ def cancel_job( project_number=project_id, name=job_name, job_id=job_id, - location=location or DEFAULT_DATAFLOW_LOCATION, + location=location, poll_sleep=self.poll_sleep ) jobs_controller.cancel() diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index f8bffd7207bd8..965e59a9e1a54 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -26,7 +26,7 @@ from typing import List, Optional from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.dataflow import DataflowHook +from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION, DataflowHook from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils.decorators import apply_defaults from airflow.version import version @@ -108,6 +108,11 @@ class DataflowCreateJavaJobOperator(BaseOperator): When defining labels (``labels`` option), you can also provide a dictionary. :type options: dict + :param project_id: Optional, the GCP project ID in which to start a job. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: Job location. + :type location: str :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str @@ -143,7 +148,6 @@ class DataflowCreateJavaJobOperator(BaseOperator): default_args = { 'dataflow_default_options': { - 'project': 'my-gcp-project', 'zone': 'europe-west1-d', 'stagingLocation': 'gs://my-staging-bucket/staging/' } @@ -182,6 +186,7 @@ def __init__( dataflow_default_options: Optional[dict] = None, options: Optional[dict] = None, project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, poll_sleep: int = 10, @@ -197,6 +202,7 @@ def __init__( options.setdefault('labels', {}).update( {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) self.project_id = project_id + self.location = location self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.jar = jar @@ -224,10 +230,13 @@ def execute(self, context): name=self.job_name, variables=dataflow_options, project_id=self.project_id, + location=self.location ) while is_running and self.check_if_running == CheckJobRunning.WaitForRun: is_running = self.hook.is_job_dataflow_running( - name=self.job_name, variables=dataflow_options, project_id=self.project_id) + name=self.job_name, variables=dataflow_options, project_id=self.project_id, + location=self.location + ) if not is_running: with ExitStack() as exit_stack: @@ -250,6 +259,7 @@ def set_current_job_id(job_id): multiple_jobs=self.multiple_jobs, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, + location=self.location ) def on_kill(self) -> None: @@ -271,6 +281,11 @@ class DataflowTemplatedJobStartOperator(BaseOperator): :type dataflow_default_options: dict :param parameters: Map of job specific parameters for the template. :type parameters: dict + :param project_id: Optional, the GCP project ID in which to start a job. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: Job location. + :type location: str :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str @@ -294,8 +309,6 @@ class DataflowTemplatedJobStartOperator(BaseOperator): default_args = { 'dataflow_default_options': { - 'project': 'my-gcp-project', - 'region': 'europe-west1', 'zone': 'europe-west1-d', 'tempLocation': 'gs://my-staging-bucket/staging/', } @@ -342,6 +355,7 @@ def __init__( dataflow_default_options: Optional[dict] = None, parameters: Optional[dict] = None, project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, poll_sleep: int = 10, @@ -357,6 +371,7 @@ def __init__( self.dataflow_default_options = dataflow_default_options self.parameters = parameters self.project_id = project_id + self.location = location self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.poll_sleep = poll_sleep @@ -380,6 +395,7 @@ def set_current_job_id(job_id): dataflow_template=self.template, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, + location=self.location ) return job @@ -442,9 +458,13 @@ class DataflowCreatePythonJobOperator(BaseOperator): See virtualenv documentation for more information. This option is only relevant if the ``py_requirements`` parameter is passed. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud - Platform. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str + :param project_id: Optional, the GCP project ID in which to start a job. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param location: Job location. + :type location: str :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -468,6 +488,7 @@ def __init__( # pylint: disable=too-many-arguments py_requirements: Optional[List[str]] = None, py_system_site_packages: bool = False, project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, poll_sleep: int = 10, @@ -487,6 +508,7 @@ def __init__( # pylint: disable=too-many-arguments self.py_requirements = py_requirements or [] self.py_system_site_packages = py_system_site_packages self.project_id = project_id + self.location = location self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.poll_sleep = poll_sleep @@ -528,6 +550,7 @@ def set_current_job_id(job_id): py_system_site_packages=self.py_system_site_packages, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, + location=self.location, ) def on_kill(self) -> None: diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 8471491389dcd..1b02f4672ac41 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -27,8 +27,8 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( - DataflowHook, DataflowJobStatus, DataflowJobType, _DataflowJobsController, _DataflowRunner, - _fallback_to_project_id_from_variables, + DEFAULT_DATAFLOW_LOCATION, DataflowHook, DataflowJobStatus, DataflowJobType, _DataflowJobsController, + _DataflowRunner, _fallback_to_project_id_from_variables, ) TASK_ID = 'test-dataflow-operator' @@ -79,7 +79,7 @@ TEST_PROJECT = 'test-project' TEST_JOB_NAME = 'test-job-name' TEST_JOB_ID = 'test-job-id' -TEST_LOCATION = 'us-central1' +TEST_LOCATION = 'custom-location' DEFAULT_PY_INTERPRETER = 'python3' @@ -119,7 +119,7 @@ def test_fn(self, *args, **kwargs): with self.assertRaisesRegex( AirflowException, - "The mutually exclusive parameter `project_id` and `project` key in `variables` parameters are " + "The mutually exclusive parameter `project_id` and `project` key in `variables` parameter are " "both present\\. Please remove one\\." ): FixtureFallback().test_fn(variables={'project': "TEST"}, project_id="TEST2") @@ -185,6 +185,63 @@ def test_start_python_dataflow( self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_python_dataflow_with_custom_region_as_variable( + self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid + ): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None + variables = copy.deepcopy(DATAFLOW_VARIABLES_PY) + variables['region'] = TEST_LOCATION + self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=variables, + dataflow=PY_FILE, py_options=PY_OPTIONS, + ) + expected_cmd = ["python3", '-m', PY_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), + sorted(expected_cmd)) + + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_python_dataflow_with_custom_region_as_paramater( + self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid + ): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None + self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, + dataflow=PY_FILE, py_options=PY_OPTIONS, + location=TEST_LOCATION + ) + expected_cmd = ["python3", '-m', PY_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), + sorted(expected_cmd)) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @@ -270,8 +327,10 @@ def test_start_java_dataflow(self, mock_conn, '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + self.assertListEqual( + sorted(expected_cmd), + sorted(mock_dataflow.call_args[1]["cmd"]), + ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -303,6 +362,70 @@ def test_start_java_dataflow_with_multiple_values_in_variables( self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_java_dataflow_with_custom_region_as_variable( + self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid + ): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None + + variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA) + variables['region'] = TEST_LOCATION + + self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=variables, + jar=JAR_FILE) + expected_cmd = ['java', '-jar', JAR_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] + self.assertListEqual( + sorted(expected_cmd), + sorted(mock_dataflow.call_args[1]["cmd"]), + ) + + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_java_dataflow_with_custom_region_as_parameter( + self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid + ): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None + + variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA) + variables['region'] = TEST_LOCATION + + self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=variables, + jar=JAR_FILE) + expected_cmd = ['java', '-jar', JAR_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] + self.assertListEqual( + sorted(expected_cmd), + sorted(mock_dataflow.call_args[1]["cmd"]), + ) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @@ -382,7 +505,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid): job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_TEMPLATE, parameters=PARAMETERS, dataflow_template=TEMPLATE, ) - options_with_region = {'region': 'us-central1'} + options_with_region = {'region': DEFAULT_DATAFLOW_LOCATION} options_with_region.update(DATAFLOW_VARIABLES_TEMPLATE) options_with_region_without_project = copy.deepcopy(options_with_region) del options_with_region_without_project['project'] @@ -397,18 +520,116 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid): } }, gcsPath='gs://dataflow-templates/wordcount/template_file', - location='us-central1', - projectId='test' + projectId='test', + location=DEFAULT_DATAFLOW_LOCATION, ) mock_controller.assert_called_once_with( dataflow=mock_conn.return_value, job_id='test-job-id', - location='us-central1', name='test-dataflow-pipeline-12345678', num_retries=5, poll_sleep=10, - project_number='test' + project_number='test', + location=DEFAULT_DATAFLOW_LOCATION + ) + mock_controller.return_value.wait_for_done.assert_called_once() + + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_template_dataflow_with_custom_region_as_variable( + self, mock_conn, mock_controller, mock_uuid + ): + launch_method = ( + mock_conn.return_value. + projects.return_value. + locations.return_value. + templates.return_value. + launch + ) + launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} + variables_with_region = {'region': TEST_LOCATION} + variables_with_region.update(DATAFLOW_VARIABLES_TEMPLATE) + variables_with_region_without_project = copy.deepcopy(variables_with_region) + del variables_with_region_without_project['project'] + + self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=variables_with_region, parameters=PARAMETERS, + dataflow_template=TEMPLATE, + ) + + launch_method.assert_called_once_with( + body={ + 'jobName': 'test-dataflow-pipeline-12345678', + 'parameters': PARAMETERS, + 'environment': { + 'zone': 'us-central1-f', + 'tempLocation': 'gs://test/temp' + } + }, + gcsPath='gs://dataflow-templates/wordcount/template_file', + projectId='test', + location=TEST_LOCATION, + ) + + mock_controller.assert_called_once_with( + dataflow=mock_conn.return_value, + job_id='test-job-id', + name='test-dataflow-pipeline-12345678', + num_retries=5, + poll_sleep=10, + project_number='test', + location=TEST_LOCATION, + ) + mock_controller.return_value.wait_for_done.assert_called_once() + + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_start_template_dataflow_with_custom_region_as_parameter( + self, mock_conn, mock_controller, mock_uuid + ): + launch_method = ( + mock_conn.return_value. + projects.return_value. + locations.return_value. + templates.return_value. + launch + ) + launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} + variables_with_region = {'region': TEST_LOCATION} + variables_with_region.update(DATAFLOW_VARIABLES_TEMPLATE) + variables_with_region_without_project = copy.deepcopy(variables_with_region) + del variables_with_region_without_project['project'] + + self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter + job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_TEMPLATE, parameters=PARAMETERS, + dataflow_template=TEMPLATE, location=TEST_LOCATION + ) + + launch_method.assert_called_once_with( + body={ + 'jobName': 'test-dataflow-pipeline-12345678', + 'parameters': PARAMETERS, + 'environment': { + 'zone': 'us-central1-f', + 'tempLocation': 'gs://test/temp' + } + }, + gcsPath='gs://dataflow-templates/wordcount/template_file', + projectId='test', + location=TEST_LOCATION, + ) + + mock_controller.assert_called_once_with( + dataflow=mock_conn.return_value, + job_id='test-job-id', + name='test-dataflow-pipeline-12345678', + num_retries=5, + poll_sleep=10, + project_number='test', + location=TEST_LOCATION, ) mock_controller.return_value.wait_for_done.assert_called_once() @@ -463,11 +684,12 @@ def test_cancel_job(self, mock_get_conn, jobs_controller): job_name=TEST_JOB_NAME, job_id=TEST_JOB_ID, project_id=TEST_PROJECT, + location=TEST_LOCATION ) jobs_controller.assert_called_once_with( dataflow=mock_get_conn.return_value, job_id=TEST_JOB_ID, - location='us-central1', + location=TEST_LOCATION, name=TEST_JOB_NAME, poll_sleep=10, project_number=TEST_PROJECT @@ -811,7 +1033,7 @@ def test_dataflow_job_cancel_job(self): mock_update.assert_called_once_with( body={'requestedState': 'JOB_STATE_CANCELLED'}, jobId='test-job-id', - location='us-central1', + location=TEST_LOCATION, projectId='test-project', ) mock_batch.add.assert_called_once_with( diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index c76ed39734555..8806ef7040931 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -60,6 +60,7 @@ } POLL_SLEEP = 30 GCS_HOOK_STRING = 'airflow.providers.google.cloud.operators.dataflow.{}' +TEST_LOCATION = "custom-location" class TestDataflowPythonOperator(unittest.TestCase): @@ -72,7 +73,9 @@ def setUp(self): py_options=PY_OPTIONS, dataflow_default_options=DEFAULT_OPTIONS_PYTHON, options=ADDITIONAL_OPTIONS, - poll_sleep=POLL_SLEEP) + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION + ) def test_init(self): """Test DataFlowPythonOperator instance is properly initialized.""" @@ -115,6 +118,7 @@ def test_exec(self, gcs_hook, dataflow_mock): py_system_site_packages=False, on_new_job_id_callback=mock.ANY, project_id=None, + location=TEST_LOCATION ) self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) @@ -129,7 +133,9 @@ def setUp(self): job_class=JOB_CLASS, dataflow_default_options=DEFAULT_OPTIONS_JAVA, options=ADDITIONAL_OPTIONS, - poll_sleep=POLL_SLEEP) + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION + ) def test_init(self): """Test DataflowTemplateOperator instance is properly initialized.""" @@ -166,6 +172,7 @@ def test_exec(self, gcs_hook, dataflow_mock): multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, + location=TEST_LOCATION ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @@ -185,7 +192,7 @@ def test_check_job_running_exec(self, gcs_hook, dataflow_mock): gcs_provide_file.assert_not_called() start_java_hook.assert_not_called() dataflow_running.assert_called_once_with( - name=JOB_NAME, variables=mock.ANY, project_id=None) + name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') @@ -211,9 +218,10 @@ def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, + location=TEST_LOCATION ) dataflow_running.assert_called_once_with( - name=JOB_NAME, variables=mock.ANY, project_id=None) + name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') @@ -240,9 +248,11 @@ def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): multiple_jobs=True, on_new_job_id_callback=mock.ANY, project_id=None, + location=TEST_LOCATION ) dataflow_running.assert_called_once_with( - name=JOB_NAME, variables=mock.ANY, project_id=None) + name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION + ) class TestDataflowTemplateOperator(unittest.TestCase): @@ -254,7 +264,9 @@ def setUp(self): job_name=JOB_NAME, parameters=PARAMETERS, dataflow_default_options=DEFAULT_OPTIONS_TEMPLATE, - poll_sleep=POLL_SLEEP) + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION + ) def test_init(self): """Test DataflowTemplateOperator instance is properly initialized.""" @@ -288,4 +300,5 @@ def test_exec(self, dataflow_mock): dataflow_template=TEMPLATE, on_new_job_id_callback=mock.ANY, project_id=None, + location=TEST_LOCATION ) diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py index 665948b0b4840..692cd3ff4431e 100644 --- a/tests/providers/google/cloud/operators/test_mlengine_utils.py +++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py @@ -119,6 +119,7 @@ def test_successful_run(self): py_system_site_packages=False, on_new_job_id_callback=ANY, project_id='test-project', + location='us-central1', ) with patch('airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook') as mock_gcs_hook: