Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/datacustomcode/io/reader/query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
logger = logging.getLogger(__name__)


SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {}"
SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {} LIMIT {}"
PANDAS_TYPE_MAPPING = {
"object": StringType(),
"int64": LongType(),
Expand Down Expand Up @@ -85,29 +85,40 @@ def __init__(self, spark: SparkSession) -> None:
)

def read_dlo(
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
self,
name: str,
schema: Union[AtomicType, StructType, str, None] = None,
row_limit: int = 1000,
) -> PySparkDataFrame:
"""
Read a Data Lake Object (DLO) from the Data Cloud.
Read a Data Lake Object (DLO) from the Data Cloud, limited to a number of rows.

Args:
name (str): The name of the DLO.
schema (Optional[Union[AtomicType, StructType, str]]): Schema of the DLO.
row_limit (int): Maximum number of rows to fetch.

Returns:
PySparkDataFrame: The PySpark DataFrame.
"""
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
pandas_df = self._conn.get_pandas_dataframe(
SQL_QUERY_TEMPLATE.format(name, row_limit)
)
if not schema:
# auto infer schema
schema = _pandas_to_spark_schema(pandas_df)
spark_dataframe = self.spark.createDataFrame(pandas_df, schema)
return spark_dataframe

def read_dmo(
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
self,
name: str,
schema: Union[AtomicType, StructType, str, None] = None,
row_limit: int = 1000,
) -> PySparkDataFrame:
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
pandas_df = self._conn.get_pandas_dataframe(
SQL_QUERY_TEMPLATE.format(name, row_limit)
)
if not schema:
# auto infer schema
schema = _pandas_to_spark_schema(pandas_df)
Expand Down
55 changes: 54 additions & 1 deletion src/datacustomcode/io/writer/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,73 @@
# limitations under the License.


from pyspark.sql import DataFrame as PySparkDataFrame
from typing import Optional

from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession

from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode


class PrintDataCloudWriter(BaseDataCloudWriter):
CONFIG_NAME = "PrintDataCloudWriter"

def __init__(
self, spark: SparkSession, reader: Optional[QueryAPIDataCloudReader] = None
) -> None:
super().__init__(spark)
self.reader = QueryAPIDataCloudReader(self.spark) if reader is None else reader

def validate_dataframe_columns_against_dlo(
self,
dataframe: PySparkDataFrame,
dlo_name: str,
) -> None:
"""
Validates that all columns in the given dataframe exist in the DLO schema.

Args:
dataframe (PySparkDataFrame): The DataFrame to validate.
dlo_name (str): The name of the DLO to check against.
reader (QueryAPIDataCloudReader): The reader to use for schema retrieval.

Raises:
ValueError: If any columns in the dataframe are not present in the DLO
schema.
"""
# Get DLO schema (no data, just schema)
dlo_df = self.reader.read_dlo(dlo_name, row_limit=0)
dlo_columns = set(dlo_df.columns)
df_columns = set(dataframe.columns)

# Find columns in dataframe not present in DLO
extra_columns = df_columns - dlo_columns
if extra_columns:
raise ValueError(
"The following columns are not present in the \n"
f"DLO '{dlo_name}': {sorted(extra_columns)}.\n"
"To fix this error, you can either:\n"
" - Drop these columns from your DataFrame before writing, e.g.,\n"
" dataframe = dataframe.drop({cols})\n"
" - Or, add these columns to the DLO schema in Data Cloud.".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be very helpful!

cols=sorted(extra_columns)
)
)

def write_to_dlo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None:

# Validate columns before proceeding
self.validate_dataframe_columns_against_dlo(dataframe, name)

dataframe.show()

def write_to_dmo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None:
# The way its validating for DLO and dataframes columns,
# its not going to work for DMO because DMO may not exists,
# so just show the dataframe.

dataframe.show()
55 changes: 5 additions & 50 deletions src/datacustomcode/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ast
import os
import sys
from typing import (
Any,
ClassVar,
Expand All @@ -40,6 +41,8 @@
},
}

STANDARD_LIBS = set(sys.stdlib_module_names)


class DataAccessLayerCalls(pydantic.BaseModel):
read_dlo: frozenset[str]
Expand Down Expand Up @@ -137,54 +140,6 @@ def found(self) -> DataAccessLayerCalls:
class ImportVisitor(ast.NodeVisitor):
"""AST Visitor that extracts external package imports from Python code."""

# Standard library modules that should be excluded from requirements
STANDARD_LIBS: ClassVar[set[str]] = {
"abc",
"argparse",
"ast",
"asyncio",
"base64",
"collections",
"configparser",
"contextlib",
"copy",
"csv",
"datetime",
"enum",
"functools",
"glob",
"hashlib",
"http",
"importlib",
"inspect",
"io",
"itertools",
"json",
"logging",
"math",
"os",
"pathlib",
"pickle",
"random",
"re",
"shutil",
"site",
"socket",
"sqlite3",
"string",
"subprocess",
"sys",
"tempfile",
"threading",
"time",
"traceback",
"typing",
"uuid",
"warnings",
"xml",
"zipfile",
}

# Additional packages to exclude from requirements.txt
EXCLUDED_PACKAGES: ClassVar[set[str]] = {
"datacustomcode", # Internal package
Expand All @@ -200,7 +155,7 @@ def visit_Import(self, node: ast.Import) -> None:
# Get the top-level package name
package = name.name.split(".")[0]
if (
package not in self.STANDARD_LIBS
package not in STANDARD_LIBS
and package not in self.EXCLUDED_PACKAGES
and not package.startswith("_")
):
Expand All @@ -213,7 +168,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
# Get the top-level package
package = node.module.split(".")[0]
if (
package not in self.STANDARD_LIBS
package not in STANDARD_LIBS
and package not in self.EXCLUDED_PACKAGES
and not package.startswith("_")
):
Expand Down
8 changes: 4 additions & 4 deletions tests/io/reader/test_query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_read_dlo(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dlo")
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
)

# Verify DataFrame was created with auto-inferred schema
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_read_dlo_with_schema(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dlo")
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
)

# Verify DataFrame was created with provided schema
Expand All @@ -192,7 +192,7 @@ def test_read_dmo(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dmo")
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
)

# Verify DataFrame was created
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_read_dmo_with_schema(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dmo")
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
)

# Verify DataFrame was created with provided schema
Expand Down
43 changes: 40 additions & 3 deletions tests/io/writer/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,33 @@ def mock_dataframe(self):
return df

@pytest.fixture
def print_writer(self, mock_spark_session):
def mock_reader(self):
"""Create a mock QueryAPIDataCloudReader."""
reader = MagicMock()
mock_dlo_df = MagicMock()
mock_dlo_df.columns = ["col1", "col2"]
reader.read_dlo.return_value = mock_dlo_df
return reader

@pytest.fixture
def print_writer(self, mock_spark_session, mock_reader):
"""Create a PrintDataCloudWriter instance."""
return PrintDataCloudWriter(mock_spark_session)
return PrintDataCloudWriter(mock_spark_session, mock_reader)

def test_write_to_dlo(self, print_writer, mock_dataframe):
"""Test write_to_dlo method calls dataframe.show()."""
# Mock the validate_dataframe_columns_against_dlo method
print_writer.validate_dataframe_columns_against_dlo = MagicMock()

# Call the method
print_writer.write_to_dlo("test_dlo", mock_dataframe, WriteMode.OVERWRITE)

# Verify show() was called
mock_dataframe.show.assert_called_once()

# Verify validate_dataframe_columns_against_dlo was called
print_writer.validate_dataframe_columns_against_dlo.assert_called_once()

def test_write_to_dmo(self, print_writer, mock_dataframe):
"""Test write_to_dmo method calls dataframe.show()."""
# Call the method
Expand All @@ -59,9 +74,31 @@ def test_ignores_name_and_write_mode(self, print_writer, mock_dataframe):
for name, write_mode in test_cases:
# Reset mock before each call
mock_dataframe.show.reset_mock()

# Mock the validate_dataframe_columns_against_dlo method
print_writer.validate_dataframe_columns_against_dlo = MagicMock()
# Call method
print_writer.write_to_dlo(name, mock_dataframe, write_mode)

# Verify show() was called with no arguments
mock_dataframe.show.assert_called_once_with()

print_writer.validate_dataframe_columns_against_dlo.assert_called_once()

def test_validate_dataframe_columns_against_dlo(self, print_writer, mock_dataframe):
"""Test validate_dataframe_columns_against_dlo method."""
# Mock the QueryAPIDataCloudReader

# Set up mock dataframe columns
mock_dataframe.columns = ["col1", "col2", "col3"]

# Test that validation raises ValueError for extra columns
with pytest.raises(ValueError) as exc_info:
print_writer.validate_dataframe_columns_against_dlo(
mock_dataframe, "test_dlo"
)

assert "col3" in str(exc_info.value)

# Test successful validation with matching columns
mock_dataframe.columns = ["col1", "col2"]
print_writer.validate_dataframe_columns_against_dlo(mock_dataframe, "test_dlo")