Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def get_dataset(
project_id: str,
location: str,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | _MethodDefault = DEFAULT,
timeout: float | None | _MethodDefault = DEFAULT,
metadata: Sequence[tuple[str, str]] = (),
) -> automl_translation.Dataset:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class TranslationNativeDatasetLink(BaseGoogleLink):
"""

name = "Translation Native Dataset"
key = "translation_naive_dataset"
key = "translation_native_dataset"
format_str = TRANSLATION_NATIVE_DATASET_LINK


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TranslationNativeDatasetLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.operators.vertex_ai.dataset import DatasetImportDataResultsCheckHelper
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
Expand Down Expand Up @@ -575,7 +576,7 @@ def execute(self, context: Context):
return result_ids


class TranslateImportDataOperator(GoogleCloudBaseOperator):
class TranslateImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
"""
Import data to the translation dataset.

Expand All @@ -602,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
"""

template_fields: Sequence[str] = (
Expand All @@ -627,6 +629,7 @@ def __init__(
retry: Retry | _MethodDefault = DEFAULT,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
raise_for_empty_result: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -639,9 +642,21 @@ def __init__(
self.retry = retry
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.raise_for_empty_result = raise_for_empty_result

def execute(self, context: Context):
hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
initial_dataset_size = self._get_number_of_ds_items(
dataset=hook.get_dataset(
dataset_id=self.dataset_id,
project_id=self.project_id,
location=self.location,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
),
total_key_name="example_count",
)
self.log.info("Importing data to dataset...")
operation = hook.import_dataset_data(
dataset_id=self.dataset_id,
Expand All @@ -660,7 +675,22 @@ def execute(self, context: Context):
location=self.location,
)
hook.wait_for_operation_done(operation=operation, timeout=self.timeout)

result_dataset_size = self._get_number_of_ds_items(
dataset=hook.get_dataset(
dataset_id=self.dataset_id,
project_id=self.project_id,
location=self.location,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
),
total_key_name="example_count",
)
if self.raise_for_empty_result:
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
self.log.info("Importing data finished!")
return {"total_imported": int(result_dataset_size) - int(initial_dataset_size)}


class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
Expand Down Expand Up @@ -335,7 +336,21 @@ def execute(self, context: Context):
self.log.info("Export was done successfully")


class ImportDataOperator(GoogleCloudBaseOperator):
class DatasetImportDataResultsCheckHelper:
"""Helper utils to verify import dataset data results."""

@staticmethod
def _get_number_of_ds_items(dataset, total_key_name):
number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
return number_of_items

@staticmethod
def _raise_for_empty_import_result(dataset_id, initial_size, size_after_import):
if int(size_after_import) - int(initial_size) <= 0:
raise AirflowException(f"Empty results of data import for the dataset_id {dataset_id}.")


class ImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
"""
Imports data into a Dataset.

Expand All @@ -356,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
"""

template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
Expand All @@ -372,6 +388,7 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
raise_for_empty_result: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -384,13 +401,24 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.raise_for_empty_result = raise_for_empty_result

def execute(self, context: Context):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

initial_dataset_size = self._get_number_of_ds_items(
dataset=hook.get_dataset(
dataset_id=self.dataset_id,
project_id=self.project_id,
region=self.region,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
),
total_key_name="data_item_count",
)
self.log.info("Importing data: %s", self.dataset_id)
operation = hook.import_data(
project_id=self.project_id,
Expand All @@ -402,7 +430,21 @@ def execute(self, context: Context):
metadata=self.metadata,
)
hook.wait_for_operation(timeout=self.timeout, operation=operation)
result_dataset_size = self._get_number_of_ds_items(
dataset=hook.get_dataset(
dataset_id=self.dataset_id,
project_id=self.project_id,
region=self.region,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
),
total_key_name="data_item_count",
)
if self.raise_for_empty_result:
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
self.log.info("Import was done successfully")
return {"total_data_items_imported": int(result_dataset_size) - int(initial_dataset_size)}


class ListDatasetsOperator(GoogleCloudBaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.translate_v3.types import (
BatchTranslateDocumentResponse,
Dataset,
TranslateDocumentResponse,
automl_translation,
translation_service,
Expand Down Expand Up @@ -331,6 +332,19 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
"input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": "import data gcs path"}}]
}
mock_hook.return_value.import_dataset_data.return_value = mock.MagicMock()

SAMPLE_DATASET = {
"name": "sample_translation_dataset",
"example_count": None,
"source_language_code": "en",
"target_language_code": "es",
}
INITIAL_DS_SIZE = 1
FINAL_DS_SIZE = 101
INITIAL_DS = {**SAMPLE_DATASET, "example_count": INITIAL_DS_SIZE}
FINAL_DS = {**SAMPLE_DATASET, "example_count": FINAL_DS_SIZE}

mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)]
op = TranslateImportDataOperator(
task_id="task_id",
dataset_id=DATASET_ID,
Expand All @@ -343,7 +357,7 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
retry=DEFAULT,
)
context = mock.MagicMock()
op.execute(context=context)
res = op.execute(context=context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
Expand All @@ -363,6 +377,7 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
location=LOCATION,
project_id=PROJECT_ID,
)
assert res["total_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE


class TestTranslateDeleteData:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from google.api_core.gapic_v1.method import DEFAULT
from google.api_core.retry import Retry
from google.cloud.aiplatform_v1.types.dataset import Dataset

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
Expand Down Expand Up @@ -1362,9 +1363,8 @@ def test_execute(self, mock_hook, to_dict_mock):


class TestVertexAIImportDataOperator:
@mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
def test_execute(self, mock_hook, to_dict_mock):
def test_execute(self, mock_hook):
op = ImportDataOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1377,7 +1377,20 @@ def test_execute(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
op.execute(context={})
SAMPLE_DATASET = {
"name": "sample_translation_dataset",
"display_name": "VertexAI dataset",
"data_item_count": None,
}
INITIAL_DS_SIZE = 1
FINAL_DS_SIZE = 101
INITIAL_DS = {**SAMPLE_DATASET, "data_item_count": INITIAL_DS_SIZE}
FINAL_DS = {**SAMPLE_DATASET, "data_item_count": FINAL_DS_SIZE}

mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)]

res = op.execute(context={})

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.import_data.assert_called_once_with(
region=GCP_LOCATION,
Expand All @@ -1388,6 +1401,7 @@ def test_execute(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
assert res["total_data_items_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE


class TestVertexAIListDatasetsOperator:
Expand Down