Skip to content

Commit

Permalink
added validate_df
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikjedlinski committed Oct 25, 2023
1 parent 1e2d708 commit 99cbcec
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/integration/flows/test_salesforce_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from viadot.flows import SalesforceToADLS
from viadot.tasks import AzureDataLakeRemove
from viadot.exceptions import ValidationError

ADLS_FILE_NAME = "test_salesforce.parquet"
ADLS_DIR_PATH = "raw/tests/"
Expand Down Expand Up @@ -32,3 +33,29 @@ def test_salesforce_to_adls():
vault_name="azuwevelcrkeyv001s",
)
rm.run(sp_credentials_secret=credentials_secret)


def test_salesforce_to_adls_validate_success():
credentials_secret = PrefectSecret(
"AZURE_DEFAULT_ADLS_SERVICE_PRINCIPAL_SECRET"
).run()

flow = SalesforceToADLS(
"test_salesforce_to_adls_run_flow",
query="SELECT IsDeleted, FiscalYear FROM Opportunity LIMIT 50",
adls_sp_credentials_secret=credentials_secret,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
validate_df_dict={"column_list_to_match": ["IsDeleted", "FiscalYear"]},
)

result = flow.run()
assert result.is_successful()

os.remove("test_salesforce_to_adls_run_flow.parquet")
os.remove("test_salesforce_to_adls_run_flow.json")
rm = AzureDataLakeRemove(
path=ADLS_DIR_PATH + ADLS_FILE_NAME,
vault_name="azuwevelcrkeyv001s",
)
rm.run(sp_credentials_secret=credentials_secret)
12 changes: 12 additions & 0 deletions viadot/flows/salesforce_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
df_to_parquet,
dtypes_to_json_task,
update_dtypes_dict,
validate_df,
)
from viadot.tasks import AzureDataLakeUpload, SalesforceToDF

Expand All @@ -41,6 +42,7 @@ def __init__(
adls_file_name: str = None,
adls_sp_credentials_secret: str = None,
if_exists: str = "replace",
validate_df_dict: Dict[str, Any] = None,
timeout: int = 3600,
*args: List[Any],
**kwargs: Dict[str, Any],
Expand Down Expand Up @@ -70,6 +72,8 @@ def __init__(
ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake.
Defaults to None.
if_exists (str, optional): What to do if the file exists. Defaults to "replace".
validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output
dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None.
timeout(int, optional): The amount of time (in seconds) to wait while running this task before
a timeout occurs. Defaults to 3600.
"""
Expand All @@ -82,6 +86,7 @@ def __init__(
self.env = env
self.vault_name = vault_name
self.credentials_secret = credentials_secret
self.validate_df_dict = validate_df_dict

# AzureDataLakeUpload
self.adls_sp_credentials_secret = adls_sp_credentials_secret
Expand Down Expand Up @@ -135,6 +140,13 @@ def gen_flow(self) -> Flow:
df_clean = df_clean_column.bind(df=df, flow=self)
df_with_metadata = add_ingestion_metadata_task.bind(df_clean, flow=self)
dtypes_dict = df_get_data_types_task.bind(df_with_metadata, flow=self)

if self.validate_df_dict:
validation_task = validate_df.bind(
df, tests=self.validate_df_dict, flow=self
)
validation_task.set_upstream(df, flow=self)

df_to_be_loaded = df_map_mixed_dtypes_for_parquet(
df_with_metadata, dtypes_dict, flow=self
)
Expand Down

0 comments on commit 99cbcec

Please sign in to comment.