diff --git a/avalanche/benchmarks/scenarios/task_aware.py b/avalanche/benchmarks/scenarios/task_aware.py index 0b5459761..cd3d07508 100644 --- a/avalanche/benchmarks/scenarios/task_aware.py +++ b/avalanche/benchmarks/scenarios/task_aware.py @@ -96,9 +96,10 @@ def with_task_labels(obj): def _add_task_labels(exp): tls = exp.dataset.targets_task_labels.uniques + # tls is a set, we need to convert to list to call __getitem__ + tls = list(tls) if len(tls) == 1: - # tls is a set. we need to convert to list to call __getitem__ - exp.task_label = list(tls)[0] + exp.task_label = tls[0] exp.task_labels = tls return exp diff --git a/tests/benchmarks/scenarios/test_task_aware.py b/tests/benchmarks/scenarios/test_task_aware.py index a7822f661..693ce490b 100644 --- a/tests/benchmarks/scenarios/test_task_aware.py +++ b/tests/benchmarks/scenarios/test_task_aware.py @@ -14,7 +14,7 @@ class TestsTaskAware(unittest.TestCase): def test_taskaware(self): - """Common use case: add tas labels to class-incremental benchmark.""" + """Common use case: add task labels to class-incremental benchmark.""" n_classes, n_samples_per_class, n_features = 10, 3, 7 for _ in range(10000): @@ -58,6 +58,7 @@ def test_taskaware(self): ci_train = bm_ci.train_stream for eid, exp in enumerate(bm_ti.train_stream): assert exp.task_label == eid + assert isinstance(exp.task_labels, list) assert len(ci_train[eid].dataset) == len(exp.dataset)