Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion audformat/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,14 @@ def segmented_index(
define.IndexField.END,
],
)
index = utils.set_index_dtypes(index, {define.IndexField.FILE: "string"})
index = utils.set_index_dtypes(
index,
{
define.IndexField.FILE: "string",
define.IndexField.START: "timedelta64[ns]",
define.IndexField.END: "timedelta64[ns]",
},
)
assert_index(index)

return index
11 changes: 3 additions & 8 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,14 +1479,14 @@ def set_index_dtypes(
index with new dtypes

Examples:
>>> index1 = pd.Index(["a", "b"])
>>> index1 = pd.Index(["a", "b"], dtype="object")
>>> index1
Index(['a', 'b'], dtype='object')
>>> index2 = set_index_dtypes(index1, "string")
>>> index2
Index(['a', 'b'], dtype='string')
>>> index3 = pd.MultiIndex.from_arrays(
... [["a", "b"], [1, 2]],
... [pd.Index(["a", "b"], dtype="object"), [1, 2]],
... names=["level1", "level2"],
... )
>>> index3.dtypes
Expand All @@ -1498,11 +1498,6 @@ def set_index_dtypes(
level1 object
level2 float64
dtype: object
>>> index5 = set_index_dtypes(index3, "string")
>>> index5.dtypes
level1 string[python]
level2 string[python]
dtype: object

"""
levels = index.names if isinstance(index, pd.MultiIndex) else [index.name]
Expand Down Expand Up @@ -1533,7 +1528,7 @@ def set_index_dtypes(
if pd.api.types.is_timedelta64_dtype(dtype):
# avoid: TypeError: Cannot cast DatetimeArray
# to dtype timedelta64[ns]
df[level] = pd.to_timedelta(list(df[level]))
df[level] = pd.to_timedelta(list(df[level])).astype(dtype)
else:
df[level] = df[level].astype(dtype)
index = pd.MultiIndex.from_frame(df)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,26 @@ def test_create_segmented_index(files, starts, ends):
] * len(files)


@pytest.mark.parametrize(
"files, starts, ends",
[
# normal case with sub-second values
(["f1.wav"], [0.001], [0.002]),
# NaT in ends
(["f1.wav"], [0], [pd.NaT]),
# NaT in starts and ends
(["f1.wav"], [pd.NaT], [pd.NaT]),
# empty index
(None, None, None),
],
)
def test_segmented_index_timedelta_dtype(files, starts, ends):
"""Ensure segmented_index always returns timedelta64[ns]."""
index = audformat.segmented_index(files, starts=starts, ends=ends)
assert index.get_level_values("start").dtype == "timedelta64[ns]"
assert index.get_level_values("end").dtype == "timedelta64[ns]"


@pytest.mark.parametrize(
"index, index_type",
[
Expand Down
40 changes: 35 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,7 +1913,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
"index, dtypes, expected",
[
(
pd.Index([]),
pd.Index([], dtype="object"),
"string",
pd.Index([], dtype="string"),
),
Expand All @@ -1923,7 +1923,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
pd.Index([]),
),
(
pd.Index(["a", "b"]),
pd.Index(["a", "b"], dtype="object"),
"string",
pd.Index(["a", "b"], dtype="string"),
),
Expand All @@ -1933,7 +1933,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
pd.Index(["a", "b"], dtype="string"),
),
(
pd.Index(["a", "b"], name="idx"),
pd.Index(["a", "b"], name="idx", dtype="object"),
{"idx": "string"},
pd.Index(["a", "b"], name="idx", dtype="string"),
),
Expand Down Expand Up @@ -2034,7 +2034,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
pd.MultiIndex.from_arrays(
[
[1, 2],
pd.to_timedelta([0, 1], unit="s"),
pd.to_timedelta([0, 1], unit="s").astype("timedelta64[ns]"),
],
names=["idx", "time"],
),
Expand All @@ -2054,7 +2054,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
pd.MultiIndex.from_arrays(
[
[1, 2],
[pd.NaT, pd.NaT],
pd.to_datetime([pd.NaT, pd.NaT]).astype("datetime64[ns]"),
],
names=["idx", "date"],
),
Expand Down Expand Up @@ -2108,6 +2108,36 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
None,
marks=pytest.mark.xfail(raises=ValueError),
),
(
pd.MultiIndex.from_arrays(
[
pd.Index([], dtype="string"),
pd.Index([], dtype="int64"),
pd.Index([], dtype="object"),
],
names=[
define.IndexField.FILE,
define.IndexField.START,
define.IndexField.END,
],
),
{
define.IndexField.START: "timedelta64[ns]",
define.IndexField.END: "timedelta64[ns]",
},
pd.MultiIndex.from_arrays(
[
pd.Index([], dtype="string"),
pd.Index([], dtype="timedelta64[ns]"),
pd.Index([], dtype="timedelta64[ns]"),
],
names=[
define.IndexField.FILE,
define.IndexField.START,
define.IndexField.END,
],
),
),
],
)
def test_set_index_dtypes(index, dtypes, expected):
Expand Down