Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import time
import uuid
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Sequence
Expand Down Expand Up @@ -76,6 +77,38 @@ class PreemptibilityType(Enum):
PREEMPTIBILITY_UNSPECIFIED = "PREEMPTIBILITY_UNSPECIFIED"


@dataclass
class InstanceSelection:
"""Defines machines types and a rank to which the machines types belong.

Representation for
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.InstanceSelection.

:param machine_types: Full machine-type names, e.g. "n1-standard-16".
:param rank: Preference of this instance selection. Lower number means higher preference.
Dataproc will first try to create a VM based on the machine-type with priority rank and fallback
to next rank based on availability. Machine types and instance selections with the same priority have
the same preference.
"""

machine_types: list[str]
rank: int = 0


@dataclass
class InstanceFlexibilityPolicy:
"""
Instance flexibility Policy allowing a mixture of VM shapes and provisioning models.

Representation for google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.

:param instance_selection_list: List of instance selection options that the group will use when
creating new VMs.
"""

instance_selection_list: list[InstanceSelection]


class ClusterGenerator:
"""Create a new Dataproc Cluster.

Expand All @@ -84,6 +117,11 @@ class ClusterGenerator:
to create the cluster. (templated)
:param num_workers: The # of workers to spin up. If set to zero will
spin up cluster in a single node mode
:param min_num_workers: The minimum number of primary worker instances to create.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not immediately clear how this value interacts with num_workers, might be worth elaborating a bit more?

From the API docs:

Example: Cluster creation request with num_instances = 5 and min_num_instances = 3:

If 4 VMs are created and 1 instance fails, the failed VM is deleted. The cluster is resized to 4 instances and placed in a RUNNING state.
If 2 instances are created and 3 instances fail, the cluster in placed in an ERROR state. The failed VMs are not deleted.

Could we add some of that context here? I don't have any strong opinions on the exact wording.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

If more than ``min_num_workers`` VMs are created out of ``num_workers``, the failed VMs will be
deleted, cluster is resized to available VMs and set to RUNNING.
If created VMs are less than ``min_num_workers``, the cluster is placed in ERROR state. The failed
VMs are not deleted.
:param storage_bucket: The storage bucket to use, setting to None lets dataproc
generate a custom one for you
:param init_actions_uris: List of GCS uri's containing
Expand Down Expand Up @@ -152,12 +190,18 @@ class ClusterGenerator:
``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]`` # noqa
:param enable_component_gateway: Provides access to the web interfaces of default and selected optional
components on the cluster.
:param driver_pool_size: The number of driver nodes in the node group.
:param driver_pool_id: The ID for the driver pool. Must be unique within the cluster. Use this ID to
identify the driver group in future operations, such as resizing the node group.
:param secondary_worker_instance_flexibility_policy: Instance flexibility Policy allowing a mixture of VM
shapes and provisioning models.
"""

def __init__(
self,
project_id: str,
num_workers: int | None = None,
min_num_workers: int | None = None,
zone: str | None = None,
network_uri: str | None = None,
subnetwork_uri: str | None = None,
Expand Down Expand Up @@ -190,11 +234,15 @@ def __init__(
auto_delete_ttl: int | None = None,
customer_managed_key: str | None = None,
enable_component_gateway: bool | None = False,
driver_pool_size: int = 0,
driver_pool_id: str | None = None,
secondary_worker_instance_flexibility_policy: InstanceFlexibilityPolicy | None = None,
**kwargs,
) -> None:
self.project_id = project_id
self.num_masters = num_masters
self.num_workers = num_workers
self.min_num_workers = min_num_workers
self.num_preemptible_workers = num_preemptible_workers
self.preemptibility = self._set_preemptibility_type(preemptibility)
self.storage_bucket = storage_bucket
Expand Down Expand Up @@ -227,6 +275,9 @@ def __init__(
self.customer_managed_key = customer_managed_key
self.enable_component_gateway = enable_component_gateway
self.single_node = num_workers == 0
self.driver_pool_size = driver_pool_size
self.driver_pool_id = driver_pool_id
self.secondary_worker_instance_flexibility_policy = secondary_worker_instance_flexibility_policy

if self.custom_image and self.image_version:
raise ValueError("The custom_image and image_version can't be both set")
Expand All @@ -240,6 +291,15 @@ def __init__(
if self.single_node and self.num_preemptible_workers > 0:
raise ValueError("Single node cannot have preemptible workers.")

if self.min_num_workers:
if not self.num_workers:
raise ValueError("Must specify num_workers when min_num_workers are provided.")
if self.min_num_workers > self.num_workers:
raise ValueError(
"The value of min_num_workers must be less than or equal to num_workers. "
f"Provided {self.min_num_workers}(min_num_workers) and {self.num_workers}(num_workers)."
)

def _set_preemptibility_type(self, preemptibility: str):
return PreemptibilityType(preemptibility.upper())

Expand Down Expand Up @@ -306,6 +366,17 @@ def _build_lifecycle_config(self, cluster_data):

return cluster_data

def _build_driver_pool(self):
driver_pool = {
"node_group": {
"roles": ["DRIVER"],
"node_group_config": {"num_instances": self.driver_pool_size},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on allowing instance type here as well? Something like:

Suggested change
"node_group_config": {"num_instances": self.driver_pool_size},
"node_group_config": {
"num_instances": self.driver_pool_size
"machine_type_uri": self.driver_pool_machine_type
},

Maybe this is too much config though? I'm fine with leaving this for another day.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let keep it these two for now. Later on, we can driver pool config (like InstanceFlexibilityPolicy) to cover rest of the config options.

},
}
if self.driver_pool_id:
driver_pool["node_group_id"] = self.driver_pool_id
return driver_pool

def _build_cluster_data(self):
if self.zone:
master_type_uri = (
Expand Down Expand Up @@ -343,6 +414,10 @@ def _build_cluster_data(self):
"autoscaling_config": {},
"endpoint_config": {},
}

if self.min_num_workers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise an Exception at initialization time if min_num_workers > num_workers?

It's usually best to surface issues as early as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cluster_data["worker_config"]["min_num_instances"] = self.min_num_workers

if self.num_preemptible_workers > 0:
cluster_data["secondary_worker_config"] = {
"num_instances": self.num_preemptible_workers,
Expand All @@ -354,6 +429,13 @@ def _build_cluster_data(self):
"is_preemptible": True,
"preemptibility": self.preemptibility.value,
}
if self.secondary_worker_instance_flexibility_policy:
cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = {
"instance_selection_list": [
vars(s)
for s in self.secondary_worker_instance_flexibility_policy.instance_selection_list
]
}

if self.storage_bucket:
cluster_data["config_bucket"] = self.storage_bucket
Expand Down Expand Up @@ -381,6 +463,9 @@ def _build_cluster_data(self):
if not self.single_node:
cluster_data["worker_config"]["image_uri"] = custom_image_url

if self.driver_pool_size > 0:
cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]

cluster_data = self._build_gce_cluster_config(cluster_data)

if self.single_node:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ dependencies:
- google-cloud-dataflow-client>=0.8.2
- google-cloud-dataform>=0.5.0
- google-cloud-dataplex>=1.4.2
- google-cloud-dataproc>=5.4.0
- google-cloud-dataproc>=5.5.0
- google-cloud-dataproc-metastore>=1.12.0
- google-cloud-dlp>=3.12.0
- google-cloud-kms>=2.15.0
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,9 @@ InspectContentResponse
InspectTemplate
instafail
installable
InstanceFlexibilityPolicy
InstanceGroupConfig
InstanceSelection
instanceTemplates
instantiation
integrations
Expand Down
140 changes: 140 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
DataprocSubmitSparkJobOperator,
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
InstanceFlexibilityPolicy,
InstanceSelection,
)
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
Expand Down Expand Up @@ -112,6 +114,7 @@
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/custom_image",
"min_num_instances": 1,
},
"secondary_worker_config": {
"num_instances": 4,
Expand All @@ -132,6 +135,17 @@
{"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}}
],
"endpoint_config": {},
"auxiliary_node_groups": [
{
"node_group": {
"roles": ["DRIVER"],
"node_group_config": {
"num_instances": 2,
},
},
"node_group_id": "cluster_driver_pool",
}
],
}
VIRTUAL_CLUSTER_CONFIG = {
"kubernetes_cluster_config": {
Expand Down Expand Up @@ -197,6 +211,64 @@
},
}

CONFIG_WITH_FLEX_MIG = {
"gce_cluster_config": {
"zone_uri": "https://www.googleapis.com/compute/v1/projects/project_id/zones/zone",
"metadata": {"metadata": "data"},
"network_uri": "network_uri",
"subnetwork_uri": "subnetwork_uri",
"internal_ip_only": True,
"tags": ["tags"],
"service_account": "service_account",
"service_account_scopes": ["service_account_scopes"],
},
"master_config": {
"num_instances": 2,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/master_machine_type",
"disk_config": {"boot_disk_type": "master_disk_type", "boot_disk_size_gb": 128},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/custom_image",
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/custom_image",
},
"secondary_worker_config": {
"num_instances": 4,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
"is_preemptible": True,
"preemptibility": "SPOT",
"instance_flexibility_policy": {
"instance_selection_list": [
{
"machine_types": [
"projects/project_id/zones/zone/machineTypes/machine1",
"projects/project_id/zones/zone/machineTypes/machine2",
],
"rank": 0,
},
{"machine_types": ["projects/project_id/zones/zone/machineTypes/machine3"], "rank": 1},
],
},
},
"software_config": {"properties": {"properties": "data"}, "optional_components": ["optional_components"]},
"lifecycle_config": {
"idle_delete_ttl": {"seconds": 60},
"auto_delete_time": "2019-09-12T00:00:00.000000Z",
},
"encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
"autoscaling_config": {"policy_uri": "autoscaling_policy"},
"config_bucket": "storage_bucket",
"initialization_actions": [
{"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}}
],
"endpoint_config": {},
}

LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}

LABELS.update({"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")})
Expand Down Expand Up @@ -361,10 +433,26 @@ def test_nodes_number(self):
)
assert "num_workers == 0 means single" in str(ctx.value)

def test_min_num_workers_less_than_num_workers(self):
with pytest.raises(ValueError) as ctx:
ClusterGenerator(
num_workers=3, min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
)
assert (
"The value of min_num_workers must be less than or equal to num_workers. "
"Provided 4(min_num_workers) and 3(num_workers)." in str(ctx.value)
)

def test_min_num_workers_without_num_workers(self):
with pytest.raises(ValueError) as ctx:
ClusterGenerator(min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME)
assert "Must specify num_workers when min_num_workers are provided." in str(ctx.value)

def test_build(self):
generator = ClusterGenerator(
project_id="project_id",
num_workers=2,
min_num_workers=1,
zone="zone",
network_uri="network_uri",
subnetwork_uri="subnetwork_uri",
Expand Down Expand Up @@ -395,6 +483,8 @@ def test_build(self):
auto_delete_time=datetime(2019, 9, 12),
auto_delete_ttl=250,
customer_managed_key="customer_managed_key",
driver_pool_id="cluster_driver_pool",
driver_pool_size=2,
)
cluster = generator.make()
assert CONFIG == cluster
Expand Down Expand Up @@ -438,6 +528,56 @@ def test_build_with_custom_image_family(self):
cluster = generator.make()
assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster

def test_build_with_flex_migs(self):
generator = ClusterGenerator(
project_id="project_id",
num_workers=2,
zone="zone",
network_uri="network_uri",
subnetwork_uri="subnetwork_uri",
internal_ip_only=True,
tags=["tags"],
storage_bucket="storage_bucket",
init_actions_uris=["init_actions_uris"],
init_action_timeout="10m",
metadata={"metadata": "data"},
custom_image="custom_image",
custom_image_project_id="custom_image_project_id",
autoscaling_policy="autoscaling_policy",
properties={"properties": "data"},
optional_components=["optional_components"],
num_masters=2,
master_machine_type="master_machine_type",
master_disk_type="master_disk_type",
master_disk_size=128,
worker_machine_type="worker_machine_type",
worker_disk_type="worker_disk_type",
worker_disk_size=256,
num_preemptible_workers=4,
preemptibility="Spot",
region="region",
service_account="service_account",
service_account_scopes=["service_account_scopes"],
idle_delete_ttl=60,
auto_delete_time=datetime(2019, 9, 12),
auto_delete_ttl=250,
customer_managed_key="customer_managed_key",
secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy(
[
InstanceSelection(
[
"projects/project_id/zones/zone/machineTypes/machine1",
"projects/project_id/zones/zone/machineTypes/machine2",
],
0,
),
InstanceSelection(["projects/project_id/zones/zone/machineTypes/machine3"], 1),
]
),
)
cluster = generator.make()
assert CONFIG_WITH_FLEX_MIG == cluster


class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):
Expand Down