Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions audformat/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pandas as pd

from audformat.core import define
from audformat.core import utils
from audformat.core.typing import Files
from audformat.core.typing import Timestamps

Expand All @@ -27,9 +26,13 @@ def to_array(value: object) -> list | np.ndarray:
def to_timedelta(times):
r"""Convert time value to pd.Timedelta."""
try:
return pd.to_timedelta(times, unit="s")
result = pd.to_timedelta(times, unit="s")
except ValueError: # catches values like '1s'
return pd.to_timedelta(times)
result = pd.to_timedelta(times)
if isinstance(result, pd.Timedelta):
return result.as_unit("ns")
else:
return result.astype("timedelta64[ns]")


def assert_index(
Expand Down Expand Up @@ -354,21 +357,18 @@ def segmented_index(
)

index = pd.MultiIndex.from_arrays(
[files, to_timedelta(starts), to_timedelta(ends)],
[
# Enforce string dtype from pandas<3.0 to get same hash values
pd.Index(files, dtype="string"),
to_timedelta(starts),
to_timedelta(ends),
],
names=[
define.IndexField.FILE,
define.IndexField.START,
define.IndexField.END,
],
)
index = utils.set_index_dtypes(
index,
{
define.IndexField.FILE: "string",
define.IndexField.START: "timedelta64[ns]",
define.IndexField.END: "timedelta64[ns]",
},
)
assert_index(index)

return index
3 changes: 2 additions & 1 deletion audformat/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from audformat.core.index import is_filewise_index
from audformat.core.index import is_segmented_index
from audformat.core.index import segmented_index
from audformat.core.index import to_timedelta
from audformat.core.media import Media
from audformat.core.rater import Rater
from audformat.core.scheme import Scheme
Expand Down Expand Up @@ -115,7 +116,7 @@ def add_table(

"""
if isinstance(file_duration, str):
file_duration = pd.Timedelta(file_duration)
file_duration = to_timedelta(file_duration)

audio_format = "wav"
if media_id and db.media[media_id].format:
Expand Down
35 changes: 22 additions & 13 deletions tests/test_misc_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,20 @@ def create_misc_table(
create_misc_table(
pd.Series(
[1.0],
index=pd.Index(["a"], name="idx", dtype="object"),
index=pd.Index(["a"], name="idx", dtype="string"),
),
),
create_misc_table(
pd.Series(
index=pd.Index([], name="idx", dtype="object"),
index=pd.Index([], name="idx", dtype="string"),
dtype="float",
),
),
],
create_misc_table(
pd.Series(
[1.0],
index=pd.Index(["a"], name="idx", dtype="object"),
index=pd.Index(["a"], name="idx", dtype="string"),
),
),
),
Expand Down Expand Up @@ -370,12 +370,6 @@ def test_copy(table):
"Int64",
audformat.define.DataType.INTEGER,
),
(
[],
str,
"object",
audformat.define.DataType.OBJECT,
),
(
[],
"string",
Expand Down Expand Up @@ -426,7 +420,7 @@ def test_copy(table):
),
(
["0"],
None,
"object",
"object",
audformat.define.DataType.OBJECT,
),
Expand Down Expand Up @@ -1137,9 +1131,14 @@ def test_drop_extend_and_pick_index_order():
),
# table empty
(
create_misc_table(pd.Index([], name="idx")),
create_misc_table(pd.Index([], name="idx", dtype="object")),
pd.Index(["a", "b"], name="idx", dtype="object"),
pd.Index([], name="idx"),
pd.Index([], name="idx", dtype="object"),
),
(
create_misc_table(pd.Index([], name="idx", dtype="string")),
pd.Index(["a", "b"], name="idx", dtype="string"),
pd.Index([], name="idx", dtype="string"),
),
# index empty
(
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def test_extend_index():

# empty and invalid

db["misc"] = audformat.MiscTable(pd.Index([], name="idx"))
db["misc"] = audformat.MiscTable(pd.Index([], name="idx", dtype="object"))
db["misc"].extend_index(pd.Index([], name="idx", dtype="object"))
assert db["misc"].get().empty
with pytest.raises(
Expand Down Expand Up @@ -1377,12 +1376,22 @@ def test_load_old_pickle(tmpdir):
pd.Index(["a", "b"], name="idx", dtype="object"),
pd.Index([], name="idx", dtype="object"),
),
(
create_misc_table(pd.Index([], name="idx", dtype="string")),
pd.Index(["a", "b"], name="idx", dtype="string"),
pd.Index([], name="idx", dtype="string"),
),
# index empty
(
create_misc_table(pd.Index(["a", "b"], name="idx", dtype="object")),
pd.Index([], name="idx", dtype="object"),
pd.Index([], name="idx", dtype="object"),
),
(
create_misc_table(pd.Index(["a", "b"], name="idx", dtype="string")),
pd.Index([], name="idx", dtype="string"),
pd.Index([], name="idx", dtype="string"),
),
# index and table identical
(
create_misc_table(pd.Index(["a", "b"], name="idx")),
Expand Down
5 changes: 3 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,8 @@ def test_extend_index():
index = pd.MultiIndex.from_arrays(
[
["3.wav"],
[pd.Timedelta(0)],
[pd.Timedelta(4, unit="s")],
[pd.Timedelta(0).as_unit("ns")],
[pd.Timedelta(4, unit="s").as_unit("ns")],
],
names=["file", "start", "end"],
)
Expand Down Expand Up @@ -1860,6 +1860,7 @@ def test_type():
pd.TimedeltaIndex(
[pd.NaT],
name=audformat.define.IndexField.END,
dtype="timedelta64[ns]",
),
)
pd.testing.assert_index_equal(db["files"].index, db["files"].index)
Expand Down
43 changes: 25 additions & 18 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re

import numpy as np
from packaging import version
import pandas as pd
import pytest

Expand Down Expand Up @@ -1048,19 +1049,19 @@ def test_index_has_overlap(obj, expected):
(
[
pd.MultiIndex.from_arrays(
[[0, 1], ["a", "b"]],
[
[0, 1],
pd.Index(["a", "b"], dtype="object"),
],
names=["idx1", "idx2"],
),
],
audformat.utils.set_index_dtypes(
pd.MultiIndex.from_arrays(
[[], []],
names=["idx1", "idx2"],
),
{
"idx1": "Int64",
"idx2": "object",
},
pd.MultiIndex.from_arrays(
[
pd.Index([], dtype="Int64"),
pd.Index([], dtype="object"),
],
names=["idx1", "idx2"],
),
),
(
Expand Down Expand Up @@ -1245,8 +1246,11 @@ def test_intersect(objs, expected):
),
(
[
pd.Index(["a", "b", "c"], name="l"),
pd.MultiIndex.from_arrays([[10, 20]], names=["l"]),
pd.Index(["a", "b", "c"], name="l", dtype="object"),
pd.MultiIndex.from_arrays(
[pd.Index([10, 20], dtype="int")],
names=["l"],
),
],
"Found different level dtypes: ['object', 'int']",
),
Expand Down Expand Up @@ -1286,15 +1290,15 @@ def test_intersect(objs, expected):
[
pd.MultiIndex.from_arrays(
[
["a", "b", "c"],
pd.Index(["a", "b", "c"], dtype="object"),
[1, 2, 3],
],
names=["l1", "l2"],
),
pd.MultiIndex.from_arrays(
[
[10],
["10"],
pd.Index(["10"], dtype="object"),
],
names=["l1", "l2"],
),
Expand All @@ -1305,15 +1309,15 @@ def test_intersect(objs, expected):
[
pd.MultiIndex.from_arrays(
[
["a", "b", "c"],
pd.Index(["a", "b", "c"], dtype="object"),
[1, 2, 3],
],
names=["l1", "l2"],
),
pd.MultiIndex.from_arrays(
[
[],
[],
pd.Index([], dtype="object"),
pd.Index([], dtype="object"),
],
names=["l1", "l2"],
),
Expand Down Expand Up @@ -1923,6 +1927,9 @@ def test_read_csv(csv, result):
obj = audformat.utils.read_csv(csv, as_dataframe=True)
if isinstance(result, pd.Index):
result = pd.DataFrame([], columns=[], index=result)
# Fix expected column type under pandas 3.0.0
if version.parse(pd.__version__) >= version.parse("3.0.0"):
result.columns = result.columns.astype("str")
elif isinstance(result, pd.Series):
result = result.to_frame()
pd.testing.assert_frame_equal(obj, result)
Expand Down Expand Up @@ -2143,7 +2150,7 @@ def test_replace_file_extension(index, extension, pattern, expected_index):
(
pd.MultiIndex.from_arrays(
[
["f1", "f2"],
pd.Index(["f1", "f2"], dtype="object"),
[0, int(1e9)],
[pd.NaT, pd.NaT],
],
Expand Down
7 changes: 5 additions & 2 deletions tests/test_utils_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def test_concat_aggregate_function(objs, aggregate_function, expected):
pd.Series(
["a", "a"],
pd.Index(["a", "b"]),
dtype="object",
dtype="str",
),
),
(
Expand Down Expand Up @@ -1771,7 +1771,10 @@ def test_concat_overwrite_aggregate_function(
),
pd.Series( # default dtype is object
[2.0],
pd.MultiIndex.from_arrays([["f1"]], names=["idx"]),
pd.MultiIndex.from_arrays(
[pd.Index(["f1"], dtype="object")],
names=["idx"],
),
),
],
None,
Expand Down