Skip to content

Commit

Permalink
fix clip_labels() bug
Browse files Browse the repository at this point in the history
resolves BoxedAnnotations.clip_labels is producing all-true labels in develop branch #1061

issue was that the function find_overlapping_idxs_in_clip_df was matching any idx in clip_df based on start/end time even if the file doesn’t match. With many files it ended up up making almost all entries True because it applied labels based on only time from any file to clips from all files.

fixed and added test
  • Loading branch information
sammlapp committed Oct 3, 2024
1 parent 84367c2 commit 6c643c2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 21 deletions.
47 changes: 26 additions & 21 deletions opensoundscape/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def from_crowsetta(cls, annotations):
@classmethod
def from_csv(cls, path):
"""load csv from path and creates BoxedAnnotations object
Note: the .annotation_files and .audio_files attributes will be none
Args:
path: file path of csv.
see __init__() docstring for required column names
Expand Down Expand Up @@ -919,25 +922,24 @@ class names for each clip; also returns a second value, the list of class names
output_df = clip_df.copy()

# how we store labels depends on `multihot` argument, either
# multi-hot or lists of integer class indices
# multi-hot 2d array of 0/1 or lists of integer class indices
if return_type == "multihot":
# add columns for each class with 0s. We will add 1s in the loop below
output_df[classes] = False

# add the annotations by adding class index positions to appropriate rows
for class_name in classes:
# get just the annotations for this class
class_df = df[df["annotation"] == class_name]
for _, row in class_df.iterrows():
annotation_start = row["start_time"]
annotation_end = row["end_time"]
# find the overlapping rows, gets the multi-index back
class_annotations = df[df["annotation"] == class_name]
for _, row in class_annotations.iterrows():
# find the rows sufficiently overlapped by this annotation, gets the multi-index back
df_idxs = find_overlapping_idxs_in_clip_df(
annotation_start,
annotation_end,
clip_df,
min_label_overlap,
min_label_fraction,
file=row["audio_file"],
annotation_start=row["start_time"],
annotation_end=row["end_time"],
clip_df=clip_df,
min_label_overlap=min_label_overlap,
min_label_fraction=min_label_fraction,
)
if len(df_idxs) > 0:
output_df.loc[df_idxs, class_name] = True
Expand All @@ -948,17 +950,16 @@ class names for each clip; also returns a second value, the list of class names
# add the annotations by adding the integer class indices to row label lists
for class_idx, class_name in enumerate(classes):
# get just the annotations for this class
class_df = df[df["annotation"] == class_name]
for _, row in class_df.iterrows():
annotation_start = row["start_time"]
annotation_end = row["end_time"]
class_annotations = df[df["annotation"] == class_name]
for _, row in class_annotations.iterrows():
# find the rows that overlap with the annotation enough in time
df_idxs = find_overlapping_idxs_in_clip_df(
annotation_start,
annotation_end,
clip_df,
min_label_overlap,
min_label_fraction,
file=row["audio_file"],
annotation_start=row["start_time"],
annotation_end=row["end_time"],
clip_df=clip_df,
min_label_overlap=min_label_overlap,
min_label_fraction=min_label_fraction,
)

for idx in df_idxs:
Expand Down Expand Up @@ -1026,7 +1027,7 @@ def clip_labels(
'classes': returns a dataframe with 'labels' column containing lists of
class names for each clip
'CategoricalLabels': returns a CategoricalLabels object
**kwargs (such as overlap_fraction, final_clip) are passed to
**kwargs (such as clip_step, final_clip) are passed to
opensoundscape.utils.generate_clip_times_df() via make_clip_df()
Returns: depends on `return_type` argument
'multihot': [default] returns a dataframe with a column for each class
Expand Down Expand Up @@ -1346,6 +1347,7 @@ def _df_to_crowsetta_bboxes(df):


def find_overlapping_idxs_in_clip_df(
file,
annotation_start,
annotation_end,
clip_df,
Expand All @@ -1355,6 +1357,7 @@ def find_overlapping_idxs_in_clip_df(
"""
Finds the (file, start_time, end_time) index values for the rows in the clip_df that overlap with the annotation_start and annotation_end
Args:
file: audio file path/name the annotation corresponds to; clip_df is filtered to this file
annotation_start: start time of the annotation
annotation_end: end time of the annotation
clip_df: dataframe with multi-index ['file', 'start_time', 'end_time']
Expand All @@ -1375,6 +1378,8 @@ def find_overlapping_idxs_in_clip_df(
Returns:
[(file, start_time, end_time)]) Multi-index values for the rows in the clip_df that overlap with the annotation_start and annotation_end.
"""
# filter to rows corresponding to this file
clip_df = clip_df.loc[clip_df.index.get_level_values(0) == file]
# ignore all rows that start after the annotation ends. Start is level 1 of multi-index
clip_df = clip_df.loc[clip_df.index.get_level_values(1) < annotation_end]
# and all rows that end before the annotation starts. End is level 2 of multi-index
Expand Down
46 changes: 46 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@ def raven_file():
return "tests/raven_annots/MSD-0003_20180427_2minstart00.Table.1.selections.txt"


@pytest.fixture()
def audio_2min():
return "tests/audio/MSD-0003_20180427_2minstart00.wav"


@pytest.fixture()
def raven_file_empty():
return "tests/raven_annots/EmptyExample.Table.1.selections.txt"


@pytest.fixture()
def audio_silence():
return "tests/audio/silence_10s.mp3"


@pytest.fixture()
def saved_raven_file(request):
path = Path("tests/raven_annots/audio_file.selections.txt")
Expand Down Expand Up @@ -399,6 +409,42 @@ def test_labels_on_index_overlap(boxed_annotations):
assert np.array_equal(labels.values, np.array([[1, 1, 0, 0, 0]]).transpose())


def test_clip_labels_with_audio_file(
raven_file, audio_2min, raven_file_empty, audio_silence
):
"""test that clip_labels works properly with multiple audio+raven files
checks that Issue #1061 is resolved
"""
boxed_annotations = BoxedAnnotations.from_raven_files(
raven_files=[raven_file, raven_file_empty],
audio_files=[audio_2min, audio_silence],
)
labels = boxed_annotations.clip_labels(
full_duration=None, clip_duration=5, clip_overlap=0, min_label_overlap=0
)
# should get back 2 min & 10 s audio file labels for 5s clips
assert len(labels) == 24 + 2
# no label on silent audio!
assert labels.head(0).sum().sum() == 0
# check for correct clip labels
assert np.array_equal(
labels.head(6).values,
np.array(
[
[False, False, False],
[False, False, False],
[True, True, False],
[True, True, False],
[True, True, True],
[False, True, False],
]
),
)
# no labels after 20 seconds in 2 min audio
assert labels.tail(-6).sum().sum() == 0


def test_clip_labels(boxed_annotations):
# test "multihot" return type
labels = boxed_annotations.clip_labels(
Expand Down

0 comments on commit 6c643c2

Please sign in to comment.