Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,31 +152,40 @@ def submit_run(self, json: dict) -> int:
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
return response["run_id"]

def list_jobs(self, limit: int = 25, offset: int = 0, expand_tasks: bool = False) -> list[dict[str, Any]]:
def list_jobs(
self, limit: int = 25, offset: int = 0, expand_tasks: bool = False, job_name: str | None = None
) -> list[dict[str, Any]]:
"""
Lists the jobs in the Databricks Job Service.

:param limit: The limit/batch size used to retrieve jobs.
:param offset: The offset of the first job to return, relative to the most recently created job.
:param expand_tasks: Whether to include task and cluster details in the response.
:param job_name: Optional name of a job to search.
:return: A list of jobs.
"""
has_more = True
jobs = []
all_jobs = []

while has_more:
json = {
payload: dict[str, Any] = {
"limit": limit,
"offset": offset,
"expand_tasks": expand_tasks,
"offset": offset,
}
response = self._do_api_call(LIST_JOBS_ENDPOINT, json)
jobs += response["jobs"] if "jobs" in response else []
if job_name:
payload["name"] = job_name
response = self._do_api_call(LIST_JOBS_ENDPOINT, payload)
jobs = response.get("jobs", [])
if job_name:
all_jobs += [j for j in jobs if j["settings"]["name"] == job_name]
else:
all_jobs += jobs
has_more = response.get("has_more", False)
if has_more:
offset += len(response["jobs"])
offset += len(jobs)

return jobs
return all_jobs

def find_job_id_by_name(self, job_name: str) -> int | None:
"""
Expand All @@ -185,8 +194,7 @@ def find_job_id_by_name(self, job_name: str) -> int | None:
:param job_name: The name of the job to look up.
:return: The job_id as an int or None if no job was found.
"""
all_jobs = self.list_jobs()
matching_jobs = [j for j in all_jobs if j["settings"]["name"] == job_name]
matching_jobs = self.list_jobs(job_name=job_name)

if len(matching_jobs) > 1:
raise AirflowException(
Expand Down
9 changes: 5 additions & 4 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def test_get_job_id_by_name_success(self, mock_requests):
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False},
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": JOB_NAME},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand All @@ -686,12 +686,13 @@ def test_get_job_id_by_name_not_found(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE

job_id = self.hook.find_job_id_by_name("Non existing job")
job_name = "Non existing job"
job_id = self.hook.find_job_id_by_name(job_name)

mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False},
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": job_name},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand All @@ -714,7 +715,7 @@ def test_get_job_id_by_name_raise_exception_with_duplicates(self, mock_requests)
mock_requests.get.assert_called_once_with(
list_jobs_endpoint(HOST),
json=None,
params={"limit": 25, "offset": 0, "expand_tasks": False},
params={"limit": 25, "offset": 0, "expand_tasks": False, "name": JOB_NAME},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand Down