Skip to content

Commit cf00c2d

Browse files
committed
Sync contents
1 parent fc84491 commit cf00c2d

File tree

10 files changed

+18
-17
lines changed

10 files changed

+18
-17
lines changed

.sync_state

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"last_synced_sha": "e039f707084c266fc64d92b2228da6f0a30710b9",
3-
"last_sync_time": "2025-10-02T01:33:33.157309"
2+
"last_synced_sha": "09d45a371c67ee32d2a31ffcdc6709a9b8e10a18",
3+
"last_sync_time": "2025-10-05T23:54:36.006757"
44
}

tinker_cookbook/hyperparam_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def get_lr(model_name: str, is_lora: bool = True) -> float:
148148
lora_multiplier = 10.0
149149

150150
lr = base_lr * lora_multiplier if is_lora else base_lr
151-
if "llama" in model_name:
151+
if "llama" in model_name.lower():
152152
exponent_model = 0.781
153-
elif "qwen" in model_name:
153+
elif "qwen" in model_name.lower():
154154
exponent_model = 0.0775
155155
else:
156156
assert False, f"Unknown model: {model_name}"

tinker_cookbook/recipes/math_rl/arithmetic_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import partial
2+
from typing import Sequence
23

34
import chz
45
import numpy as np
@@ -62,7 +63,7 @@ def __init__(
6263
self.n_batches = n_batches
6364
self.include_fewshot = include_fewshot
6465

65-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
66+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
6667
self._rng.seed(index)
6768
return [self._make_env_group_builder(self._rng) for _ in range(self.batch_size)]
6869

tinker_cookbook/recipes/math_rl/math_env.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import re
33
from functools import partial
4-
from typing import Literal, cast
4+
from typing import Literal, Sequence, cast
55

66
import chz
77
from datasets import Dataset, concatenate_datasets, get_dataset_config_names, load_dataset
@@ -153,7 +153,7 @@ def __init__(
153153
self.renderer = renderer
154154
self.convo_prefix = convo_prefix
155155

156-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
156+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
157157
batch_start = index * self.batch_size
158158
batch_end = min((index + 1) * self.batch_size, len(self.ds))
159159
assert batch_start < batch_end, "Incorrect batch size"
@@ -329,7 +329,7 @@ def __init__(
329329
def question_suffix(cls) -> str:
330330
return " Provide a numerical answer without units, written inside \\boxed{}."
331331

332-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
332+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
333333
batch_start = index * self.batch_size
334334
batch_end = min((index + 1) * self.batch_size, len(self.ds))
335335
assert batch_start < batch_end, "Incorrect batch size"

tinker_cookbook/recipes/multiplayer_rl/guess_number/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class GuessNumberDataset(RLDataset):
112112
batch_size: int
113113
group_size: int
114114

115-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
115+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
116116
return [
117117
GuessNumberEnvGroupBuilder(
118118
answer=self.answers[index * self.batch_size + i],

tinker_cookbook/recipes/multiplayer_rl/text_arena/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(self, batch_size: int, builder: TwoPlayerEnvGroupBuilder, num_datap
234234
"num_datapoints must be divisible by num_players"
235235
)
236236

237-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
237+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
238238
return [
239239
self.builder
240240
for i in range(self.batch_size // self.builder.num_players)

tinker_cookbook/recipes/multiplayer_rl/twenty_questions/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class TwentyQuestionsDataset(RLDataset):
165165
batch_size: int
166166
group_size: int
167167

168-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
168+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
169169
return [
170170
TwentyQuestionsEnvGroupBuilder(
171171
answerer=self.answerer,

tinker_cookbook/recipes/tool_use/search/search_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import string
77
from functools import partial, reduce
88
from pathlib import Path
9-
from typing import Literal, TypedDict, cast
9+
from typing import Literal, Sequence, TypedDict, cast
1010

1111
import chz
1212
import pandas as pd
@@ -315,7 +315,7 @@ def __init__(
315315
rng = random.Random(self.seed)
316316
rng.shuffle(self.ds)
317317

318-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
318+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
319319
return [
320320
self._make_env_group_builder(row, self.group_size)
321321
for row in self.ds[index * self.batch_size : (index + 1) * self.batch_size]

tinker_cookbook/rl/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import os
1010
import time
11-
from typing import Any, Callable, List
11+
from typing import Any, Callable, List, Sequence
1212

1313
import chz
1414
import numpy as np
@@ -567,7 +567,7 @@ async def save_checkpoint_and_get_sampling_client(
567567

568568
async def prepare_minibatch(
569569
cfg: Config,
570-
env_group_builders_P: list[EnvGroupBuilder],
570+
env_group_builders_P: Sequence[EnvGroupBuilder],
571571
trajectory_groups_P: list[TrajectoryGroup],
572572
tokenizer: Tokenizer,
573573
service_client: tinker.ServiceClient,
@@ -744,7 +744,7 @@ async def do_train_step_and_get_sampling_client(
744744
training_client: tinker.TrainingClient,
745745
service_client: tinker.ServiceClient,
746746
tokenizer: Tokenizer,
747-
env_group_builders_P: list[EnvGroupBuilder],
747+
env_group_builders_P: Sequence[EnvGroupBuilder],
748748
trajectory_groups_P: list[TrajectoryGroup],
749749
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
750750
metrics = {}

tinker_cookbook/rl/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class RLDataset(ABC):
137137
"""
138138

139139
@abstractmethod
140-
def get_batch(self, index: int) -> list[EnvGroupBuilder]:
140+
def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
141141
pass
142142

143143
@abstractmethod

0 commit comments

Comments
 (0)