Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class BatchState(Enum):
SUCCESS = "success"


def sanitize_endpoint_prefix(endpoint_prefix: str | None) -> str:
"""Ensure that the endpoint prefix is prefixed with a slash."""
return f"/{endpoint_prefix.strip('/')}" if endpoint_prefix else ""


class LivyHook(HttpHook):
"""
Hook for Apache Livy through the REST API.
Expand Down Expand Up @@ -86,12 +91,14 @@ def __init__(
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
auth_type: Any | None = None,
endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
if auth_type:
self.auth_type = auth_type

Expand Down Expand Up @@ -163,7 +170,10 @@ def post_batch(self, *args: Any, **kwargs: Any) -> int:
self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url)

response = self.run_method(
method="POST", endpoint="/batches", data=batch_submit_body, headers=self.extra_headers
method="POST",
endpoint=f"{self.endpoint_prefix}/batches",
data=batch_submit_body,
headers=self.extra_headers,
)
self.log.debug("Got response: %s", response.text)

Expand Down Expand Up @@ -192,7 +202,9 @@ def get_batch(self, session_id: int | str) -> dict:
self._validate_session_id(session_id)

self.log.debug("Fetching info for batch session %s", session_id)
response = self.run_method(endpoint=f"/batches/{session_id}", headers=self.extra_headers)
response = self.run_method(
endpoint=f"{self.endpoint_prefix}/batches/{session_id}", headers=self.extra_headers
)

try:
response.raise_for_status()
Expand All @@ -217,7 +229,9 @@ def get_batch_state(self, session_id: int | str, retry_args: dict[str, Any] | No

self.log.debug("Fetching info for batch session %s", session_id)
response = self.run_method(
endpoint=f"/batches/{session_id}/state", retry_args=retry_args, headers=self.extra_headers
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state",
retry_args=retry_args,
headers=self.extra_headers,
)

try:
Expand All @@ -244,7 +258,9 @@ def delete_batch(self, session_id: int | str) -> dict:

self.log.info("Deleting batch session %s", session_id)
response = self.run_method(
method="DELETE", endpoint=f"/batches/{session_id}", headers=self.extra_headers
method="DELETE",
endpoint=f"{self.endpoint_prefix}/batches/{session_id}",
headers=self.extra_headers,
)

try:
Expand All @@ -270,7 +286,9 @@ def get_batch_logs(self, session_id: int | str, log_start_position, log_batch_si
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
response = self.run_method(
endpoint=f"/batches/{session_id}/log", data=log_params, headers=self.extra_headers
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log",
data=log_params,
headers=self.extra_headers,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -490,12 +508,14 @@ def __init__(
livy_conn_id: str = default_conn_name,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)

async def _do_api_call_async(
self,
Expand Down Expand Up @@ -624,7 +644,7 @@ async def get_batch_state(self, session_id: int | str) -> Any:
"""
self._validate_session_id(session_id)
self.log.info("Fetching info for batch session %s", session_id)
result = await self.run_method(endpoint=f"/batches/{session_id}/state")
result = await self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state")
if result["status"] == "error":
self.log.info(result)
return {"batch_state": "error", "response": result, "status": "error"}
Expand Down Expand Up @@ -659,7 +679,9 @@ async def get_batch_logs(
"""
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
result = await self.run_method(endpoint=f"/batches/{session_id}/log", data=log_params)
result = await self.run_method(
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", data=log_params
)
if result["status"] == "error":
self.log.info(result)
return {"response": result["response"], "status": "error"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
proxy_user: str | None = None,
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
livy_endpoint_prefix: str | None = None,
polling_interval: int = 0,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
self.spark_params = spark_params
self._livy_conn_id = livy_conn_id
self._livy_conn_auth_type = livy_conn_auth_type
self._livy_endpoint_prefix = livy_endpoint_prefix
self._polling_interval = polling_interval
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
Expand All @@ -139,6 +141,7 @@ def hook(self) -> LivyHook:
extra_headers=self._extra_headers,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
endpoint_prefix=self._livy_endpoint_prefix,
)

def execute(self, context: Context) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
extra_options: dict[str, Any] | None = None,
endpoint_prefix: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -54,6 +55,7 @@ def __init__(
self._livy_conn_auth_type = livy_conn_auth_type
self._livy_hook: LivyHook | None = None
self._extra_options = extra_options or {}
self._endpoint_prefix = endpoint_prefix

def get_hook(self) -> LivyHook:
"""
Expand All @@ -66,6 +68,7 @@ def get_hook(self) -> LivyHook:
livy_conn_id=self._livy_conn_id,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
endpoint_prefix: str | None = None,
):
super().__init__()
self._batch_id = batch_id
Expand All @@ -67,6 +68,7 @@ def __init__(
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout
self._endpoint_prefix = endpoint_prefix

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
Expand Down Expand Up @@ -170,5 +172,6 @@ def _get_async_hook(self) -> LivyAsyncHook:
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook_async
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,77 @@ def test_alternate_auth_type(self):

auth_type.assert_called_once_with("login", "secret")

@patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method")
def test_post_batch_with_endpoint_prefix(self, mock_request):
mock_request.return_value.status_code = 201
mock_request.return_value.json.return_value = {
"id": BATCH_ID,
"state": BatchState.STARTING.value,
"log": [],
}

resp = LivyHook(endpoint_prefix="/livy").post_batch(file="sparkapp")

mock_request.assert_called_once_with(
method="POST", endpoint="/livy/batches", data=json.dumps({"file": "sparkapp"}), headers={}
)

request_args = mock_request.call_args.kwargs
assert "data" in request_args
assert isinstance(request_args["data"], str)

assert isinstance(resp, int)
assert resp == BATCH_ID

def test_get_batch_with_endpoint_prefix(self, requests_mock):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix="/livy").get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp

def test_get_batch_state_with_endpoint_prefix(self, requests_mock):
running = BatchState.RUNNING

requests_mock.register_uri(
"GET",
f"{MATCH_URL}/livy/batches/{BATCH_ID}/state",
json={"id": BATCH_ID, "state": running.value},
status_code=200,
)

state = LivyHook(endpoint_prefix="/livy").get_batch_state(BATCH_ID)
assert isinstance(state, BatchState)
assert state == running

def test_delete_batch_with_endpoint_prefix(self, requests_mock):
requests_mock.register_uri(
"DELETE", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"msg": "deleted"}, status_code=200
)
assert LivyHook(endpoint_prefix="/livy").delete_batch(BATCH_ID) == {"msg": "deleted"}

@pytest.mark.parametrize(
"prefix",
["/livy/", "livy", "/livy", "livy/"],
ids=["leading_and_trailing_slashes", "no_slashes", "leading_slash", "trailing_slash"],
)
def test_endpoint_prefix_is_sanitized_simple(self, requests_mock, prefix):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix=prefix).get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp

def test_endpoint_prefix_is_sanitized_multiple_path_elements(self, requests_mock):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/foo/bar/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix="/livy/foo/bar/").get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp


class TestLivyAsyncHook:
@pytest.mark.asyncio
Expand Down Expand Up @@ -815,3 +886,31 @@ def test_check_session_id_success(self, conn_id):
def test_check_session_id_failure(self, conn_id):
with pytest.raises(TypeError):
LivyAsyncHook._validate_session_id(None)

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method):
mock_run_method.return_value = {"status": "success", "response": {"state": BatchState.RUNNING}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy")
state = await hook.get_batch_state(BATCH_ID)
assert state == {
"batch_state": BatchState.RUNNING,
"response": "successfully fetched the batch state.",
"status": "success",
}
mock_run_method.assert_called_once_with(
endpoint=f"/livy/batches/{BATCH_ID}/state",
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method):
mock_run_method.return_value = {"status": "success", "response": {}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy")
state = await hook.get_batch_logs(BATCH_ID, 0, 100)
assert state["status"] == "success"

mock_run_method.assert_called_once_with(
endpoint=f"/livy/batches/{BATCH_ID}/log",
data={"from": 0, "size": 100},
)