diff --git a/examples/training/swallow-tart/args.py b/examples/training/swallow-tart/args.py new file mode 100644 index 000000000..d7d2767dc --- /dev/null +++ b/examples/training/swallow-tart/args.py @@ -0,0 +1,37 @@ +import json +from dataclasses import dataclass, field +from typing import Optional + +from peft import get_peft_config +from transformers import TrainingArguments as STTrainingArguments + +__all__ = ["STModelArguments", "STDataArgumnets", "STTrainingArguments"] + + +@dataclass +class STModelArguments: + model_name: str = "bert-base-uncased" + peft_config_path: Optional[str] = None + use_flash_attention: bool = False + + def __post_init__(self): + if self.peft_config_path is not None: + with open(self.peft_config_path, "r") as f: + peft_config_data = json.load(f) + self.peft_config = get_peft_config(peft_config_data) + else: + self.peft_config = None + + +@dataclass +class STDataArgumnets: + data_dir: str + hf_dataset_dir: str + task_names: list[str] = field(default_factory=list) + max_length: int = 512 + n_dev_sample: int = 100 + query_file_name: str = "tuple_beir/queries.jsonl" + corpus_file_name: str = "tuple_beir/corpus.jsonl" + qrel_file_name: str = "tuple_beir/qrels/train.tsv" + hard_negatives_file_name: str = "negatives/hard_negative.jsonl" + num_proc: int = 1 diff --git a/examples/training/swallow-tart/configs/ds_config_zero3.json b/examples/training/swallow-tart/configs/ds_config_zero3.json new file mode 100644 index 000000000..ac7eeb18b --- /dev/null +++ b/examples/training/swallow-tart/configs/ds_config_zero3.json @@ -0,0 +1,60 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 10, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 10, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "sub_group_size": 1e9, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": "auto" + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/examples/training/swallow-tart/configs/lora_config.json b/examples/training/swallow-tart/configs/lora_config.json new file mode 100644 index 000000000..26705720e --- /dev/null +++ b/examples/training/swallow-tart/configs/lora_config.json @@ -0,0 +1,27 @@ +{ + "auto_mapping": null, + "base_model_name_or_path": "tokyotech-llm/Swallow-7b-hf", + "bias": "none", + "fan_in_fan_out": false, + "inference_mode": false, + "init_lora_weights": true, + "layers_pattern": null, + "layers_to_transform": null, + "lora_alpha": 256, + "lora_dropout": 0.1, + "modules_to_save": null, + "peft_type": "LORA", + "r": 128, + "revision": null, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "down_proj", + "up_proj", + "gate_proj" + ], + "task_type": "FEATURE_EXTRACTION", + "use_rslora": true +} diff --git a/examples/training/swallow-tart/data.py b/examples/training/swallow-tart/data.py new file mode 100644 index 000000000..39d36722f --- /dev/null +++ b/examples/training/swallow-tart/data.py @@ -0,0 +1,359 @@ +import os +import json +import random +from collections import defaultdict +from pathlib import Path +from typing import Callable, Optional, Tuple + +import datasets +import torch +from datasets import load_from_disk +from sentence_transformers.huggingface import SENTENCE_KEYS +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import PreTrainedTokenizer, BatchEncoding +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class MNRLDataset(Dataset): + # https://github.com/texttron/tevatron/blob/main/examples/repllama/data.py#L162 + def __init__( + self, + dataset: datasets.Dataset, + tokenizer: PreTrainedTokenizer, + max_length: int, + ): + self.train_data = dataset + self.tok = tokenizer + + self.max_length = max_length + self.total_len = len(self.train_data) + + def create_one_example(self, text_encoding: list[int]) -> BatchEncoding: + """Add eos token""" + item = self.tok.prepare_for_model( + text_encoding + [self.tok.eos_token_id], + truncation="only_first", + max_length=self.max_length - 2, # for bos and margin + padding=False, + ) + return item + + def __len__(self): + # Return query size + return self.total_len + + def __getitem__(self, item) -> dict[str, BatchEncoding]: + # https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder_mnrl.py#L215 + group = self.train_data[item] + query_encoding = self.create_one_example(group["query"]) + + target_pos_ids = group["positives"].pop(0) + target_pos_encoding = self.create_one_example(target_pos_ids) + group["positives"].append(target_pos_ids) + + negative_pos_ids = group["negatives"].pop(0) + negative_pos_encoding = self.create_one_example(negative_pos_ids) + group["negatives"].append(negative_pos_ids) + + label = 0 # 学習には使用しないが、引数に指定されている + + anchor_name, pos_name, neg_name = SENTENCE_KEYS + data = { + anchor_name: query_encoding, + pos_name: target_pos_encoding, + neg_name: negative_pos_encoding, + "label": label, + } + return data + + +class TokenizeProcessor: + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_length: int, + ) -> None: + self.tokenizer = tokenizer + self.max_length = max_length + + def __call__(self, example): + query_tokenized = self.tokenizer.encode( + example["query"], + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos, eos and margin + ) + positive_tokenizeds = [] + for positive in example["positives"]: + positive_tokenizeds.append( + self.tokenizer.encode( + positive, + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos and eos + ) + ) + negative_tokenizeds = [] + for negative in example["negatives"]: + negative_tokenizeds.append( + self.tokenizer.encode( + negative, + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos and eos + ) + ) + return {"query": query_tokenized, "positives": positive_tokenizeds, "negatives": negative_tokenizeds} + + +class TokenizeBatchProcessor(TokenizeProcessor): + def __call__(self, examples): + query_tokenized = self.tokenizer( + examples["query"], + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos, eos and margin + )["input_ids"] + positive_tokenizeds = [] + for one_batch in examples["positives"]: + positive_tokenizeds.append( + self.tokenizer( + one_batch, + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos and eos + )["input_ids"] + ) + negative_tokenizeds = [] + for one_batch in examples["negatives"]: + negative_tokenizeds.append( + self.tokenizer( + one_batch, + add_special_tokens=False, + truncation=True, + max_length=self.max_length - 3, # For bos and eos + )["input_ids"] + ) + return {"query": query_tokenized, "positives": positive_tokenizeds, "negatives": negative_tokenizeds} + + +class IRCollator: + def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int): + self.tokenizer = tokenizer + self.max_length = max_length + + def __call__(self, batch: list[dict[str, BatchEncoding]]) -> tuple[list[BatchEncoding], torch.Tensor]: + # this function is based on sentence_transformers.SentenceTransformer.smart_batching_collate + texts = [] + for example in batch: + temp_texts = [] + for key in SENTENCE_KEYS: + temp_texts.append(example[key]) + texts.append(temp_texts) + + transposed_texts = [ + self.tokenizer.pad(sentences, padding="max_length", max_length=self.max_length, return_tensors="pt") + for sentences in zip(*texts) + ] + labels = torch.tensor([example["label"] for example in batch]) + + return transposed_texts, labels + + +def load_queries(queries_path: str) -> dict[str, str]: + queries = {} + with open(queries_path, "r") as f: + for line in f: + data = json.loads(line) + queries[data["_id"]] = data["text"] + return queries + + +def load_corpus(corpus_path: str) -> dict[str, str]: + corpus = {} + with open(corpus_path, "r") as f: + for line in f: + data = json.loads(line) + corpus[data["_id"]] = data["text"] + return corpus + + +def load_qrels(qrels_path: str) -> dict[str, list[int]]: + """Load qrel. + + qrel format: + query_id\tdocument_id\tlabel + """ + qrels = defaultdict(list) + with open(qrels_path, "r") as f: + for idx, line in enumerate(f): + if idx == 0: + continue + data = line.strip().split("\t") + qid = data[0] + did = data[1] + qrels[qid].append(did) + return dict(qrels) + + +def load_hard_negatives(hard_negatives_path: str) -> dict[str, list[int]]: + """Load hard negative. + + hard negative format: + {"query_id": str, "hard_negative": [str, str, ...]} + """ + hard_negative = defaultdict(list) + with open(hard_negatives_path, "r") as f: + for line in f: + data = json.loads(line) + qid = data["query_id"] + hard_negative[qid].extend(data["hard_negative"]) + return dict(hard_negative) + + +def prepare_ir_dataset( + task_names: list[str], + input_data_dir: str, + query_file_name: str, + corpus_file_name: str, + qrel_file_name: str, + hard_negatives_file_name: str, +) -> datasets.Dataset: + # load dataset + # {"query": str, "positives": list[str], "negatives": list[str]} + target_datasets: list[datasets.Dataset] = [] + for task_idx, task_name in enumerate(task_names): + target_path = { + "queries": os.path.join(input_data_dir, task_name, query_file_name), + "corpus": os.path.join(input_data_dir, task_name, corpus_file_name), + "qrels": os.path.join(input_data_dir, task_name, qrel_file_name), + "hard_negatives": os.path.join(input_data_dir, task_name, hard_negatives_file_name), + } + + queries = load_queries(target_path["queries"]) + corpus = load_corpus(target_path["corpus"]) + qrels = load_qrels(target_path["qrels"]) + hard_negatives = load_hard_negatives(target_path["hard_negatives"]) + + logger.info(f"...Task: {task_name}") + current_dataset = [] + for qid, query in tqdm(queries.items()): + if qid not in qrels: + logger.info(f"......qid: {qid} is not included at the qrel. skip this query.") + continue + positive_ids = qrels[qid] + + positives = [] + for pos_id in positive_ids: + if pos_id not in corpus: + continue + positive_text = corpus[pos_id] + if positive_text is not None: + positives.append(corpus[pos_id]) + if len(positives) == 0: + logger.info(f"......qid: {qid} doesn't have positive passage. skip this query.") + continue + random.shuffle(positives) + + if qid not in hard_negatives: + continue + negative_ids = hard_negatives[qid] + + negatives = [] + for neg_id in negative_ids: + if neg_id not in corpus: + continue + negative_text = corpus[neg_id] + if negative_text is not None: + negatives.append(corpus[neg_id]) + if len(negatives) == 0: + logger.info(f"......qid: {qid} doesn't have negative passage. skip this query.") + continue + random.shuffle(negatives) + + current_dataset.append({"query": query, "positives": positives, "negatives": negatives, "label": task_idx}) + + target_datasets.append(datasets.Dataset.from_list(current_dataset)) + + target_concat_dataset = datasets.concatenate_datasets(target_datasets) + return target_concat_dataset + + +def load_ir_dataset( + dataset_path: Path, + task_names: list[str], + input_data_dir: str, + query_file_name: str, + corpus_file_name: str, + qrel_file_name: str, + hard_negatives_file_name: str, + n_each_dev_sample: int, +) -> datasets.Dataset: + if not dataset_path.exists(): + logger.info("Build huggingface datasets.") + hf_dataset = prepare_ir_dataset( + task_names, input_data_dir, query_file_name, corpus_file_name, qrel_file_name, hard_negatives_file_name + ) + logger.info("Split train/dev dataset.") + hf_dataset = hf_dataset.class_encode_column("label") + n_dev_sample = n_each_dev_sample * len(task_names) + hf_dataset = hf_dataset.train_test_split(test_size=n_dev_sample, shuffle=True, stratify_by_column="label") + + logger.info(f"Save DatasetDict to {str(dataset_path)}.") + hf_dataset.save_to_disk(str(dataset_path), max_shard_size="1GB") + + hf_dataset = load_from_disk(dataset_path) + return hf_dataset + + +def get_dataset( + hf_dataset_dir: str, + task_names: list[str], + input_data_dir: str, + query_file_name: str, + corpus_file_name: str, + qrel_file_name: str, + hard_negatives_file_name: str, + tokenizer: PreTrainedTokenizer, + max_length: int, + n_each_dev_sample: int = 0, + process_func: Optional[Callable] = None, + num_proc: int = 1, +) -> Tuple[Dataset, Dataset]: + # build HF Dataset + logger.info("Load huggingface datasets.") + hf_dataset = load_ir_dataset( + Path(hf_dataset_dir), + task_names, + input_data_dir, + query_file_name, + corpus_file_name, + qrel_file_name, + hard_negatives_file_name, + n_each_dev_sample, + ) + + # apply preprocess (mainly tokenization (make word ids)) + logger.info("Apply preprocessing.") + remove_column_names = hf_dataset.column_names["train"].remove("label") + hf_dataset = hf_dataset.map( + process_func, + batched=True, + remove_columns=remove_column_names, + num_proc=num_proc, + desc="Running Tokenizer on dataset", + ) + + # split train/dev dataset + train_dataset = hf_dataset["train"] + dev_dataset = hf_dataset["test"] + logger.info(f"Train dataset size: {len(train_dataset)}") + logger.info(f"Dev dataset size: {len(dev_dataset)}") + + # build Torch Dataset and Return ones. + train_torch_dataset = MNRLDataset(train_dataset, tokenizer, max_length) + dev_torch_dataset = MNRLDataset(dev_dataset, tokenizer, max_length) + return train_torch_dataset, dev_torch_dataset diff --git a/examples/training/swallow-tart/run_train.py b/examples/training/swallow-tart/run_train.py new file mode 100644 index 000000000..ddc9253ba --- /dev/null +++ b/examples/training/swallow-tart/run_train.py @@ -0,0 +1,136 @@ +"""Train embeddings with Sentence-Transformers-HF + +lr: + llm-jp: 2e-5 https://llm-jp.nii.ac.jp/blog/2024/02/09/v1.1-tuning.html#%E3%83%8F%E3%82%A4%E3%83%91%E3%83%BC%E3%83%91%E3%83%A9%E3%83%A1%E3%83%BC%E3%82%BF + repLLaMA: 1e-4 https://llm-jp.nii.ac.jp/blog/2024/02/09/v1.1-tuning.html#%E3%83%8F%E3%82%A4%E3%83%91%E3%83%BC%E3%83%91%E3%83%A9%E3%83%A1%E3%83%BC%E3%82%BF +""" +import os +import sys + +from sentence_transformers import losses +from sentence_transformers.huggingface import ( + MNRLSentenceTransformersTrainer, + MNRLSentenceTransformer, +) +from sentence_transformers.models import Transformer, Pooling, Normalize +from transformers import HfArgumentParser, set_seed +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import logging + +from args import STDataArgumnets, STModelArguments, STTrainingArguments +from data import get_dataset, TokenizeBatchProcessor, IRCollator + +logger = logging.get_logger(__name__) + + +def main(): + parser = HfArgumentParser((STDataArgumnets, STModelArguments, STTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + data_args, model_args, training_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + data_args, model_args, training_args = parser.parse_args_into_dataclasses() + + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + logger.info("MODEL parameters %s", model_args) + + set_seed(training_args.seed) + + # define model + logger.info("Build SentenceTransformer") + if model_args.use_flash_attention: + # validate fp16 or bf16 + assert training_args.fp16 or training_args.bf16, "use_flash_attention requires fp16 or bf16" + model_kwargs = {"attn_implementation": "flash_attention_2"} + else: + model_kwargs = {} + tf_model = Transformer( + model_args.model_name, + model_args=model_kwargs, + peft_config=model_args.peft_config, + is_gradient_checkpointing=training_args.gradient_checkpointing, + ) + pooler = Pooling(tf_model.get_word_embedding_dimension(), pooling_mode="lasttoken") + normalize = Normalize() + model = MNRLSentenceTransformer(modules=[tf_model, pooler, normalize]) + tokenizer = model.tokenizer + # https://github.com/texttron/tevatron/blob/2e5d00ee21d5a7db0bd2ea1463c9150a572106d4/examples/repllama/train.py#L68-L69 + tokenizer.pad_token_id = tokenizer.unk_token_id + tokenizer.pad_token = tokenizer.unk_token + max_length = min(data_args.max_length, tokenizer.model_max_length) + tokenizer.model_max_length = max_length + loss = losses.MultipleNegativesRankingLoss(model=model) + ir_collator = IRCollator(tokenizer, max_length) + + # define train/eval dataset + logger.info("Load dataset") + logger.info(f"Target task names: {data_args.task_names}") + preprocessor = TokenizeBatchProcessor(tokenizer, data_args.max_length) + train_dataset, eval_dataset = get_dataset( + data_args.hf_dataset_dir, + data_args.task_names, + data_args.data_dir, + data_args.query_file_name, + data_args.corpus_file_name, + data_args.qrel_file_name, + data_args.hard_negatives_file_name, + tokenizer, + data_args.max_length, + data_args.n_dev_sample, + preprocessor, + data_args.num_proc, + ) + + trainer = MNRLSentenceTransformersTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=ir_collator, + tokenizer=tokenizer, + loss=loss, + text_columns=[], + ) + + # detecting last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + + logger.info("Start training") + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + +if __name__ == "__main__": + main() diff --git a/install-deepspeed.sh b/install-deepspeed.sh new file mode 100755 index 000000000..35414b901 --- /dev/null +++ b/install-deepspeed.sh @@ -0,0 +1,2 @@ +#!/bin/sh +DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_FUSED_LAMB=1 DS_BUILD_UTILS=1 pip install deepspeed --no-cache-dir diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index feb5975ac..845039ac4 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1163,3 +1163,10 @@ def _target_device(self) -> torch.device: @_target_device.setter def _target_device(self, device: Optional[Union[int, str, torch.device]] = None) -> None: self.to(device) + + @property + def config(self): + return self._first_module().config + + def gradient_checkpointing_enable(self, *args, **kwargs): + return self._first_module().gradient_checkpointing_enable(*args, **kwargs) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index e61b268d6..b979119cd 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -1,3 +1,4 @@ +import torch from torch import nn from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config from peft import PeftConfig, get_peft_model @@ -29,6 +30,7 @@ def __init__( do_lower_case: bool = False, tokenizer_name_or_path: str = None, peft_config: Optional[PeftConfig] = None, + is_gradient_checkpointing: bool = False, ): super(Transformer, self).__init__() self.config_keys = ["max_seq_length", "do_lower_case"] @@ -38,6 +40,13 @@ def __init__( self._load_model(model_name_or_path, config, cache_dir, **model_args) if peft_config is not None: + if is_gradient_checkpointing: + for param in self.auto_model.parameters(): + param.requires_grad = True + if param.ndim == 1: + param.data = param.data.to(torch.float32) + self.auto_model.gradient_checkpointing_enable() + self.auto_model.enable_input_require_grads() self.auto_model = get_peft_model(self.auto_model, peft_config) self.tokenizer = AutoTokenizer.from_pretrained( @@ -190,3 +199,10 @@ def load(input_path: str): if "model_args" in config: config["model_args"].pop("trust_remote_code") return Transformer(model_name_or_path=input_path, **config) + + @property + def config(self): + return self.auto_model.config + + def gradient_checkpointing_enable(self, *args, **kwargs): + return self.auto_model.gradient_checkpointing_enable(*args, **kwargs) diff --git a/setup.py b/setup.py index fbf48ae97..9c342f02b 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "huggingface-hub>=0.15.1", "Pillow", "peft", + "datasets", ], classifiers=[ "Development Status :: 5 - Production/Stable",