Skip to content

Commit

Permalink
support full_refresh sources in airbyte connector (#7247)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c9af893064a73bbd73f0d062d5bb3cf486cec8d8
  • Loading branch information
zxqfd555-pw authored and Manul from Pathway committed Sep 11, 2024
1 parent 059da89 commit c752bba
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- **Experimental** A ``pw.xpacks.llm.document_store.DocumentStore`` to process and index documents.
- ``pw.xpacks.llm.servers.DocumentStoreServer`` used to expose REST server for retrieving documents from ``pw.xpacks.llm.document_store.DocumentStore``.
- `pw.xpacks.stdlib.indexing.HybridIndex` used for querying multiple indices and combining their results.
- `pw.io.airbyte.read` now also supports streams that only operate in `full_refresh` mode.

### Changed
- Running servers for answering queries is extracted from `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` into `pw.xpacks.llm.servers.QARestServer` and `pw.xpacks.llm.servers.QASummaryRestServer`.
Expand Down
8 changes: 8 additions & 0 deletions integration_tests/airbyte/test-file-connection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
source:
docker_image: "airbyte/source-file:latest"
config:
dataset_name: "dataset"
format: "csv"
url: "./input.txt"
provider:
storage: "local"
54 changes: 51 additions & 3 deletions integration_tests/airbyte/test_airbyte.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import os
import threading

import pytest
import yaml

import pathway as pw
from pathway.tests.utils import run_all
from pathway.tests.utils import (
FileLinesNumberChecker,
run_all,
wait_result_with_checker,
write_lines,
)

CREDENTIALS_PATH = os.path.join(os.path.dirname(__file__), "credentials.json")
TEST_CONNECTION_PATH = os.path.join(os.path.dirname(__file__), "test-connection.yaml")
TEST_FAKER_CONNECTION_PATH = os.path.join(
os.path.dirname(__file__), "test-faker-connection.yaml"
)
TEST_FILE_CONNECTION_PATH = os.path.join(
os.path.dirname(__file__), "test-file-connection.yaml"
)


@pytest.mark.parametrize("gcp_job_name", [None, "pw-integration-custom-gcp-job-name"])
@pytest.mark.parametrize("env_vars", [None, {"''": "\"''''\"\""}, {"KEY": "VALUE"}])
def test_airbyte_remote_run(gcp_job_name, env_vars, tmp_path):
table = pw.io.airbyte.read(
TEST_CONNECTION_PATH,
TEST_FAKER_CONNECTION_PATH,
["Users"],
service_user_credentials_file=CREDENTIALS_PATH,
mode="static",
Expand All @@ -29,3 +41,39 @@ def test_airbyte_remote_run(gcp_job_name, env_vars, tmp_path):
for _ in f:
total_lines += 1
assert total_lines == 500


def test_airbyte_full_refresh_streams(tmp_path):
input_path = tmp_path / "input.csv"
connection_path = tmp_path / "connection.yaml"
with open(TEST_FILE_CONNECTION_PATH, "r") as f:
config = yaml.safe_load(f)
config["source"]["config"]["url"] = os.fspath(input_path)
with open(connection_path, "w") as f:
yaml.dump(config, f)

write_lines(input_path, "header\nfoo\nbar\nbaz")
output_path = tmp_path / "table.jsonl"

table = pw.io.airbyte.read(
connection_path,
streams=["dataset"],
mode="streaming",
execution_type="local",
enforce_method="venv",
refresh_interval_ms=100,
)
pw.io.jsonlines.write(table, output_path)

def stream_target():
wait_result_with_checker(FileLinesNumberChecker(output_path, 3), 5, target=None)
write_lines(input_path, "header\nbaz")

wait_result_with_checker(FileLinesNumberChecker(output_path, 5), 5, target=None)
write_lines(input_path, "header\nfoo")

inputs_thread = threading.Thread(target=stream_target, daemon=True)
inputs_thread.start()

wait_result_with_checker(FileLinesNumberChecker(output_path, 7), 15)
inputs_thread.join()
8 changes: 4 additions & 4 deletions python/pathway/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from pathway import cli
from pathway.internals import config, parse_graph
from pathway.tests.utils import AIRBYTE_CONNECTION_REL_PATH, UniquePortDispenser
from pathway.tests.utils import AIRBYTE_FAKER_CONNECTION_REL_PATH, UniquePortDispenser


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -91,7 +91,7 @@ def tmp_path_with_airbyte_config(tmp_path):
result = runner.invoke(
cli.create_source,
[
"new_source",
"faker",
"--image",
"airbyte/source-faker:6.2.10",
],
Expand All @@ -100,15 +100,15 @@ def tmp_path_with_airbyte_config(tmp_path):
finally:
os.chdir(start_dir)

with open(tmp_path / AIRBYTE_CONNECTION_REL_PATH, "r") as f:
with open(tmp_path / AIRBYTE_FAKER_CONNECTION_REL_PATH, "r") as f:
config = yaml.safe_load(f)

# https://docs.airbyte.com/integrations/sources/faker#reference
config["source"]["config"]["records_per_slice"] = 500
config["source"]["config"]["records_per_sync"] = 500
config["source"]["config"]["count"] = 500
config["source"]["config"]["always_updated"] = False
with open(tmp_path / AIRBYTE_CONNECTION_REL_PATH, "w") as f:
with open(tmp_path / AIRBYTE_FAKER_CONNECTION_REL_PATH, "w") as f:
yaml.dump(config, f)

return tmp_path
17 changes: 14 additions & 3 deletions python/pathway/io/airbyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
AbstractAirbyteSource,
)

INCREMENTAL_SYNC_MODE = "incremental"
METHOD_PYPI = "pypi"
METHOD_DOCKER = "docker"

Expand Down Expand Up @@ -49,6 +48,10 @@ def _construct_local_source(
enforce_method: str | None = None,
) -> AbstractAirbyteSource:
with optional_imports("airbyte"):
from pathway.io.airbyte.logic import (
FULL_REFRESH_SYNC_MODE,
INCREMENTAL_SYNC_MODE,
)
from pathway.third_party.airbyte_serverless.sources import (
DockerAirbyteSource,
VenvAirbyteSource,
Expand Down Expand Up @@ -81,11 +84,18 @@ def _construct_local_source(
)

# Run airbyte connector locally and check streams
global_sync_mode = None
for stream in source.configured_catalog["streams"]:
name = stream["stream"]["name"]
sync_mode = stream["sync_mode"]
if sync_mode != INCREMENTAL_SYNC_MODE:
raise ValueError(f"Stream {name} doesn't support 'incremental' sync mode")
if sync_mode != INCREMENTAL_SYNC_MODE and sync_mode != FULL_REFRESH_SYNC_MODE:
raise ValueError(f"Stream {name} has unknown sync_mode: {sync_mode}")
global_sync_mode = global_sync_mode or sync_mode
if global_sync_mode != sync_mode:
raise ValueError(
"All streams within the same 'pw.io.airbyte.read' must have "
"the same 'sync_mode'"
)

return source

Expand Down Expand Up @@ -304,6 +314,7 @@ def read(
)
source = RemoteAirbyteSource(
config=config,
streams=streams,
job_id=job_id,
credentials=credentials,
region=gcp_region,
Expand Down
38 changes: 37 additions & 1 deletion python/pathway/io/airbyte/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from collections.abc import Sequence

from pathway.internals import api
from pathway.io._utils import STATIC_MODE_NAME
from pathway.io.python import ConnectorSubject
from pathway.third_party.airbyte_serverless.destinations import (
Expand All @@ -12,6 +13,7 @@

MAX_RETRIES = 5
INCREMENTAL_SYNC_MODE = "incremental"
FULL_REFRESH_SYNC_MODE = "full_refresh"

AIRBYTE_STREAM_RECORD_PREFIX = "_airbyte_raw_"
AIRBYTE_DATA_RECORD_FIELD = "_airbyte_data"
Expand Down Expand Up @@ -144,10 +146,28 @@ def __init__(
self.mode = mode
self.refresh_interval = refresh_interval_ms / 1000.0
self.destination = _PathwayAirbyteDestination(
on_event=lambda payload: self.next_json({"data": payload}),
on_event=self.on_event,
on_state=self.on_state,
)
self.streams = streams
self.sync_mode = self._sync_mode()
self._cache: dict[api.Pointer, bytes] = {}
self._present_keys: set[api.Pointer] = set()

def on_event(self, payload):
if self.sync_mode == INCREMENTAL_SYNC_MODE:
self.next_json({"data": payload})
elif self.sync_mode == FULL_REFRESH_SYNC_MODE:
message = json.dumps(
{"data": payload}, ensure_ascii=False, sort_keys=True
).encode("utf-8")
key = api.ref_scalar(message)
if self._cache.get(key) != message:
self._cache[key] = message
self._add(key, message)
self._present_keys.add(key)
else:
raise RuntimeError(f"Unknown sync_mode: {self.sync_mode}")

def on_state(self, state):
self._report_offset(json.dumps(state).encode("utf-8"))
Expand Down Expand Up @@ -179,6 +199,17 @@ def run(self):

if self.mode == STATIC_MODE_NAME:
break
if self.sync_mode == FULL_REFRESH_SYNC_MODE:
absent_keys = set()
for key, message in self._cache.items():
if key not in self._present_keys:
self._remove(key, message)
absent_keys.add(key)
for key in absent_keys:
self._cache.pop(key)
self._present_keys.clear()
self._enable_commits()
self._disable_commits()

time_elapsed = time.time() - time_before_start
if time_elapsed < self.refresh_interval:
Expand All @@ -187,6 +218,11 @@ def run(self):
def on_stop(self):
self.source.on_stop()

def _sync_mode(self):
stream = self.source.configured_catalog["streams"][0]
sync_mode = stream["sync_mode"]
return sync_mode

def _seek(self, state):
self.destination.set_state(json.loads(state.decode("utf-8")))

Expand Down
18 changes: 11 additions & 7 deletions python/pathway/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pathway.internals.parse_graph import G
from pathway.io.airbyte.logic import _PathwayAirbyteDestination
from pathway.tests.utils import (
AIRBYTE_CONNECTION_REL_PATH,
AIRBYTE_FAKER_CONNECTION_REL_PATH,
CountDifferentTimestampsCallback,
CsvLinesNumberChecker,
FileLinesNumberChecker,
Expand Down Expand Up @@ -3097,14 +3097,16 @@ def iteration():
@pytest.mark.parametrize("env_vars", [None, {"''": "\"''''\"\""}, {"KEY": "VALUE"}])
def test_airbyte_local_run(env_vars, tmp_path_with_airbyte_config):
table = pw.io.airbyte.read(
tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH,
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH,
["users"],
mode="static",
execution_type="local",
env_vars=env_vars,
)

with open(tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH, "r") as f:
with open(
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH, "r"
) as f:
config = yaml.safe_load(f)["source"]
airbyte_source = pw.io.airbyte._construct_local_source(
config,
Expand All @@ -3127,15 +3129,17 @@ def test_airbyte_local_run(env_vars, tmp_path_with_airbyte_config):
@pytest.mark.parametrize("env_vars", [None, {"''": "\"''''\"\""}, {"KEY": "VALUE"}])
def test_airbyte_local_docker_run(env_vars, tmp_path_with_airbyte_config):
table = pw.io.airbyte.read(
tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH,
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH,
["users"],
mode="static",
execution_type="local",
env_vars=env_vars,
enforce_method="docker",
)

with open(tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH, "r") as f:
with open(
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH, "r"
) as f:
config = yaml.safe_load(f)["source"]
airbyte_source = pw.io.airbyte._construct_local_source(
config,
Expand Down Expand Up @@ -3288,7 +3292,7 @@ def test_airbyte_persistence(enforce_method, tmp_path_with_airbyte_config):

def run_pathway_program(n_expected_records):
table = pw.io.airbyte.read(
tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH,
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH,
["users"],
mode="static",
execution_type="local",
Expand Down Expand Up @@ -3318,7 +3322,7 @@ def test_airbyte_persistence_error_message(tmp_path_with_airbyte_config):
output_path = tmp_path_with_airbyte_config / "table.jsonl"
pstorage_path = tmp_path_with_airbyte_config / "PStorage"
table = pw.io.airbyte.read(
tmp_path_with_airbyte_config / AIRBYTE_CONNECTION_REL_PATH,
tmp_path_with_airbyte_config / AIRBYTE_FAKER_CONNECTION_REL_PATH,
streams=["users", "purchases"],
mode="static",
)
Expand Down
2 changes: 1 addition & 1 deletion python/pathway/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
os.getenv("PATHWAY_THREADS", "1") != "1", reason="multiple threads"
)

AIRBYTE_CONNECTION_REL_PATH = "connections/new_source.yaml"
AIRBYTE_FAKER_CONNECTION_REL_PATH = "connections/faker.yaml"


def skip_on_multiple_workers() -> None:
Expand Down
39 changes: 23 additions & 16 deletions python/pathway/third_party/airbyte_serverless/executable_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,32 @@
MAX_GCP_ENV_VAR_LENGTH = 32768


def get_configured_catalog(catalog, streams):
configured_catalog = catalog
configured_catalog["streams"] = [
{
"stream": stream,
"sync_mode": (
"incremental"
if "incremental" in stream["supported_sync_modes"]
else "full_refresh"
),
"destination_sync_mode": "append",
"cursor_field": stream.get("default_cursor_field", []),
}
for stream in configured_catalog["streams"]
if not streams or stream["name"] in streams
]
return configured_catalog


class AbstractAirbyteSource(ABC):
@abstractmethod
def extract(self, state=None): ...

@property
def configured_catalog(self): ...

def on_stop(self):
pass

Expand Down Expand Up @@ -252,22 +274,7 @@ def catalog(self):

@property
def configured_catalog(self):
configured_catalog = self.catalog
configured_catalog["streams"] = [
{
"stream": stream,
"sync_mode": (
"incremental"
if "incremental" in stream["supported_sync_modes"]
else "full_refresh"
),
"destination_sync_mode": "append",
"cursor_field": stream.get("default_cursor_field", []),
}
for stream in configured_catalog["streams"]
if not self.streams or stream["name"] in self.streams
]
return configured_catalog
return get_configured_catalog(self.catalog, self.streams)

def load_cached_catalog(self, cached_catalog):
self._cached_catalog = cached_catalog
Expand Down
Loading

0 comments on commit c752bba

Please sign in to comment.