Skip to content

Commit

Permalink
Update train_alpaca.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jianzhnie committed May 28, 2023
1 parent 22068fc commit 6d164c2
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions examples/alpaca/train_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ class SupervisedDataset(Dataset):
}
IGNORE_INDEX = -100

def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer):
def __init__(self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_length: int = 1024):
"""
Initializes a SupervisedDataset object.
Expand All @@ -115,14 +118,14 @@ def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer):
"""
super(SupervisedDataset, self).__init__()
logging.warning('Loading data...')
logging.warning(f'Loading dataset from {data_path}')
if data_path.endswith('.json') or data_path.endswith('.jsonl'):
list_data_dict = load_dataset('json',
data_files=data_path)['train']
else:
list_data_dict = load_dataset(data_path)['train']

logging.warning('Formatting inputs...')
logging.warning('Found %d rows', list_data_dict.num_rows)
prompt_input, prompt_no_input = self.PROMPT_DICT[
'prompt_input'], self.PROMPT_DICT['prompt_no_input']
self.sources = [
Expand All @@ -137,6 +140,7 @@ def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer):

self.examples = [s + t for s, t in zip(self.sources, self.targets)]
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self) -> int:
"""
Expand Down Expand Up @@ -165,14 +169,14 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
example_tokenized = self.tokenizer(
example_txt,
padding='longest',
max_length=self.tokenizer.model_max_length,
max_length=self.max_length,
truncation=True,
)
source_txt = self.sources[idx]
source_tokenized = self.tokenizer(
source_txt,
padding='longest',
max_length=self.tokenizer.model_max_length,
max_length=self.max_length,
truncation=True,
)
# Extract the input_ids tensor
Expand Down Expand Up @@ -233,12 +237,12 @@ def train() -> None:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side='right',
use_fast=False,
)

# Resize the tokenizer's vocabulary size to accommodate additional special tokens, if necessary
tokenizer.pad_token = tokenizer.eos_token
special_tokens_dict = {}
if tokenizer.pad_token is None:
special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN
Expand All @@ -249,21 +253,41 @@ def train() -> None:
if tokenizer.unk_token is None:
special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN

special_tokens_dict['additional_special_tokens'] = [
'### Instruction:',
'### Response:\n',
'### End',
]

if len(special_tokens_dict) > 0:
smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,
model)

max_length = None
for length_setting in [
'n_positions', 'max_position_embeddings', 'seq_length'
]:
max_length = getattr(model.config, length_setting, None)
if max_length:
logging.warning(f'Found max lenth: {max_length}')
break
if not max_length:
max_length = 1024
logging.warning(f'Using default max length: {max_length}')

# Create the training dataset and data collator
train_dataset = SupervisedDataset(
data_path=data_args.data_path,
tokenizer=tokenizer,
max_length=max_length,
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

model.is_parallelizable = True
model.model_parallel = True

# Initialize the Trainer object and start training
logging.warning('Instantiating Trainer')
trainer = Trainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -273,18 +297,22 @@ def train() -> None:
data_collator=data_collator,
)
model.config.use_cache = False
logging.warning('Training')

if training_args.resume_from_checkpoint and list(
pathlib.Path(training_args.output_dir).glob('checkpoint-*')):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()

logging.warning(f'Saving Model to {training_args.output_dir}')
trainer.save_state()
# Save the trained model
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)

logging.warning('Done.')


if __name__ == '__main__':
train()

0 comments on commit 6d164c2

Please sign in to comment.