Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .cursorrules
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ Additional for integration tests:
# Run local tests
./bin/test-local

# Run a specific test file
./bin/test-local tests/unit/test_file.py

# ... or specific test from file
./bin/test-local tests/unit/test_file.py::TestClass::test_method

# Run specific test type
export TEST_TYPE="unit|integration"
export TOOLKIT_VERSION="local-build"
Expand Down
48 changes: 18 additions & 30 deletions deepnote_toolkit/ocelots/pandas/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pandas as pd

from deepnote_toolkit.ocelots.constants import DEEPNOTE_INDEX_COLUMN
from deepnote_toolkit.ocelots.pandas.utils import (
is_numeric_or_temporal,
is_type_datetime_or_timedelta,
safe_convert_to_string,
)
from deepnote_toolkit.ocelots.types import ColumnsStatsRecord, ColumnStats


Expand All @@ -24,7 +29,10 @@ def _get_categories(np_array):
# special treatment for empty values
num_nans = pandas_series.isna().sum().item()

counter = Counter(pandas_series.dropna().astype(str))
try:
counter = Counter(pandas_series.dropna().astype(str))
except (TypeError, UnicodeDecodeError, AttributeError):
counter = Counter(pandas_series.dropna().apply(safe_convert_to_string))

max_items = 3
if num_nans > 0:
Expand All @@ -46,33 +54,9 @@ def _get_categories(np_array):
return [{"name": name, "count": count} for name, count in categories]


def _is_type_numeric(dtype):
"""
Returns True if dtype is numeric, False otherwise

Numeric means either a number (int, float, complex) or a datetime or timedelta.
It means e.g. that a range of these values can be plotted on a histogram.
"""

# datetime doesn't play nice with np.issubdtype, so we need to check explicitly
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
dtype
):
return True

try:
return np.issubdtype(dtype, np.number)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False


def _get_histogram(pd_series):
try:
if pd.api.types.is_datetime64_any_dtype(
pd_series
) or pd.api.types.is_timedelta64_dtype(pd_series):
# convert datetime or timedelta to an integer so that a histogram can be created
if is_type_datetime_or_timedelta(pd_series):
np_array = np.array(pd_series.dropna().astype(int))
else:
# let's drop infinite values because they break histograms
Expand Down Expand Up @@ -104,11 +88,15 @@ def _calculate_min_max(column):
"""
Calculate min and max values for a given column.
"""
if _is_type_numeric(column.dtype):
if not is_numeric_or_temporal(column.dtype):
return None, None

try:
min_value = str(min(column.dropna())) if len(column.dropna()) > 0 else None
max_value = str(max(column.dropna())) if len(column.dropna()) > 0 else None
return min_value, max_value
return None, None
except (TypeError, ValueError):
return None, None


def analyze_columns(
Expand Down Expand Up @@ -167,7 +155,7 @@ def analyze_columns(
unique_count=_count_unique(column), nan_count=column.isnull().sum().item()
)

if _is_type_numeric(column.dtype):
if is_numeric_or_temporal(column.dtype):
min_value, max_value = _calculate_min_max(column)
columns[i].stats.min = min_value
columns[i].stats.max = max_value
Expand All @@ -187,7 +175,7 @@ def analyze_columns(
for i in range(max_columns_to_analyze, len(df.columns)):
# Ignore columns that are not numeric
column = df.iloc[:, i]
if not _is_type_numeric(column.dtype):
if not is_numeric_or_temporal(column.dtype):
continue

column_name = columns[i].name
Expand Down
58 changes: 49 additions & 9 deletions deepnote_toolkit/ocelots/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
from deepnote_toolkit.ocelots.constants import MAX_STRING_CELL_LENGTH


def safe_convert_to_string(value):
"""
Safely convert a value to string, handling cases where str() might fail.

Note: For bytes, this returns Python's standard string representation (e.g., b'hello')
rather than base64 encoding, which is more human-readable.
"""
try:
return str(value)
except Exception:
return "<unconvertible>"


# like fillna, but only fills NaT (not a time) values in datetime columns with the specified value
def fill_nat(df, value):
df_datetime_columns = df.select_dtypes(
Expand Down Expand Up @@ -76,36 +89,63 @@ def deduplicate_columns(df):
# Cast dataframe contents to strings and trim them to avoid sending too much data
def cast_objects_to_string(df):
def to_string_truncated(elem):
elem_string = str(elem)
elem_string = safe_convert_to_string(elem)
return (
(elem_string[: MAX_STRING_CELL_LENGTH - 1] + "…")
if len(elem_string) > MAX_STRING_CELL_LENGTH
else elem_string
)

for column in df:
if not _is_type_number(df[column].dtype):
if not is_pure_numeric(df[column].dtype):
# if the dtype is not a number, we want to convert it to string and truncate
df[column] = df[column].apply(to_string_truncated)

return df


def _is_type_number(dtype):
def is_type_datetime_or_timedelta(series_or_dtype):
"""
Returns True if dtype is a number, False otherwise. Datetime and timedelta will return False.
Returns True if the series or dtype is datetime or timedelta, False otherwise.
"""
return pd.api.types.is_datetime64_any_dtype(
series_or_dtype
) or pd.api.types.is_timedelta64_dtype(series_or_dtype)


The primary intent of this is to recognize a value that will converted to a JSON number during serialization.
def is_numeric_or_temporal(dtype):
"""
Returns True if dtype is numeric or temporal (datetime/timedelta), False otherwise.

if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
dtype
):
This includes numbers (int, float), datetime, and timedelta types.
Use this to determine if values can be plotted on a histogram or have min/max calculated.
"""
if is_type_datetime_or_timedelta(dtype):
return True

try:
return np.issubdtype(dtype, np.number) and not np.issubdtype(
dtype, np.complexfloating
)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False


def is_pure_numeric(dtype):
"""
Returns True if dtype is a pure number (int, float), False otherwise.

Use this to determine if a value will be serialized as a JSON number.
"""
if is_type_datetime_or_timedelta(dtype):
# np.issubdtype(dtype, np.number) returns True for timedelta, which we don't want
return False

try:
return np.issubdtype(dtype, np.number)
return np.issubdtype(dtype, np.number) and not np.issubdtype(
dtype, np.complexfloating
)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False
2 changes: 1 addition & 1 deletion deepnote_toolkit/ocelots/pyspark/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def select_column(field: StructField) -> Column:
# We slice binary field before encoding to avoid encoding potentially big blob. Round slicing to
# 4 bytes to avoid breaking multi-byte sequences
if isinstance(field.dataType, BinaryType):
sliced = F.substring(field, 1, keep_bytes)
sliced = F.substring(F.col(field.name), 1, keep_bytes)
return F.base64(sliced)

# String just needs to be trimmed
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/helpers/testing_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,14 @@ def create_dataframe_with_duplicate_column_names():
datetime.datetime(2023, 1, 1, 12, 0, 0),
datetime.datetime(2023, 1, 2, 12, 0, 0),
],
"binary": [b"hello", b"world"],
}
),
"pyspark_schema": pst.StructType(
[
pst.StructField("list", pst.ArrayType(pst.IntegerType()), True),
pst.StructField("datetime", pst.TimestampType(), True),
pst.StructField("binary", pst.BinaryType(), True),
]
),
},
Expand Down
Loading
Loading