Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 128 additions & 45 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,23 @@ class DatabricksCreateJobsOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"name",
"description",
"tags",
"tasks",
"job_clusters",
"email_notifications",
"webhook_notifications",
"notification_settings",
"timeout_seconds",
"schedule",
"max_concurrent_runs",
"git_source",
"access_control_list",
)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"
Expand Down Expand Up @@ -300,21 +316,19 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.overridden_json_params = {
"name": name,
"description": description,
"tags": tags,
"tasks": tasks,
"job_clusters": job_clusters,
"email_notifications": email_notifications,
"webhook_notifications": webhook_notifications,
"notification_settings": notification_settings,
"timeout_seconds": timeout_seconds,
"schedule": schedule,
"max_concurrent_runs": max_concurrent_runs,
"git_source": git_source,
"access_control_list": access_control_list,
}
self.name = name
self.description = description
self.tags = tags
self.tasks = tasks
self.job_clusters = job_clusters
self.email_notifications = email_notifications
self.webhook_notifications = webhook_notifications
self.notification_settings = notification_settings
self.timeout_seconds = timeout_seconds
self.schedule = schedule
self.max_concurrent_runs = max_concurrent_runs
self.git_source = git_source
self.access_control_list = access_control_list

@cached_property
def _hook(self):
Expand All @@ -327,6 +341,22 @@ def _hook(self):
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"name": self.name,
"description": self.description,
"tags": self.tags,
"tasks": self.tasks,
"job_clusters": self.job_clusters,
"email_notifications": self.email_notifications,
"webhook_notifications": self.webhook_notifications,
"notification_settings": self.notification_settings,
"timeout_seconds": self.timeout_seconds,
"schedule": self.schedule,
"max_concurrent_runs": self.max_concurrent_runs,
"git_source": self.git_source,
"access_control_list": self.access_control_list,
}

_handle_overridden_json_params(self)

if "name" not in self.json:
Expand Down Expand Up @@ -470,7 +500,25 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"tasks",
"spark_jar_task",
"notebook_task",
"spark_python_task",
"spark_submit_task",
"pipeline_task",
"dbt_task",
"new_cluster",
"existing_cluster_id",
"libraries",
"run_name",
"timeout_seconds",
"idempotency_token",
"access_control_list",
"git_source",
)
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -516,23 +564,21 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.overridden_json_params = {
"tasks": tasks,
"spark_jar_task": spark_jar_task,
"notebook_task": notebook_task,
"spark_python_task": spark_python_task,
"spark_submit_task": spark_submit_task,
"pipeline_task": pipeline_task,
"dbt_task": dbt_task,
"new_cluster": new_cluster,
"existing_cluster_id": existing_cluster_id,
"libraries": libraries,
"run_name": run_name,
"timeout_seconds": timeout_seconds,
"idempotency_token": idempotency_token,
"access_control_list": access_control_list,
"git_source": git_source,
}
self.tasks = tasks
self.spark_jar_task = spark_jar_task
self.notebook_task = notebook_task
self.spark_python_task = spark_python_task
self.spark_submit_task = spark_submit_task
self.pipeline_task = pipeline_task
self.dbt_task = dbt_task
self.new_cluster = new_cluster
self.existing_cluster_id = existing_cluster_id
self.libraries = libraries
self.run_name = run_name
self.timeout_seconds = timeout_seconds
self.idempotency_token = idempotency_token
self.access_control_list = access_control_list
self.git_source = git_source

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -552,6 +598,24 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"tasks": self.tasks,
"spark_jar_task": self.spark_jar_task,
"notebook_task": self.notebook_task,
"spark_python_task": self.spark_python_task,
"spark_submit_task": self.spark_submit_task,
"pipeline_task": self.pipeline_task,
"dbt_task": self.dbt_task,
"new_cluster": self.new_cluster,
"existing_cluster_id": self.existing_cluster_id,
"libraries": self.libraries,
"run_name": self.run_name,
"timeout_seconds": self.timeout_seconds,
"idempotency_token": self.idempotency_token,
"access_control_list": self.access_control_list,
"git_source": self.git_source,
}

_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
Expand Down Expand Up @@ -772,7 +836,18 @@ class DatabricksRunNowOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"job_id",
"job_name",
"notebook_params",
"python_params",
"python_named_params",
"jar_params",
"spark_submit_params",
"idempotency_token",
)
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -815,16 +890,14 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
self.overridden_json_params = {
"job_id": job_id,
"job_name": job_name,
"notebook_params": notebook_params,
"python_params": python_params,
"python_named_params": python_named_params,
"jar_params": jar_params,
"spark_submit_params": spark_submit_params,
"idempotency_token": idempotency_token,
}
self.job_id = job_id
self.job_name = job_name
self.notebook_params = notebook_params
self.python_params = python_params
self.python_named_params = python_named_params
self.jar_params = jar_params
self.spark_submit_params = spark_submit_params
self.idempotency_token = idempotency_token
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -843,6 +916,16 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"job_id": self.job_id,
"job_name": self.job_name,
"notebook_params": self.notebook_params,
"python_params": self.python_params,
"python_named_params": self.python_named_params,
"jar_params": self.jar_params,
"spark_submit_params": self.spark_submit_params,
"idempotency_token": self.idempotency_token,
}
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
Expand Down
Loading