Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions airflow/providers/apache/beam/hooks/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import contextlib
import copy
import json
import os
import select
Expand Down Expand Up @@ -310,11 +311,10 @@ def start_go_pipeline(
should_init_module: bool = False,
) -> None:
"""
Starts Apache Beam Go pipeline.
Starts Apache Beam Go pipeline with a source file.

:param variables: Variables passed to the job.
:param go_file: Path to the Go file with your beam pipeline.
:param go_file:
:param process_line_callback: (optional) Callback that can be used to process each line of
the stdout and stderr file descriptors.
:param should_init_module: If False (default), will just execute a `go run` command. If True, will
Expand Down Expand Up @@ -346,3 +346,34 @@ def start_go_pipeline(
process_line_callback=process_line_callback,
working_directory=working_directory,
)

def start_go_pipeline_with_binary(
self,
variables: dict,
launcher_binary: str,
worker_binary: str,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
"""
Starts Apache Beam Go pipeline with an executable binary.

:param variables: Variables passed to the job.
:param launcher_binary: Path to the binary compiled for the launching platform.
:param worker_binary: Path to the binary compiled for the worker platform.
:param process_line_callback: (optional) Callback that can be used to process each line of
the stdout and stderr file descriptors.
"""
job_variables = copy.deepcopy(variables)

if "labels" in job_variables:
job_variables["labels"] = json.dumps(job_variables["labels"], separators=(",", ":"))

job_variables["worker_binary"] = worker_binary

command_prefix = [launcher_binary]

self._start_pipeline(
variables=job_variables,
command_prefix=command_prefix,
process_line_callback=process_line_callback,
)
179 changes: 159 additions & 20 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
from __future__ import annotations

import copy
import os
import stat
import tempfile
from abc import ABC, ABCMeta
from abc import ABC, ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from functools import partial
from typing import TYPE_CHECKING, Callable, Sequence

from airflow import AirflowException
Expand All @@ -31,10 +35,10 @@
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.helpers import convert_camel_to_snake, exactly_one
from airflow.version import version

if TYPE_CHECKING:
Expand Down Expand Up @@ -520,12 +524,27 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
For more detail on Apache Beam have a look at the reference:
https://beam.apache.org/documentation/

:param go_file: Reference to the Go Apache Beam pipeline e.g.,
/some/local/file/path/to/your/go/pipeline/file.go
:param go_file: Reference to the Apache Beam pipeline Go source file,
e.g. /local/path/to/main.go or gs://bucket/path/to/main.go.
Exactly one of go_file and launcher_binary must be provided.

:param launcher_binary: Reference to the Apache Beam pipeline Go binary compiled for the launching
platform, e.g. /local/path/to/launcher-main or gs://bucket/path/to/launcher-main.
Exactly one of go_file and launcher_binary must be provided.

:param worker_binary: Reference to the Apache Beam pipeline Go binary compiled for the worker platform,
e.g. /local/path/to/worker-main or gs://bucket/path/to/worker-main.
Needed if the OS or architecture of the workers running the pipeline is different from that
of the platform launching the pipeline. For more information, see the Apache Beam documentation
for Go cross compilation: https://beam.apache.org/documentation/sdks/go-cross-compilation/.
If launcher_binary is not set, providing a worker_binary will have no effect. If launcher_binary is
set and worker_binary is not, worker_binary will default to the value of launcher_binary.
"""

template_fields = [
"go_file",
"launcher_binary",
"worker_binary",
"runner",
"pipeline_options",
"default_pipeline_options",
Expand All @@ -537,7 +556,9 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
def __init__(
self,
*,
go_file: str,
go_file: str = "",
launcher_binary: str = "",
worker_binary: str = "",
runner: str = "DirectRunner",
default_pipeline_options: dict | None = None,
pipeline_options: dict | None = None,
Expand All @@ -563,8 +584,13 @@ def __init__(
)
self.dataflow_support_impersonation = False

if not exactly_one(go_file, launcher_binary):
raise ValueError("Exactly one of `go_file` and `launcher_binary` must be set")

self.go_file = go_file
self.should_init_go_module = False
self.launcher_binary = launcher_binary
self.worker_binary = worker_binary or launcher_binary

self.pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)
Expand All @@ -581,24 +607,24 @@ def execute(self, context: Context):
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")

go_artifact: _GoArtifact = (
_GoFile(file=self.go_file)
if self.go_file
else _GoBinary(launcher=self.launcher_binary, worker=self.worker_binary)
)

with ExitStack() as exit_stack:
if self.go_file.lower().startswith("gs://"):
if go_artifact.is_located_on_gcs():
gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)

tmp_dir = exit_stack.enter_context(tempfile.TemporaryDirectory(prefix="apache-beam-go"))
tmp_gcs_file = exit_stack.enter_context(
gcs_hook.provide_file(object_url=self.go_file, dir=tmp_dir)
)
self.go_file = tmp_gcs_file.name
self.should_init_go_module = True
go_artifact.download_from_gcs(gcs_hook=gcs_hook, tmp_dir=tmp_dir)

if is_dataflow and self.dataflow_hook:
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_go_pipeline(
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
go_file=self.go_file,
process_line_callback=process_line_callback,
should_init_module=self.should_init_go_module,
)

DataflowJobLink.persist(
Expand All @@ -618,11 +644,10 @@ def execute(self, context: Context):
)
return {"dataflow_job_id": self.dataflow_job_id}
else:
self.beam_hook.start_go_pipeline(
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
go_file=self.go_file,
process_line_callback=process_line_callback,
should_init_module=self.should_init_go_module,
)

def on_kill(self) -> None:
Expand All @@ -632,3 +657,117 @@ def on_kill(self) -> None:
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
)


class _GoArtifact(ABC):
@abstractmethod
def is_located_on_gcs(self) -> bool:
...

@abstractmethod
def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
...

@abstractmethod
def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
...


class _GoFile(_GoArtifact):
def __init__(self, file: str) -> None:
self.file = file
self.should_init_go_module = False

def is_located_on_gcs(self) -> bool:
return _object_is_located_on_gcs(self.file)

def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
self.file = _download_object_from_gcs(gcs_hook=gcs_hook, uri=self.file, tmp_dir=tmp_dir)
self.should_init_go_module = True

def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
beam_hook.start_go_pipeline(
variables=variables,
go_file=self.file,
process_line_callback=process_line_callback,
should_init_module=self.should_init_go_module,
)


class _GoBinary(_GoArtifact):
def __init__(self, launcher: str, worker: str) -> None:
self.launcher = launcher
self.worker = worker

def is_located_on_gcs(self) -> bool:
return any(_object_is_located_on_gcs(path) for path in (self.launcher, self.worker))

def download_from_gcs(self, gcs_hook: GCSHook, tmp_dir: str) -> None:
binaries_are_equal = self.launcher == self.worker

binaries_to_download = []

if _object_is_located_on_gcs(self.launcher):
binaries_to_download.append("launcher")

if not binaries_are_equal and _object_is_located_on_gcs(self.worker):
binaries_to_download.append("worker")

download_fn = partial(_download_object_from_gcs, gcs_hook=gcs_hook, tmp_dir=tmp_dir)

with ThreadPoolExecutor(max_workers=len(binaries_to_download)) as executor:
futures = {
executor.submit(download_fn, uri=getattr(self, binary), tmp_prefix=f"{binary}-"): binary
for binary in binaries_to_download
}

for future in as_completed(futures):
binary = futures[future]
tmp_path = future.result()
_make_executable(tmp_path)
setattr(self, binary, tmp_path)

if binaries_are_equal:
self.worker = self.launcher

def start_pipeline(
self,
beam_hook: BeamHook,
variables: dict,
process_line_callback: Callable[[str], None] | None = None,
) -> None:
beam_hook.start_go_pipeline_with_binary(
variables=variables,
launcher_binary=self.launcher,
worker_binary=self.worker,
process_line_callback=process_line_callback,
)


def _object_is_located_on_gcs(path: str) -> bool:
return path.lower().startswith("gs://")


def _download_object_from_gcs(gcs_hook: GCSHook, uri: str, tmp_dir: str, tmp_prefix: str = "") -> str:
tmp_name = f"{tmp_prefix}{os.path.basename(uri)}"
tmp_path = os.path.join(tmp_dir, tmp_name)

bucket, prefix = _parse_gcs_url(uri)
gcs_hook.download(bucket_name=bucket, object_name=prefix, filename=tmp_path)

return tmp_path


def _make_executable(path: str) -> None:
st = os.stat(path)
os.chmod(path, st.st_mode | stat.S_IEXEC)
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ schedulable
schedulername
schemas
sdk
sdks
searchpath
SearchResultGenerator
SecretManagerClient
Expand Down
29 changes: 29 additions & 0 deletions tests/providers/apache/beam/hooks/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@ def test_start_go_pipeline_without_go_installed_raises(self, mock_which):
variables=copy.deepcopy(BEAM_VARIABLES_GO),
)

@mock.patch(BEAM_STRING.format("BeamCommandRunner"))
def test_start_go_pipeline_with_binary(self, mock_runner):
hook = BeamHook(runner=DEFAULT_RUNNER)
wait_for_done_method = mock_runner.return_value.wait_for_done
process_line_callback = MagicMock()

launcher_binary = "/path/to/launcher-main"
worker_binary = "/path/to/worker-main"

hook.start_go_pipeline_with_binary(
variables=BEAM_VARIABLES_GO,
launcher_binary=launcher_binary,
worker_binary=worker_binary,
process_line_callback=process_line_callback,
)

expected_cmd = [
launcher_binary,
f"--runner={DEFAULT_RUNNER}",
"--output=gs://test/output",
'--labels={"foo":"bar"}',
f"--worker_binary={worker_binary}",
]

mock_runner.assert_called_once_with(
cmd=expected_cmd, process_line_callback=process_line_callback, working_directory=None
)
wait_for_done_method.assert_called_once_with()


class TestBeamRunner:
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
Expand Down
Loading