Skip to content

Commit ce564d9

Browse files
committed
[AIP-5284] Integration with Transforms
1 parent c85dc6d commit ce564d9

File tree

5 files changed

+80
-1
lines changed

5 files changed

+80
-1
lines changed

datasets/datasets_decorator.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import keyword
33
from typing import Callable, Optional
44

5-
from datasets import DatasetPlugin
65
from datasets.context import Context
76
from datasets.utils import _pascal_to_snake_case
7+
from datasets.txf_integration.txf_utils import add_txf_attributes
8+
from .dataset_plugin import DatasetPlugin
89

910

11+
# flake8: noqa: C901
1012
def dataset(
1113
name: str = None,
1214
field_name: Optional[str] = None,
@@ -27,6 +29,19 @@ def step_wrapper(*args, **kwargs):
2729
_snake_name = _pascal_to_snake_case(dataset.name)
2830
setattr(self, _snake_name, dataset)
2931

32+
# Transformation integration
33+
if type(self).__name__ in ["FlowSpec", "OnlineFlowSpec"]:
34+
add_txf_attributes(self)
35+
if not self._txf_registered_datasets.get(self.name, None):
36+
self._txf_registered_datasets[self.name] = [dataset]
37+
else:
38+
self._txf_registered_datasets[self.name].append(dataset)
39+
40+
if len(self._txf_callbacks):
41+
for _, callback in self._txf_callbacks.items():
42+
callback()
43+
# End Transformation integration
44+
3045
func(*args, **kwargs)
3146

3247
return step_wrapper
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pandas as pd
2+
from metaflow import FlowSpec, step
3+
4+
from datasets import Mode, dataset
5+
from datasets.txf_integration.txf_utils import (
6+
TXF_REGISTERED_DATASETS_ATTRIBUTE,
7+
TXF_CALLBACKS_ATTRIBUTE,
8+
TXF_METADATA_ATTRIBUTE,
9+
add_callback,
10+
)
11+
12+
13+
def udf():
14+
pass
15+
16+
17+
class TxTestFlow(FlowSpec):
18+
@step
19+
def start(self):
20+
add_callback(self, "udf", udf)
21+
self.next(self.ds)
22+
23+
@dataset(name="TxfDataset", partition_by="region", mode=Mode.WRITE)
24+
@step
25+
def ds(self):
26+
df = pd.DataFrame({"region": ["A", "A", "A", "B", "B", "B"], "home_id": [1, 2, 3, 4, 5, 6]})
27+
self.txf_dataset.write(df)
28+
assert hasattr(self, TXF_REGISTERED_DATASETS_ATTRIBUTE)
29+
assert hasattr(self, TXF_CALLBACKS_ATTRIBUTE)
30+
assert hasattr(self, TXF_METADATA_ATTRIBUTE)
31+
assert len(getattr(self, TXF_CALLBACKS_ATTRIBUTE)) == 1
32+
self.next(self.end)
33+
34+
@step
35+
def end(self):
36+
pass
37+
38+
39+
if __name__ == "__main__":
40+
TxTestFlow()

datasets/tests/test_txf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from datasets.tests.test_tutorials import run_flow
2+
3+
4+
def test_input_output_flow():
5+
run_flow("tests/resources/txf_test_flow.py")

datasets/txf_integration/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Callable
2+
3+
TXF_REGISTERED_DATASETS_ATTRIBUTE = "_txf_registered_datasets"
4+
TXF_CALLBACKS_ATTRIBUTE = "_txf_callbacks"
5+
TXF_METADATA_ATTRIBUTE = "_txf_metadata_accumulator"
6+
7+
8+
def add_txf_attributes(flow):
9+
if not hasattr(flow, TXF_REGISTERED_DATASETS_ATTRIBUTE):
10+
setattr(flow, TXF_REGISTERED_DATASETS_ATTRIBUTE, dict())
11+
if not hasattr(flow, TXF_CALLBACKS_ATTRIBUTE):
12+
setattr(flow, TXF_CALLBACKS_ATTRIBUTE, {})
13+
if not hasattr(flow, TXF_METADATA_ATTRIBUTE):
14+
setattr(flow, TXF_METADATA_ATTRIBUTE, {})
15+
16+
17+
def add_callback(flow, caller: str, func: Callable):
18+
add_txf_attributes(flow)
19+
flow._txf_callbacks[caller] = func

0 commit comments

Comments
 (0)