Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions airflow/providers/papermill/operators/papermill.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,23 @@ def __init__(

if not input_nb:
raise ValueError("Input notebook is not specified")
elif not isinstance(input_nb, NoteBook):
self.input_nb = NoteBook(url=input_nb, parameters=self.parameters)
else:
self.input_nb = input_nb
self.input_nb = input_nb

if not output_nb:
raise ValueError("Output notebook is not specified")
elif not isinstance(output_nb, NoteBook):
self.output_nb = NoteBook(url=output_nb)
else:
self.output_nb = output_nb
self.output_nb = output_nb

self.kernel_name = kernel_name
self.language_name = language_name
self.kernel_conn_id = kernel_conn_id

def execute(self, context: Context):
if not isinstance(self.input_nb, NoteBook):
self.input_nb = NoteBook(url=self.input_nb, parameters=self.parameters)
if not isinstance(self.output_nb, NoteBook):
self.output_nb = NoteBook(url=self.output_nb)
self.inlets.append(self.input_nb)
self.outlets.append(self.output_nb)

def execute(self, context: Context):
remote_kernel_kwargs = {}
kernel_hook = self.hook
if kernel_hook:
Expand Down
25 changes: 12 additions & 13 deletions tests/providers/papermill/operators/test_papermill.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ def test_mandatory_attributes(self):
pytest.param(NoteBook(TEST_INPUT_URL), id="input-as-notebook-object"),
],
)
def test_notebooks_objects(self, input_nb, output_nb):
@patch("airflow.providers.papermill.operators.papermill.pm")
@patch("airflow.providers.papermill.operators.papermill.PapermillOperator.hook")
def test_notebooks_objects(self, mock_papermill, mock_hook, input_nb, output_nb):
"""Test different type of Input/Output notebooks arguments."""
op = PapermillOperator(task_id="test_notebooks_objects", input_nb=input_nb, output_nb=output_nb)

op.execute(None)

assert op.input_nb.url == TEST_INPUT_URL
assert op.output_nb.url == TEST_OUTPUT_URL

# Test render Lineage inlets/outlets
assert op.inlets[0] == op.input_nb
assert op.outlets[0] == op.output_nb

@patch("airflow.providers.papermill.operators.papermill.pm")
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
Expand Down Expand Up @@ -173,19 +182,9 @@ def test_render_template(self, create_task_instance_of_operator):
task = ti.render_templates()

# Test render Input/Output notebook attributes
assert task.input_nb.url == "/tmp/test_render_template.ipynb"
assert task.input_nb.parameters == {
"msgs": "dag id is test_render_template!",
"test_dt": DEFAULT_DATE.date().isoformat(),
}
assert task.output_nb.url == "/tmp/out-test_render_template.ipynb"
assert task.output_nb.parameters == {}
assert task.input_nb == "/tmp/test_render_template.ipynb"
assert task.output_nb == "/tmp/out-test_render_template.ipynb"

# Test render other templated attributes
assert task.parameters == task.input_nb.parameters
assert "python3" == task.kernel_name
assert "python" == task.language_name

# Test render Lineage inlets/outlets
assert task.inlets[0] == task.input_nb
assert task.outlets[0] == task.output_nb