-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Da 24/reading comprehension #74
Changes from 1 commit
298e100
ce4a23f
9987c06
b842667
f6d253a
571b41c
5c349ca
f11797b
21b6210
3c49970
c7b63f6
8a9ec67
c0c530f
fcdfa3b
0f7b7e0
7515b0a
75c7811
979408c
42f77f5
de7f60c
ead9e7b
3efa61b
1d39491
3abc635
980cc6a
1868423
ce9b352
3760fc5
fe85703
331eaed
8c1a35a
e77d273
2f91c25
3b8e270
a37b5d3
280e2dd
3facf92
3f50761
8bb10c4
42330eb
194cf91
451015b
2e3fd08
9a430e5
3eb3b44
14f65c5
55f03ef
4d5802d
3f17238
88a74c1
4ab1211
b6fa0ae
644c294
9c2f00d
1d4ec97
9171734
1af81cd
3039050
4d93d4a
5fb1252
958f1f4
749f11c
903ba18
bc63323
a6f0e6e
0b89e1d
340b969
9ae906d
1d3ed52
a7ab91c
9cb16ee
4025e45
39b7f1e
81a2ea0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments | ||
|
||
from trl import SFTTrainer | ||
from trl.trainer import ConstantLengthDataset | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think given the number of options that exists for tracker, separating ourselves from a certain tracker may be the best way forward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, can we just default to no tracking? Then users can add it as needed, including libs |
||
import argparse | ||
|
||
|
@@ -20,17 +19,13 @@ def create_datasets( | |
size_valid_set: Optional[int], | ||
streaming: bool, | ||
shuffle_buffer: int, | ||
seq_length: int, | ||
num_workers: int, | ||
tokenizer: AutoTokenizer, | ||
formatting_func: callable, | ||
local_dataset: bool = True, | ||
): | ||
if local_dataset: | ||
dataset = load_from_disk( | ||
dataset_name, | ||
) | ||
streaming = False | ||
else: | ||
dataset = load_dataset( | ||
dataset_name, | ||
|
@@ -49,34 +44,15 @@ def create_datasets( | |
valid_data = dataset["test"] | ||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") | ||
|
||
chars_per_token = chars_token_ratio(train_data, tokenizer, formatting_func) | ||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") | ||
return train_data, valid_data | ||
|
||
train_dataset = ConstantLengthDataset( | ||
tokenizer, | ||
train_data, | ||
formatting_func=formatting_func, | ||
infinite=True, | ||
seq_length=seq_length, | ||
chars_per_token=chars_per_token, | ||
) | ||
valid_dataset = ConstantLengthDataset( | ||
tokenizer, | ||
valid_data, | ||
formatting_func=formatting_func, | ||
infinite=False, | ||
seq_length=seq_length, | ||
chars_per_token=chars_per_token, | ||
) | ||
return train_dataset, valid_dataset | ||
|
||
|
||
def chars_token_ratio(dataset, tokenizer, formatting_func, nb_examples=400): | ||
def chars_token_ratio(dataset, tokenizer, formatting_func, sample_size=400): | ||
""" | ||
Estimate the average number of characters per token in the dataset. | ||
""" | ||
total_characters, total_tokens = 0, 0 | ||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): | ||
for _, example in tqdm(zip(range(sample_size), iter(dataset)), total=sample_size): | ||
text = formatting_func(example) | ||
total_characters += len(text) | ||
if tokenizer.is_fast: | ||
|
@@ -91,9 +67,7 @@ def parse_args(): | |
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be |
||
parser.add_argument("--log_with", type=str, default="wandb", help="use 'wandb' to log with wandb") | ||
parser.add_argument( | ||
"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name" | ||
) | ||
parser.add_argument("--dataset_name", type=str, required=True, help="The dataset name either corresponding repo or Format should be that of ChatML") | ||
parser.add_argument("--split", type=str, default="train", help="the split to use") | ||
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems easier for users to express this in terms of percentage of training set rather than absolute size. Eg, 20% There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this argument for streaming, I believe in this scenario a preset size is required There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it, for streaming we don't know the full size of the dataset until we load it. I think its confusing and misleading though to advertise this parameter without mentioning the fact that its streaming only. WDYT of changing to this?
|
||
parser.add_argument("--streaming", type=bool, default=False, help="whether to stream the dataset") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason not to just default to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't tested out the advantages of streaming over just vanilla loading. So hard for me to go for set that as a default for the user |
||
|
@@ -189,6 +163,7 @@ def train_generator( | |
output_dir=output_dir, | ||
per_device_train_batch_size=per_device_train_batch_size, | ||
gradient_accumulation_steps=gradient_accumulation_steps, | ||
gradient_checkpointing=gradient_checkpointing, | ||
per_device_eval_batch_size=per_device_eval_batch_size, | ||
learning_rate=learning_rate, | ||
logging_steps=logging_steps, | ||
|
@@ -204,6 +179,8 @@ def train_generator( | |
bf16=True, | ||
remove_unused_columns=False, | ||
run_name="generator_tuning", | ||
weight_decay=weight_decay, | ||
log_freq=log_freq, | ||
) | ||
|
||
def prepare_sample_text(example): | ||
|
@@ -223,6 +200,9 @@ def prepare_sample_text(example): | |
prepare_sample_text, | ||
) | ||
|
||
chars_per_token = chars_token_ratio(train_dataset, tokenizer, prepare_sample_text) | ||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") | ||
|
||
tleyden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
trainer = SFTTrainer( | ||
model=base_model, | ||
train_dataset=train_dataset, | ||
|
@@ -233,6 +213,7 @@ def prepare_sample_text(example): | |
tokenizer=tokenizer, | ||
args=training_args, | ||
neftune_noise_alpha=neftune_noise_alpha, | ||
chars_per_token=chars_per_token, | ||
) | ||
|
||
trainer.train() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think trl needs to get added to the
project.toml
dependencies