Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Da 24/reading comprehension #74

Merged
merged 74 commits into from
Dec 4, 2023
Merged
Changes from 1 commit
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
298e100
Broken reading-comprehension generators and generator training
metric-space Nov 15, 2023
ce4a23f
Refactor out shared utils and add domain tokenizer training as a opti…
metric-space Nov 16, 2023
9987c06
Corrections
metric-space Nov 16, 2023
b842667
Further corrections
metric-space Nov 16, 2023
f6d253a
Add q&a extractor as a util and to the pipeline example script
metric-space Nov 16, 2023
571b41c
Util correction and context addition to output
metric-space Nov 16, 2023
5c349ca
Revert previous corrections
metric-space Nov 16, 2023
f11797b
Chatml outputformat for regex based rc and remove domain keyword
metric-space Nov 17, 2023
21b6210
Generator additions
metric-space Nov 17, 2023
3c49970
More generator corrections and additions
metric-space Nov 17, 2023
c7b63f6
Add train num epochs as an arg to the generator training script
metric-space Nov 17, 2023
8a9ec67
Add new dependencies
metric-space Nov 18, 2023
c0c530f
1. json dump when writing to file
metric-space Nov 18, 2023
fcdfa3b
Post pipeline test run corrections
metric-space Nov 19, 2023
0f7b7e0
Remove (explicit) use of ConstantLengthDataset
metric-space Nov 21, 2023
7515b0a
Lift state management to a higer level and some corrections
metric-space Nov 21, 2023
75c7811
Add option to save dataset as huggingface dataset
metric-space Nov 21, 2023
979408c
Training script corrections
metric-space Nov 21, 2023
42f77f5
Trainer script cleanup and corrections
metric-space Nov 21, 2023
de7f60c
Add cloud friendly logger
metric-space Nov 21, 2023
ead9e7b
Reformatted synthetic dataset generation + corrections
metric-space Nov 21, 2023
3efa61b
More formatting for the synth-gen script
metric-space Nov 21, 2023
1d39491
More corrections to synth-gen
metric-space Nov 21, 2023
3abc635
Regex-gen changes and banner addition
metric-space Nov 21, 2023
980cc6a
Missing comma
metric-space Nov 21, 2023
1868423
Type hint all the functions in utils
metric-space Nov 21, 2023
ce9b352
Lightly refactor the pipeline
metric-space Nov 21, 2023
3760fc5
Address proper negation of lags
metric-space Nov 21, 2023
fe85703
Switch out generator for iterator when type hinting
metric-space Nov 21, 2023
331eaed
Util typing correction
metric-space Nov 21, 2023
8c1a35a
Correct all linting issues
metric-space Nov 22, 2023
e77d273
Pipeline corrections
metric-space Nov 22, 2023
2f91c25
More corrections for output type of generator
metric-space Nov 22, 2023
3b8e270
More corrections to the pipeline
metric-space Nov 22, 2023
a37b5d3
Appeasing the linter for the pipeline code
metric-space Nov 22, 2023
280e2dd
Appeasing the linter for llm synth script
metric-space Nov 22, 2023
3facf92
Appeasing the linter for the training script
metric-space Nov 22, 2023
3f50761
Linter based corrections for utils
metric-space Nov 22, 2023
8bb10c4
More appeasing of the linter and work arounds
metric-space Nov 22, 2023
42330eb
Incorporate csv reading and associated changes
metric-space Nov 22, 2023
194cf91
Unicode decoding revisit
metric-space Nov 22, 2023
451015b
More fixes
metric-space Nov 22, 2023
2e3fd08
Forgot to put in replace line
metric-space Nov 22, 2023
9a430e5
Better logging and removal of statefile and more corrections
metric-space Nov 23, 2023
3eb3b44
Add missing general spm input validation line to pipeline script
metric-space Nov 23, 2023
14f65c5
More validation lines for pipeline
metric-space Nov 23, 2023
55f03ef
More corrections
metric-space Nov 23, 2023
4d5802d
Banner correction, corrections
metric-space Nov 23, 2023
3f17238
Start of README.md and add general sentencepiece model to resources
metric-space Nov 23, 2023
88a74c1
Add defaults for cli args
metric-space Nov 23, 2023
4ab1211
Add more detail to README.md
metric-space Nov 23, 2023
b6fa0ae
Add defaults to function
metric-space Nov 23, 2023
644c294
Defaults
metric-space Nov 23, 2023
9c2f00d
README.md for rc pipeline
metric-space Nov 23, 2023
1d4ec97
transformers version dependency constraint
metric-space Nov 23, 2023
9171734
alpha -> beta
metric-space Nov 23, 2023
1af81cd
Better warning message
metric-space Nov 23, 2023
3039050
Correct description in README.md
metric-space Nov 23, 2023
4d93d4a
Stream arg correction for trainer
metric-space Nov 23, 2023
5fb1252
Add prompt link to README
metric-space Nov 23, 2023
958f1f4
Add general spm to resources
metric-space Nov 27, 2023
749f11c
- Better input content generator (deals with directory of csv(s))
metric-space Nov 29, 2023
903ba18
Vocab size second-try and key error fix
metric-space Nov 29, 2023
bc63323
Correct logging
metric-space Nov 29, 2023
a6f0e6e
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
0b89e1d
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
340b969
Update dalm/datasets/reading_comprehension_generation/synthetic_based.py
metric-space Dec 2, 2023
9ae906d
Update dalm/datasets/reading_comprehension_generation/utils.py
metric-space Dec 2, 2023
1d3ed52
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
a7ab91c
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
9cb16ee
Corrections
metric-space Dec 2, 2023
4025e45
Post linting
metric-space Dec 2, 2023
39b7f1e
Update README with suggested corrections
metric-space Dec 2, 2023
81a2ea0
grammar corrections
metric-space Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove (explicit) use of ConstantLengthDataset
  • Loading branch information
metric-space committed Nov 21, 2023
commit 0f7b7e02fa7d795d22d0613d58536f0e20d0d3aa
41 changes: 11 additions & 30 deletions dalm/training/generator_only/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think trl needs to get added to the project.toml dependencies

from trl.trainer import ConstantLengthDataset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add import wandb here? This will fail fast when wandb is not installed, rather than failing in the middle of a pipeline with this stacktrace:

  File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 180, in pipeline
    train_generator(
  File "/opt/conda/lib/python3.10/site-packages/dalm/training/generator_only/trainer.py", line 240, in train_generator
    trainer = SFTTrainer(
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 252, in __init__
    super().__init__(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 525, in __init__
    self.callback_handler = CallbackHandler(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 306, in __init__
    self.add_callback(cb)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 323, in add_callback
    cb = callback() if isinstance(callback, type) else callback
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 669, in __init__
    raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
RuntimeError: WandbCallback requires wandb to be installed. Run `pip install wandb`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given the number of options that exists for tracker, separating ourselves from a certain tracker may be the best way forward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, can we just default to no tracking? Then users can add it as needed, including libs

import argparse

Expand All @@ -20,17 +19,13 @@ def create_datasets(
size_valid_set: Optional[int],
streaming: bool,
shuffle_buffer: int,
seq_length: int,
num_workers: int,
tokenizer: AutoTokenizer,
formatting_func: callable,
local_dataset: bool = True,
):
if local_dataset:
dataset = load_from_disk(
dataset_name,
)
streaming = False
else:
dataset = load_dataset(
dataset_name,
Expand All @@ -49,34 +44,15 @@ def create_datasets(
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")

chars_per_token = chars_token_ratio(train_data, tokenizer, formatting_func)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
return train_data, valid_data

train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
formatting_func=formatting_func,
infinite=True,
seq_length=seq_length,
chars_per_token=chars_per_token,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
formatting_func=formatting_func,
infinite=False,
seq_length=seq_length,
chars_per_token=chars_per_token,
)
return train_dataset, valid_dataset


def chars_token_ratio(dataset, tokenizer, formatting_func, nb_examples=400):
def chars_token_ratio(dataset, tokenizer, formatting_func, sample_size=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
for _, example in tqdm(zip(range(sample_size), iter(dataset)), total=sample_size):
text = formatting_func(example)
total_characters += len(text)
if tokenizer.is_fast:
Expand All @@ -91,9 +67,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be zephyr-7b-beta?

parser.add_argument("--log_with", type=str, default="wandb", help="use 'wandb' to log with wandb")
parser.add_argument(
"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name"
)
parser.add_argument("--dataset_name", type=str, required=True, help="The dataset name either corresponding repo or Format should be that of ChatML")
parser.add_argument("--split", type=str, default="train", help="the split to use")
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems easier for users to express this in terms of percentage of training set rather than absolute size. Eg, 20%

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this argument for streaming, I believe in this scenario a preset size is required

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it, for streaming we don't know the full size of the dataset until we load it. I think its confusing and misleading though to advertise this parameter without mentioning the fact that its streaming only.

WDYT of changing to this?

    parser.add_argument(
      "--size_valid_set_streaming", 
      type=int, 
      default=4000, 
      help="the size of the validation set when used in streaming mode, ignored otherwise"
    )

parser.add_argument("--streaming", type=bool, default=False, help="whether to stream the dataset")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to just default to True to keep memory footprint low by default? What are the downsides of streaming the dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested out the advantages of streaming over just vanilla loading. So hard for me to go for set that as a default for the user

Expand Down Expand Up @@ -189,6 +163,7 @@ def train_generator(
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
per_device_eval_batch_size=per_device_eval_batch_size,
learning_rate=learning_rate,
logging_steps=logging_steps,
Expand All @@ -204,6 +179,8 @@ def train_generator(
bf16=True,
remove_unused_columns=False,
run_name="generator_tuning",
weight_decay=weight_decay,
log_freq=log_freq,
)

def prepare_sample_text(example):
Expand All @@ -223,6 +200,9 @@ def prepare_sample_text(example):
prepare_sample_text,
)

chars_per_token = chars_token_ratio(train_dataset, tokenizer, prepare_sample_text)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")

tleyden marked this conversation as resolved.
Show resolved Hide resolved
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
Expand All @@ -233,6 +213,7 @@ def prepare_sample_text(example):
tokenizer=tokenizer,
args=training_args,
neftune_noise_alpha=neftune_noise_alpha,
chars_per_token=chars_per_token,
)

trainer.train()
Expand Down