Skip to content

Commit 6df08c2

Browse files
authored
Merge branch 'main' into tolik0/add-api-budget
2 parents 245fb3e + 74631d8 commit 6df08c2

File tree

10 files changed

+426
-17
lines changed

10 files changed

+426
-17
lines changed

.github/workflows/pypi_publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
(github.event_name == 'push' &&
147147
startsWith(github.ref, 'refs/tags/v')
148148
) || github.event.inputs.publish_to_pypi == 'true'
149-
uses: pypa/gh-action-pypi-publish@v1.12.3
149+
uses: pypa/gh-action-pypi-publish@v1.12.4
150150

151151
publish_sdm:
152152
name: Publish SDM to DockerHub

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44

55
from dataclasses import InitVar, dataclass, field
6-
from datetime import timedelta
6+
from datetime import datetime, timedelta
77
from typing import Any, List, Mapping, MutableMapping, Optional, Union
88

99
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
@@ -232,8 +232,13 @@ def get_refresh_request_headers(self) -> Mapping[str, Any]:
232232
return self._refresh_request_headers.eval(self.config)
233233

234234
def get_token_expiry_date(self) -> AirbyteDateTime:
235+
if not self._has_access_token_been_initialized():
236+
return AirbyteDateTime.from_datetime(datetime.min)
235237
return self._token_expiry_date # type: ignore # _token_expiry_date is an AirbyteDateTime. It is never None despite what mypy thinks
236238

239+
def _has_access_token_been_initialized(self) -> bool:
240+
return self._access_token is not None
241+
237242
def set_token_expiry_date(self, value: Union[str, int]) -> None:
238243
self._token_expiry_date = self._parse_token_expiration_date(value)
239244

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44

55
import logging
6-
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple
6+
from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple
77

88
from airbyte_cdk.models import (
99
AirbyteCatalog,
@@ -224,6 +224,7 @@ def _group_streams(
224224
stream_state = self._connector_state_manager.get_stream_state(
225225
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
226226
)
227+
stream_state = self._migrate_state(declarative_stream, stream_state)
227228

228229
retriever = self._get_retriever(declarative_stream, stream_state)
229230

@@ -331,6 +332,8 @@ def _group_streams(
331332
stream_state = self._connector_state_manager.get_stream_state(
332333
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
333334
)
335+
stream_state = self._migrate_state(declarative_stream, stream_state)
336+
334337
partition_router = declarative_stream.retriever.stream_slicer._partition_router
335338

336339
perpartition_cursor = (
@@ -521,3 +524,14 @@ def _remove_concurrent_streams_from_catalog(
521524
if stream.stream.name not in concurrent_stream_names
522525
]
523526
)
527+
528+
@staticmethod
529+
def _migrate_state(
530+
declarative_stream: DeclarativeStream, stream_state: MutableMapping[str, Any]
531+
) -> MutableMapping[str, Any]:
532+
for state_migration in declarative_stream.state_migrations:
533+
if state_migration.should_migrate(stream_state):
534+
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
535+
stream_state = dict(state_migration.migrate(stream_state))
536+
537+
return stream_state

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,17 @@ def create_concurrency_level(
968968
parameters={},
969969
)
970970

971+
@staticmethod
972+
def apply_stream_state_migrations(
973+
stream_state_migrations: List[Any] | None, stream_state: MutableMapping[str, Any]
974+
) -> MutableMapping[str, Any]:
975+
if stream_state_migrations:
976+
for state_migration in stream_state_migrations:
977+
if state_migration.should_migrate(stream_state):
978+
# The state variable is expected to be mutable but the migrate method returns an immutable mapping.
979+
stream_state = dict(state_migration.migrate(stream_state))
980+
return stream_state
981+
971982
def create_concurrent_cursor_from_datetime_based_cursor(
972983
self,
973984
model_type: Type[BaseModel],
@@ -977,6 +988,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
977988
config: Config,
978989
message_repository: Optional[MessageRepository] = None,
979990
runtime_lookback_window: Optional[datetime.timedelta] = None,
991+
stream_state_migrations: Optional[List[Any]] = None,
980992
**kwargs: Any,
981993
) -> ConcurrentCursor:
982994
# Per-partition incremental streams can dynamically create child cursors which will pass their current
@@ -987,6 +999,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
987999
if "stream_state" not in kwargs
9881000
else kwargs["stream_state"]
9891001
)
1002+
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
9901003

9911004
component_type = component_definition.get("type")
9921005
if component_definition.get("type") != model_type.__name__:
@@ -1222,6 +1235,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
12221235
config: Config,
12231236
stream_state: MutableMapping[str, Any],
12241237
partition_router: PartitionRouter,
1238+
stream_state_migrations: Optional[List[Any]] = None,
12251239
**kwargs: Any,
12261240
) -> ConcurrentPerPartitionCursor:
12271241
component_type = component_definition.get("type")
@@ -1270,8 +1284,10 @@ def create_concurrent_cursor_from_perpartition_cursor(
12701284
stream_namespace=stream_namespace,
12711285
config=config,
12721286
message_repository=NoopMessageRepository(),
1287+
stream_state_migrations=stream_state_migrations,
12731288
)
12741289
)
1290+
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
12751291

12761292
# Return the concurrent cursor and state converter
12771293
return ConcurrentPerPartitionCursor(
@@ -1780,6 +1796,7 @@ def _merge_stream_slicers(
17801796
stream_name=model.name or "",
17811797
stream_namespace=None,
17821798
config=config or {},
1799+
stream_state_migrations=model.state_migrations,
17831800
)
17841801
return (
17851802
self._create_component_from_model(model=model.incremental_sync, config=config)

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
261261
262262
:return: expiration datetime
263263
"""
264+
if not value and not self.token_has_expired():
265+
# No expiry token was provided but the previous one is not expired so it's fine
266+
return self.get_token_expiry_date()
264267

265268
if self.token_expiry_is_time_of_expiration:
266269
if not self.token_expiry_date_format:

unit_tests/sources/declarative/auth/test_oauth.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
301301
client_id="{{ config['client_id'] }}",
302302
client_secret="{{ config['client_secret'] }}",
303303
token_expiry_date=timestamp,
304+
access_token_value="some_access_token",
304305
refresh_token="some_refresh_token",
305306
config={
306307
"refresh_endpoint": "refresh_end",
@@ -313,6 +314,34 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
313314
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
314315
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)
315316

317+
def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
318+
self,
319+
) -> None:
320+
expiry_date = ab_datetime_now().add(timedelta(days=1))
321+
oauth = DeclarativeOauth2Authenticator(
322+
token_refresh_endpoint="https://refresh_endpoint.com/",
323+
client_id="some_client_id",
324+
client_secret="some_client_secret",
325+
token_expiry_date=expiry_date.isoformat(),
326+
refresh_token="some_refresh_token",
327+
config={},
328+
parameters={},
329+
grant_type="client",
330+
)
331+
332+
with HttpMocker() as http_mocker:
333+
http_mocker.post(
334+
HttpRequest(
335+
url="https://refresh_endpoint.com/",
336+
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
337+
),
338+
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
339+
)
340+
oauth.get_access_token()
341+
342+
assert oauth.access_token == "new_access_token"
343+
assert oauth._token_expiry_date == expiry_date
344+
316345
@pytest.mark.parametrize(
317346
"expires_in_response, token_expiry_date_format",
318347
[
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from typing import Any, Mapping
6+
7+
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
8+
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
9+
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
10+
from airbyte_cdk.sources.types import Config
11+
12+
13+
class CustomStateMigration(StateMigration):
14+
declarative_stream: DeclarativeStream
15+
config: Config
16+
17+
def __init__(self, declarative_stream: DeclarativeStream, config: Config):
18+
self._config = config
19+
self.declarative_stream = declarative_stream
20+
self._cursor = declarative_stream.incremental_sync
21+
self._parameters = declarative_stream.parameters
22+
self._cursor_field = InterpolatedString.create(
23+
self._cursor.cursor_field, parameters=self._parameters
24+
).eval(self._config)
25+
26+
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
27+
return True
28+
29+
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
30+
if not self.should_migrate(stream_state):
31+
return stream_state
32+
updated_at = stream_state[self._cursor.cursor_field]
33+
34+
migrated_stream_state = {
35+
"states": [
36+
{
37+
"partition": {"type": "type_1"},
38+
"cursor": {self._cursor.cursor_field: updated_at},
39+
},
40+
{
41+
"partition": {"type": "type_2"},
42+
"cursor": {self._cursor.cursor_field: updated_at},
43+
},
44+
]
45+
}
46+
47+
return migrated_stream_state

unit_tests/sources/declarative/decoders/test_composite_decoder.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,35 @@ def compress_with_gzip(data: str, encoding: str = "utf-8"):
3030
return buf.getvalue()
3131

3232

33-
def generate_csv(encoding: str) -> bytes:
34-
"""
35-
Generate CSV data with tab-separated values (\t).
36-
"""
33+
def generate_csv(
34+
encoding: str = "utf-8", delimiter: str = ",", should_compress: bool = False
35+
) -> bytes:
3736
data = [
38-
{"id": 1, "name": "John", "age": 28},
39-
{"id": 2, "name": "Alice", "age": 34},
40-
{"id": 3, "name": "Bob", "age": 25},
37+
{"id": "1", "name": "John", "age": "28"},
38+
{"id": "2", "name": "Alice", "age": "34"},
39+
{"id": "3", "name": "Bob", "age": "25"},
4140
]
4241

4342
output = StringIO()
44-
writer = csv.DictWriter(output, fieldnames=["id", "name", "age"], delimiter="\t")
43+
writer = csv.DictWriter(output, fieldnames=["id", "name", "age"], delimiter=delimiter)
4544
writer.writeheader()
4645
for row in data:
4746
writer.writerow(row)
4847

49-
# Ensure the pointer is at the beginning of the buffer before compressing
5048
output.seek(0)
49+
csv_data = output.read()
5150

52-
# Compress the CSV data with Gzip
53-
compressed_data = compress_with_gzip(output.read(), encoding=encoding)
54-
55-
return compressed_data
51+
if should_compress:
52+
return compress_with_gzip(csv_data, encoding=encoding)
53+
return csv_data.encode(encoding)
5654

5755

5856
@pytest.mark.parametrize("encoding", ["utf-8", "utf", "iso-8859-1"])
5957
def test_composite_raw_decoder_gzip_csv_parser(requests_mock, encoding: str):
6058
requests_mock.register_uri(
61-
"GET", "https://airbyte.io/", content=generate_csv(encoding=encoding)
59+
"GET",
60+
"https://airbyte.io/",
61+
content=generate_csv(encoding=encoding, delimiter="\t", should_compress=True),
6262
)
6363
response = requests.get("https://airbyte.io/", stream=True)
6464

@@ -175,3 +175,26 @@ def test_composite_raw_decoder_raises_traced_exception_when_both_parsers_fail(re
175175
with patch("json.loads", side_effect=Exception("test")):
176176
with pytest.raises(AirbyteTracedException):
177177
list(composite_raw_decoder.decode(response))
178+
179+
180+
@pytest.mark.parametrize("encoding", ["utf-8", "utf", "iso-8859-1"])
181+
@pytest.mark.parametrize("delimiter", [",", "\t", ";"])
182+
def test_composite_raw_decoder_csv_parser_values(requests_mock, encoding: str, delimiter: str):
183+
requests_mock.register_uri(
184+
"GET",
185+
"https://airbyte.io/",
186+
content=generate_csv(encoding=encoding, delimiter=delimiter, should_compress=False),
187+
)
188+
response = requests.get("https://airbyte.io/", stream=True)
189+
190+
parser = CsvParser(encoding=encoding, delimiter=delimiter)
191+
composite_raw_decoder = CompositeRawDecoder(parser=parser)
192+
193+
expected_data = [
194+
{"id": "1", "name": "John", "age": "28"},
195+
{"id": "2", "name": "Alice", "age": "34"},
196+
{"id": "3", "name": "Bob", "age": "25"},
197+
]
198+
199+
parsed_records = list(composite_raw_decoder.decode(response))
200+
assert parsed_records == expected_data

0 commit comments

Comments
 (0)