|
1 | 1 | import math |
2 | 2 | import re |
3 | 3 | from functools import partial |
4 | | -from typing import Literal, cast |
| 4 | +from typing import Literal, Sequence, cast |
5 | 5 |
|
6 | 6 | import chz |
7 | 7 | from datasets import Dataset, concatenate_datasets, get_dataset_config_names, load_dataset |
@@ -153,7 +153,7 @@ def __init__( |
153 | 153 | self.renderer = renderer |
154 | 154 | self.convo_prefix = convo_prefix |
155 | 155 |
|
156 | | - def get_batch(self, index: int) -> list[EnvGroupBuilder]: |
| 156 | + def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: |
157 | 157 | batch_start = index * self.batch_size |
158 | 158 | batch_end = min((index + 1) * self.batch_size, len(self.ds)) |
159 | 159 | assert batch_start < batch_end, "Incorrect batch size" |
@@ -329,7 +329,7 @@ def __init__( |
329 | 329 | def question_suffix(cls) -> str: |
330 | 330 | return " Provide a numerical answer without units, written inside \\boxed{}." |
331 | 331 |
|
332 | | - def get_batch(self, index: int) -> list[EnvGroupBuilder]: |
| 332 | + def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: |
333 | 333 | batch_start = index * self.batch_size |
334 | 334 | batch_end = min((index + 1) * self.batch_size, len(self.ds)) |
335 | 335 | assert batch_start < batch_end, "Incorrect batch size" |
|
0 commit comments