Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1261,17 +1261,31 @@ def _get_hook(self, caller: str) -> DatabricksHook:
def databricks_task_key(self) -> str:
return self._generate_databricks_task_key()

def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
def _generate_databricks_task_key(
self, task_id: str | None = None, task_dict: dict[str, BaseOperator] | None = None
) -> str:
"""Create a databricks task key using the hash of dag_id and task_id."""
if not self._databricks_task_key or len(self._databricks_task_key) > 100:
self.log.info(
"databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id"
)
task_id = task_id or self.task_id
task_key = f"{self.dag_id}__{task_id}".encode()
self._databricks_task_key = hashlib.md5(task_key).hexdigest()
self.log.info("Generated databricks task_key: %s", self._databricks_task_key)
return self._databricks_task_key
if task_id:
if not task_dict:
raise ValueError(
"Must pass task_dict if task_id is provided in _generate_databricks_task_key."
)
_task = task_dict.get(task_id)
if _task and hasattr(_task, "databricks_task_key"):
_databricks_task_key = _task.databricks_task_key
else:
task_key = f"{self.dag_id}__{task_id}".encode()
_databricks_task_key = hashlib.md5(task_key).hexdigest()
return _databricks_task_key
else:
if not self._databricks_task_key or len(self._databricks_task_key) > 100:
self.log.info(
"databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id"
)
task_key = f"{self.dag_id}__{self.task_id}".encode()
self._databricks_task_key = hashlib.md5(task_key).hexdigest()
self.log.info("Generated databricks task_key: %s", self._databricks_task_key)
return self._databricks_task_key

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
Expand Down Expand Up @@ -1354,14 +1368,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
return {task["task_key"]: task for task in sorted_task_runs}[self.databricks_task_key]

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
self,
relevant_upstreams: list[BaseOperator],
task_dict: dict[str, BaseOperator],
context: Context | None = None,
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
base_task_json = self._get_task_base_json()
result = {
"task_key": self.databricks_task_key,
"depends_on": [
{"task_key": self._generate_databricks_task_key(task_id)}
{"task_key": self._generate_databricks_task_key(task_id, task_dict)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
Expand Down Expand Up @@ -1571,7 +1588,10 @@ def _extend_workflow_notebook_packages(
self.notebook_packages.append(task_group_package)

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
self,
relevant_upstreams: list[BaseOperator],
task_dict: dict[str, BaseOperator],
context: Context | None = None,
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
databricks_workflow_task_group = self._databricks_workflow_task_group
Expand All @@ -1589,7 +1609,7 @@ def _convert_to_databricks_workflow_task(
**databricks_workflow_task_group.notebook_params,
}

return super()._convert_to_databricks_workflow_task(relevant_upstreams, context=context)
return super()._convert_to_databricks_workflow_task(relevant_upstreams, task_dict, context=context)


class DatabricksTaskOperator(DatabricksTaskBaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
:param max_concurrent_runs: The maximum number of concurrent runs for the workflow.
:param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
will be passed to all notebooks in the workflow.
:param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be
:param tasks_to_convert: A dict of tasks to convert to a Databricks workflow. This list can also be
populated after instantiation using the `add_task` method.
"""

Expand All @@ -105,7 +105,7 @@ def __init__(
job_clusters: list[dict[str, object]] | None = None,
max_concurrent_runs: int = 1,
notebook_params: dict | None = None,
tasks_to_convert: list[BaseOperator] | None = None,
tasks_to_convert: dict[str, BaseOperator] | None = None,
**kwargs,
):
self.databricks_conn_id = databricks_conn_id
Expand All @@ -114,7 +114,7 @@ def __init__(
self.job_clusters = job_clusters or []
self.max_concurrent_runs = max_concurrent_runs
self.notebook_params = notebook_params or {}
self.tasks_to_convert = tasks_to_convert or []
self.tasks_to_convert = tasks_to_convert or {}
self.relevant_upstreams = [task_id]
self.workflow_run_metadata: WorkflowRunMetadata | None = None
super().__init__(task_id=task_id, **kwargs)
Expand All @@ -129,9 +129,9 @@ def _get_hook(self, caller: str) -> DatabricksHook:
def _hook(self) -> DatabricksHook:
return self._get_hook(caller=self.caller)

def add_task(self, task: BaseOperator) -> None:
"""Add a task to the list of tasks to convert to a Databricks workflow."""
self.tasks_to_convert.append(task)
def add_task(self, task_id, task: BaseOperator) -> None:
"""Add a task to the dict of tasks to convert to a Databricks workflow."""
self.tasks_to_convert[task_id] = task

@property
def job_name(self) -> str:
Expand All @@ -143,9 +143,9 @@ def create_workflow_json(self, context: Context | None = None) -> dict[str, obje
"""Create a workflow json to be used in the Databricks API."""
task_json = [
task._convert_to_databricks_workflow_task( # type: ignore[attr-defined]
relevant_upstreams=self.relevant_upstreams, context=context
relevant_upstreams=self.relevant_upstreams, task_dict=self.tasks_to_convert, context=context
)
for task in self.tasks_to_convert
for task_id, task in self.tasks_to_convert.items()
]

default_json = {
Expand Down Expand Up @@ -334,7 +334,7 @@ def __exit__(

task.workflow_run_metadata = create_databricks_workflow_task.output
create_databricks_workflow_task.relevant_upstreams.append(task.task_id)
create_databricks_workflow_task.add_task(task)
create_databricks_workflow_task.add_task(task.task_id, task)

for root_task in roots:
root_task.set_upstream(create_databricks_workflow_task)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2385,8 +2385,9 @@ def test_convert_to_databricks_workflow_task(self):
operator.task_id = "test_task"
operator.upstream_task_ids = ["upstream_task"]
relevant_upstreams = [MagicMock(task_id="upstream_task")]
task_dict = {"upstream_task": MagicMock(task_id="upstream_task")}

task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams)
task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)

task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
expected_json = {
Expand Down Expand Up @@ -2423,12 +2424,13 @@ def test_convert_to_databricks_workflow_task_no_task_group(self):
)
operator.task_group = None
relevant_upstreams = [MagicMock(task_id="upstream_task")]
task_dict = {"upstream_task": MagicMock(task_id="upstream_task")}

with pytest.raises(
AirflowException,
match="Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup.",
):
operator._convert_to_databricks_workflow_task(relevant_upstreams)
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)

def test_convert_to_databricks_workflow_task_cluster_conflict(self):
"""Test that an error is raised if both `existing_cluster_id` and `job_cluster_key` are set."""
Expand All @@ -2446,12 +2448,13 @@ def test_convert_to_databricks_workflow_task_cluster_conflict(self):
operator.job_cluster_key = "job-cluster-key"
operator.task_group = databricks_workflow_task_group
relevant_upstreams = [MagicMock(task_id="upstream_task")]
task_dict = {"upstream_task": MagicMock(task_id="upstream_task")}

with pytest.raises(
ValueError,
match="Both existing_cluster_id and job_cluster_key are set. Only one can be set per task.",
):
operator._convert_to_databricks_workflow_task(relevant_upstreams)
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)


class TestDatabricksTaskOperator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def test_create_workflow_json(mock_databricks_hook, context, mock_task_group):
)
operator.task_group = mock_task_group

task = MagicMock(spec=BaseOperator)
task = MagicMock(spec=BaseOperator, task_id="task_1")
task._convert_to_databricks_workflow_task = MagicMock(return_value={})
operator.add_task(task)
operator.add_task(task.task_id, task)

workflow_json = operator.create_workflow_json(context=context)

Expand Down Expand Up @@ -150,9 +150,9 @@ def test_execute(mock_databricks_hook, context, mock_task_group):
life_cycle_state=RunLifeCycleState.RUNNING.value
)

task = MagicMock(spec=BaseOperator)
task = MagicMock(spec=BaseOperator, task_id="task_1")
task._convert_to_databricks_workflow_task = MagicMock(return_value={})
operator.add_task(task)
operator.add_task(task.task_id, task)

result = operator.execute(context)

Expand Down