Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 94 additions & 179 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
WorkflowJobRunLink,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -186,17 +186,6 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)


def _handle_overridden_json_params(operator):
for key, value in operator.overridden_json_params.items():
if value is not None:
operator.json[key] = value


def normalise_json_content(operator):
if operator.json:
operator.json = _normalise_json_content(operator.json)


class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""

Expand Down Expand Up @@ -263,23 +252,7 @@ class DatabricksCreateJobsOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"name",
"description",
"tags",
"tasks",
"job_clusters",
"email_notifications",
"webhook_notifications",
"notification_settings",
"timeout_seconds",
"schedule",
"max_concurrent_runs",
"git_source",
"access_control_list",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"
Expand Down Expand Up @@ -316,19 +289,34 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.name = name
self.description = description
self.tags = tags
self.tasks = tasks
self.job_clusters = job_clusters
self.email_notifications = email_notifications
self.webhook_notifications = webhook_notifications
self.notification_settings = notification_settings
self.timeout_seconds = timeout_seconds
self.schedule = schedule
self.max_concurrent_runs = max_concurrent_runs
self.git_source = git_source
self.access_control_list = access_control_list
if name is not None:
self.json["name"] = name
if description is not None:
self.json["description"] = description
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if notification_settings is not None:
self.json["notification_settings"] = notification_settings
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if self.json:
self.json = normalise_json_content(self.json)

@cached_property
def _hook(self):
Expand All @@ -340,40 +328,16 @@ def _hook(self):
caller="DatabricksCreateJobsOperator",
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"name": self.name,
"description": self.description,
"tags": self.tags,
"tasks": self.tasks,
"job_clusters": self.job_clusters,
"email_notifications": self.email_notifications,
"webhook_notifications": self.webhook_notifications,
"notification_settings": self.notification_settings,
"timeout_seconds": self.timeout_seconds,
"schedule": self.schedule,
"max_concurrent_runs": self.max_concurrent_runs,
"git_source": self.git_source,
"access_control_list": self.access_control_list,
}

_handle_overridden_json_params(self)

def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")

normalise_json_content(self)

def execute(self, context: Context) -> int:
self._setup_and_validate_json()

job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is not None:
acl_json = {"access_control_list": access_control_list}
self._hook.update_job_permission(job_id, _normalise_json_content(acl_json))
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))

return job_id

Expand Down Expand Up @@ -500,25 +464,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"tasks",
"spark_jar_task",
"notebook_task",
"spark_python_task",
"spark_submit_task",
"pipeline_task",
"dbt_task",
"new_cluster",
"existing_cluster_id",
"libraries",
"run_name",
"timeout_seconds",
"idempotency_token",
"access_control_list",
"git_source",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -564,21 +510,43 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.tasks = tasks
self.spark_jar_task = spark_jar_task
self.notebook_task = notebook_task
self.spark_python_task = spark_python_task
self.spark_submit_task = spark_submit_task
self.pipeline_task = pipeline_task
self.dbt_task = dbt_task
self.new_cluster = new_cluster
self.existing_cluster_id = existing_cluster_id
self.libraries = libraries
self.run_name = run_name
self.timeout_seconds = timeout_seconds
self.idempotency_token = idempotency_token
self.access_control_list = access_control_list
self.git_source = git_source
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
self.json["spark_jar_task"] = spark_jar_task
if notebook_task is not None:
self.json["notebook_task"] = notebook_task
if spark_python_task is not None:
self.json["spark_python_task"] = spark_python_task
if spark_submit_task is not None:
self.json["spark_submit_task"] = spark_submit_task
if pipeline_task is not None:
self.json["pipeline_task"] = pipeline_task
if dbt_task is not None:
self.json["dbt_task"] = dbt_task
if new_cluster is not None:
self.json["new_cluster"] = new_cluster
if existing_cluster_id is not None:
self.json["existing_cluster_id"] = existing_cluster_id
if libraries is not None:
self.json["libraries"] = libraries
if run_name is not None:
self.json["run_name"] = run_name
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if "run_name" not in self.json:
self.json["run_name"] = run_name or kwargs["task_id"]
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -597,43 +565,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"tasks": self.tasks,
"spark_jar_task": self.spark_jar_task,
"notebook_task": self.notebook_task,
"spark_python_task": self.spark_python_task,
"spark_submit_task": self.spark_submit_task,
"pipeline_task": self.pipeline_task,
"dbt_task": self.dbt_task,
"new_cluster": self.new_cluster,
"existing_cluster_id": self.existing_cluster_id,
"libraries": self.libraries,
"run_name": self.run_name,
"timeout_seconds": self.timeout_seconds,
"idempotency_token": self.idempotency_token,
"access_control_list": self.access_control_list,
"git_source": self.git_source,
}

_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
self.json["run_name"] = self.task_id

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if (
"pipeline_task" in self.json
and "pipeline_id" in self.json["pipeline_task"]
and "pipeline_name" in self.json["pipeline_task"]
):
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
Expand All @@ -643,7 +575,7 @@ def execute(self, context: Context):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = _normalise_json_content(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
Expand Down Expand Up @@ -679,7 +611,7 @@ def __init__(self, *args, **kwargs):

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
json_normalised = _normalise_json_content(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down Expand Up @@ -836,18 +768,7 @@ class DatabricksRunNowOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"job_id",
"job_name",
"notebook_params",
"python_params",
"python_named_params",
"jar_params",
"spark_submit_params",
"idempotency_token",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -890,14 +811,27 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
self.job_id = job_id
self.job_name = job_name
self.notebook_params = notebook_params
self.python_params = python_params
self.python_named_params = python_named_params
self.jar_params = jar_params
self.spark_submit_params = spark_submit_params
self.idempotency_token = idempotency_token

if job_id is not None:
self.json["job_id"] = job_id
if job_name is not None:
self.json["job_name"] = job_name
if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
self.json["notebook_params"] = notebook_params
if python_params is not None:
self.json["python_params"] = python_params
if python_named_params is not None:
self.json["python_named_params"] = python_named_params
if jar_params is not None:
self.json["jar_params"] = jar_params
if spark_submit_params is not None:
self.json["spark_submit_params"] = spark_submit_params
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if self.json:
self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -915,26 +849,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"job_id": self.job_id,
"job_name": self.job_name,
"notebook_params": self.notebook_params,
"python_params": self.python_params,
"python_named_params": self.python_named_params,
"jar_params": self.jar_params,
"spark_submit_params": self.spark_submit_params,
"idempotency_token": self.idempotency_token,
}
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.providers.databricks.hooks.databricks import RunState


def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
"""
Normalize content or all values of content if it is a dict to a string.

Expand All @@ -33,7 +33,7 @@ def _normalise_json_content(content, json_path: str = "json") -> str | bool | li
The only one exception is when we have boolean values, they can not be converted
to string type because databricks does not understand 'True' or 'False' values.
"""
normalise = _normalise_json_content
normalise = normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
Expand Down
Loading