Skip to content

Commit 177a5a8

Browse files
committed
fixed multiple epochs
1 parent b269150 commit 177a5a8

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

examples/star/star.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Main STaR Loop"""
22

3+
from copy import deepcopy
34
from datasets import Dataset, DatasetDict, load_dataset
45
from inference import generate_predictions
56
from train import train
@@ -21,13 +22,13 @@ def main():
2122
ds[split] = ds[split].add_column(name="text", column=texts)
2223

2324
model_name = args.model_name_or_path
24-
ds["train"] = ds["train"].select(range(10))
25+
output_dir = deepcopy(args.output_dir)
2526
for i in range(args.iteration):
2627
# sample
2728
all_samples = generate_predictions(
2829
model_name, ds["train"], args.temperature, args.n
2930
)
30-
ds["train"].add_column(name="sample", column=all_samples).to_json(f"{args.output_dir}/data/samples-iter{i}.json")
31+
ds["train"].add_column(name="sample", column=all_samples).to_json(f"{output_dir}/data/samples-iter{i}.json")
3132
assert len(ds["train"]) == len(all_samples)
3233

3334
# verify and construct the training set
@@ -43,10 +44,10 @@ def main():
4344
passed_examples.append(example)
4445
break
4546
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})
46-
raw_datasets["train"].to_json(f"{args.output_dir}/data/verified-samples-iter{i}.json")
47+
raw_datasets["train"].to_json(f"{output_dir}/data/verified-samples-iter{i}.json")
4748

4849
# train
49-
args.output_dir = f"{args.output_dir}/models-iter{i}"
50+
args.output_dir = f"{output_dir}/models-iter{i}"
5051
train(raw_datasets, model_name, args)
5152
model_name = args.output_dir
5253

examples/star/train.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,6 @@ def tokenize_function(examples):
254254
# The trackers initializes automatically on the main process.
255255
if args.with_tracking:
256256
experiment_config = vars(args)
257-
# TensorBoard cannot log Enums, need the raw value
258-
experiment_config["lr_scheduler_type"] = experiment_config[
259-
"lr_scheduler_type"
260-
].value
261257
accelerator.init_trackers("clm_no_trainer", experiment_config)
262258

263259
# Train!
@@ -407,8 +403,6 @@ def tokenize_function(examples):
407403
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
408404
json.dump({"perplexity": perplexity}, f)
409405
cleanup(model)
410-
#cleanup(optimizer)
411-
#cleanup(lr_scheduler)
412406

413407

414408
if __name__ == "__main__":

0 commit comments

Comments
 (0)