Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import shutil

import numpy as np
import oyaml as yaml


Expand Down Expand Up @@ -40,6 +41,23 @@
from audformat.core.table import Table


def _is_string_like_dtype(dtype) -> bool:
"""Check if dtype is a string-like dtype.

Args:
dtype: A pandas/numpy dtype to check.

Returns:
True if dtype is string-like (str, StringDtype, object with string data),
False otherwise.

"""
# Check for pandas StringDtype (e.g., "string", "string[python]", "string[pyarrow]")
if isinstance(dtype, pd.StringDtype) or pd.api.types.is_string_dtype(dtype):
return True
return False


class Database(HeaderBase):
r"""Database object.

Expand Down Expand Up @@ -696,12 +714,17 @@ def append_series(ys, y, column_id):
ys.append(y)

def dtypes_of_categories(objs):
dtypes = [
obj.dtype.categories.dtype
for obj in objs
if isinstance(obj.dtype, pd.CategoricalDtype)
]
return sorted(list(set(dtypes)))
dtypes = []
for obj in objs:
if isinstance(obj.dtype, pd.CategoricalDtype):
dtype = obj.dtype.categories.dtype
# Normalize string-like dtypes to object for consistency
# (pandas 3.0 may use 'str' or StringDtype for categories)
if _is_string_like_dtype(dtype):
dtype = np.dtype("O")
dtypes.append(dtype)
# Deduplicate and sort for consistent ordering
return sorted(list(set(dtypes)), key=str)

def empty_frame(name):
return pd.DataFrame(
Expand Down Expand Up @@ -832,6 +855,13 @@ def scheme_in_column(scheme_id, column, column_id):
ys[n] = y.astype(
pd.CategoricalDtype(y.array.dropna().unique().astype(dtype))
)
# Normalize all string-like categorical dtypes to "object" for consistency
# (pandas 3.0 may infer "str" or StringDtype for string categories)
for n, y in enumerate(ys):
cat_dtype = y.dtype.categories.dtype
if _is_string_like_dtype(cat_dtype):
new_categories = y.dtype.categories.astype("object")
ys[n] = y.astype(pd.CategoricalDtype(new_categories))
# Find union of categorical data
data = [y.array for y in ys]
try:
Expand Down
61 changes: 57 additions & 4 deletions tests/test_database_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def wrong_scheme_labels_db(tmpdir):
]
),
dtype=pd.CategoricalDtype(
["w1", "w2", "w3"],
pd.Index(["w1", "w2", "w3"], dtype="object"),
ordered=False,
),
),
Expand Down Expand Up @@ -603,7 +603,7 @@ def wrong_scheme_labels_db(tmpdir):
[0.2, 0.2, 0.5, 0.7],
),
dtype=pd.CategoricalDtype(
["s1", "s2", "s3"],
pd.Index(["s1", "s2", "s3"], dtype="object"),
ordered=False,
),
name="speaker",
Expand Down Expand Up @@ -849,7 +849,7 @@ def wrong_scheme_labels_db(tmpdir):
},
index=audformat.filewise_index(["f1.wav", "f2.wav"]),
dtype=pd.CategoricalDtype(
["female", "male"],
pd.Index(["female", "male"], dtype="str"),
ordered=False,
),
),
Expand Down Expand Up @@ -1253,7 +1253,7 @@ def test_database_get_aggregate_and_modify_function(
["s1"],
index=audformat.filewise_index(["f1.wav"]),
dtype=pd.CategoricalDtype(
["s1", "s2", "s3"],
pd.Index(["s1", "s2", "s3"], dtype="object"),
ordered=False,
),
name="speaker",
Expand Down Expand Up @@ -1825,3 +1825,56 @@ def test_database_get_errors(
tables=tables,
original_column_names=original_column_names,
)


def test_get_mixed_str_and_object_categorical_dtype(tmpdir):
"""Test that mixing 'str' and 'object' categorical dtypes normalizes to 'object'."""
db = audformat.Database("test")

# Create scheme with string labels
db.schemes["label"] = audformat.Scheme(
"str",
labels=["a", "b", "c"],
)

# Create two tables with the same scheme
index1 = audformat.filewise_index(["f1.wav", "f2.wav"])
db["table1"] = audformat.Table(index1)
db["table1"]["label"] = audformat.Column(scheme_id="label")
db["table1"]["label"].set(["a", "b"])

index2 = audformat.filewise_index(["f3.wav", "f4.wav"])
db["table2"] = audformat.Table(index2)
db["table2"]["label"] = audformat.Column(scheme_id="label")
db["table2"]["label"].set(["b", "c"])

# Manually set different category dtypes (object + string)
# (even if this might not happen in reality)
df1 = db["table1"].df
df2 = db["table2"].df
df1["label"] = df1["label"].astype(
pd.CategoricalDtype(
pd.Index(["a", "b", "c"], dtype="object"),
ordered=False,
)
)
db["table1"]._df = df1
df2["label"] = df2["label"].astype(
pd.CategoricalDtype(
pd.Index(["a", "b", "c"], dtype="string"),
ordered=False,
)
)
db["table2"]._df = df2

# Get combined data - this should work without errors
result = db.get("label")

# Verify result contains all labels
assert list(result["label"]) == ["a", "b", "b", "c"]

# Verify the resulting categorical has object dtype for categories
result_dtype = result["label"].dtype
assert isinstance(result_dtype, pd.CategoricalDtype)
# Categories should be normalized to object dtype
assert result_dtype.categories.dtype == np.dtype("O")