Skip to content

Commit f9a32cd

Browse files
Merge pull request #439 from RelevanceAI/development
v0.32.0
2 parents b44811a + 0dad922 commit f9a32cd

File tree

5 files changed

+159
-31
lines changed

5 files changed

+159
-31
lines changed

ai_transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.31.3"
1+
__version__ = "0.32.0"
22

33
from ai_transform.timer import Timer
44

ai_transform/engine/abstract_engine.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import time
2-
import logging
32
import warnings
43

54
from json import JSONDecodeError
@@ -8,7 +7,7 @@
87

98
from tqdm.auto import tqdm
109

11-
from ai_transform.logger import format_logging_info, ic
10+
from ai_transform.logger import ic
1211
from ai_transform.types import Filter
1312
from ai_transform.dataset.dataset import Dataset
1413
from ai_transform.operator.abstract_operator import AbstractOperator
@@ -107,8 +106,10 @@ def __init__(
107106
filters = []
108107
assert isinstance(filters, list), "Filters must be applied as a list of Dictionaries"
109108

110-
if not refresh:
111-
filters += self._get_refresh_filter(select_fields, dataset)
109+
self._refresh = refresh
110+
self._after_id = after_id
111+
112+
filters += self._get_refresh_filter()
112113
filters += self._get_workflow_filter()
113114

114115
self._filters = filters
@@ -118,9 +119,6 @@ def __init__(
118119
else:
119120
self._size = dataset.len(filters=filters) if self._limit_documents is None else self._limit_documents
120121

121-
self._refresh = refresh
122-
self._after_id = after_id
123-
124122
self._successful_documents = 0
125123
self._success_ratio = None
126124

@@ -206,36 +204,36 @@ def _operate(self, mini_batch):
206204
self._successful_documents += len(mini_batch)
207205
return transformed_batch
208206

209-
def _get_refresh_filter(self, select_fields: List[str], dataset: Dataset):
207+
def _get_refresh_filter(self):
210208
# initialize the refresh filter container
211-
refresh_filters = {"filter_type": "or", "condition_value": []}
209+
input_field_filters = {"filter_type": "or", "condition_value": []}
212210

213211
# initialize where the filters are going
214-
input_field_filters = []
215212
output_field_filters = {"filter_type": "or", "condition_value": []}
216213

217-
# We want documents where all select_fields exists
214+
# We want documents where any of the select_fields exists
218215
# as these are needed for operator ...
219-
for field in select_fields:
220-
input_field_filters += dataset[field].exists()
221-
222-
# ... and where any of its output_fields dont exist
223-
for operator in self.operators:
224-
if operator.output_fields is not None:
225-
for output_field in operator.output_fields:
226-
output_field_filters["condition_value"] += dataset[output_field].not_exists()
227-
228216
# We construct this as:
229217
#
230-
# input_field1 and input_field2 and (not output_field1 or not output_field2)
218+
# (input_field1 or input_field2) and (not output_field1 or not output_field2)
231219
#
232220
# This use case here is for two input fields and two output fields
233221
# tho this extends to arbitrarily many.
234-
refresh_filters["condition_value"] = input_field_filters
235-
refresh_filters["condition_value"] += [output_field_filters]
222+
for field in self._select_fields:
223+
input_field_filters["condition_value"] += self.dataset[field].exists()
224+
225+
# ... and where any of its output_fields dont exist
226+
if not self._refresh:
227+
for operator in self.operators:
228+
if operator.output_fields is not None:
229+
for output_field in operator.output_fields:
230+
output_field_filters["condition_value"] += self.dataset[output_field].not_exists()
231+
232+
return [input_field_filters, output_field_filters]
236233

237-
# Wrap in list at end
238-
return [refresh_filters]
234+
else:
235+
# Wrap in list at end
236+
return [input_field_filters]
239237

240238
def _get_workflow_filter(self, field: str = "_id"):
241239
# Get the required workflow filter as an environment variable

examples/fail_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(
4040

4141
def transform(self, documents: List[Document]) -> List[Document]:
4242
try:
43-
text = [document[self.text_field] for document in documents]
43+
raise ValueError
44+
4445
except:
4546
# pass
4647
raise UserFacingError(

tests/conftest.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,55 @@ def full_dataset(test_client: Client) -> Dataset:
6363
test_client.delete_dataset(dataset_id)
6464

6565

66+
@pytest.fixture(scope="class")
67+
def partial_dataset(test_client: Client) -> Dataset:
68+
salt = "".join(random.choices(string.ascii_lowercase, k=10))
69+
dataset_id = f"_sample_dataset_{salt}"
70+
dataset = test_client.Dataset(dataset_id, expire=True)
71+
documents = mock_documents(1000)
72+
fields = ["sample_1_label", "sample_2_label", "sample_3_label"]
73+
for document in documents:
74+
for field in random.sample(fields, k=random.randint(1, 3)):
75+
document.pop(field)
76+
dataset.insert_documents(documents)
77+
yield dataset
78+
test_client.delete_dataset(dataset_id)
79+
80+
81+
@pytest.fixture(scope="class")
82+
def simple_partial_dataset(test_client: Client) -> Dataset:
83+
salt = "".join(random.choices(string.ascii_lowercase, k=10))
84+
dataset_id = f"_sample_dataset_{salt}"
85+
dataset = test_client.Dataset(dataset_id, expire=True)
86+
documents = mock_documents(1000)
87+
fields = ["sample_1_label"]
88+
for document in documents:
89+
if random.random() < 0.5:
90+
document.pop(fields[0])
91+
dataset.insert_documents(documents)
92+
yield dataset
93+
test_client.delete_dataset(dataset_id)
94+
95+
96+
@pytest.fixture(scope="class")
97+
def partial_dataset_with_outputs(test_client: Client) -> Dataset:
98+
salt = "".join(random.choices(string.ascii_lowercase, k=10))
99+
dataset_id = f"_sample_dataset_{salt}"
100+
dataset = test_client.Dataset(dataset_id, expire=True)
101+
documents = mock_documents(1000)
102+
fields = ["sample_1_label", "sample_2_label", "sample_3_label"]
103+
for document in documents:
104+
for field in random.sample(fields, k=random.randint(1, 3)):
105+
document.pop(field)
106+
for document in documents:
107+
for field in fields:
108+
if document.get(field) and random.random() < 0.5:
109+
document[field + "_output"] = document[field] + "_output"
110+
dataset.insert_documents(documents)
111+
yield dataset
112+
test_client.delete_dataset(dataset_id)
113+
114+
66115
@pytest.fixture(scope="class")
67116
def mixed_dataset(test_client: Client) -> Dataset:
68117
salt = "".join(random.choices(string.ascii_lowercase, k=10))
@@ -150,6 +199,26 @@ def transform(self, documents: DocumentList) -> DocumentList:
150199
return ExampleOperator()
151200

152201

202+
@pytest.fixture(scope="function")
203+
def test_partial_operator() -> AbstractOperator:
204+
class PartialOperator(AbstractOperator):
205+
def __init__(self, fields):
206+
super().__init__(input_fields=fields, output_fields=[field + "_output" for field in fields])
207+
208+
def transform(self, documents: DocumentList) -> DocumentList:
209+
"""
210+
Main transform function
211+
"""
212+
for input_field, output_field in zip(self.input_fields, self.output_fields):
213+
for document in documents:
214+
if document.get(input_field):
215+
document[output_field] = document[input_field] + "_output"
216+
217+
return documents
218+
219+
return PartialOperator
220+
221+
153222
@pytest.fixture(scope="function")
154223
def test_paid_operator() -> AbstractOperator:
155224
class ExampleOperator(AbstractOperator):
@@ -243,7 +312,7 @@ def test_user_facing_error_workflow_token(test_client: Client) -> str:
243312
job_id=job_id,
244313
dataset_id=dataset_id,
245314
authorizationToken=test_client.credentials.token,
246-
text_field="sample_1_description_not_in_dataset",
315+
text_field="sample_1_description",
247316
)
248317
config_string = json.dumps(config)
249318
config_bytes = config_string.encode()
Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,80 @@
1+
import uuid
2+
3+
from typing import Type
4+
15
from ai_transform.dataset.dataset import Dataset
26
from ai_transform.engine.stable_engine import StableEngine
37
from ai_transform.engine.small_batch_stable_engine import SmallBatchStableEngine
48

59
from ai_transform.operator.abstract_operator import AbstractOperator
6-
from ai_transform.workflow.abstract_workflow import AbstractWorkflow
10+
from ai_transform.workflow.abstract_workflow import Workflow
11+
12+
13+
def _random_id():
14+
return str(uuid.uuid4())
715

816

917
class TestStableEngine:
1018
def test_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1119
engine = StableEngine(full_dataset, test_operator, worker_number=0)
12-
workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123")
20+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
1321
workflow.run()
1422
assert engine.success_ratio == 1
1523

1624
def test_small_batch_stable_engine(self, full_dataset: Dataset, test_operator: AbstractOperator):
1725
engine = SmallBatchStableEngine(full_dataset, test_operator)
18-
workflow = AbstractWorkflow(name="workflow_test123", engine=engine, job_id="test_job123")
26+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
27+
workflow.run()
28+
assert engine.success_ratio == 1
29+
30+
31+
class TestStableEngineFilters:
32+
_SELECTED_FIELDS = ["sample_1_label", "sample_2_label", "sample_3_label"]
33+
34+
def test_stable_engine_filters1(self, partial_dataset: Dataset, test_partial_operator: Type[AbstractOperator]):
35+
prev_health = partial_dataset.health()
36+
operator = test_partial_operator(self._SELECTED_FIELDS)
37+
38+
engine = StableEngine(partial_dataset, operator, select_fields=self._SELECTED_FIELDS)
39+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
40+
workflow.run()
41+
42+
post_health = partial_dataset.health()
43+
for input_field, output_field in zip(operator.input_fields, operator.output_fields):
44+
assert prev_health[input_field]["exists"] == post_health[output_field]["exists"]
45+
46+
assert engine.success_ratio == 1
47+
48+
def test_stable_engine_filters2(
49+
self, partial_dataset_with_outputs: Dataset, test_partial_operator: Type[AbstractOperator]
50+
):
51+
prev_health = partial_dataset_with_outputs.health()
52+
operator = test_partial_operator(self._SELECTED_FIELDS)
53+
54+
engine = StableEngine(
55+
partial_dataset_with_outputs, operator, select_fields=self._SELECTED_FIELDS, refresh=False
56+
)
57+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
58+
workflow.run()
59+
60+
post_health = partial_dataset_with_outputs.health()
61+
for input_field, output_field in zip(operator.input_fields, operator.output_fields):
62+
assert prev_health[input_field]["exists"] == post_health[output_field]["exists"]
63+
64+
assert engine.success_ratio == 1
65+
66+
def test_stable_engine_filters3(
67+
self, simple_partial_dataset: Dataset, test_partial_operator: Type[AbstractOperator]
68+
):
69+
prev_health = simple_partial_dataset.health()
70+
operator = test_partial_operator(["sample_1_label"])
71+
72+
engine = StableEngine(simple_partial_dataset, operator, select_fields=["sample_1_label"], refresh=False)
73+
workflow = Workflow(name=_random_id(), engine=engine, job_id=_random_id())
1974
workflow.run()
75+
76+
post_health = simple_partial_dataset.health()
77+
for input_field, output_field in zip(operator.input_fields, operator.output_fields):
78+
assert prev_health[input_field]["exists"] == post_health[output_field]["exists"]
79+
2080
assert engine.success_ratio == 1

0 commit comments

Comments
 (0)