Skip to content

Commit

Permalink
Merge pull request #560 from yaoguany/main
Browse files Browse the repository at this point in the history
allow exceeding model maximum length when train&inference
  • Loading branch information
research4pan authored Jul 21, 2023
2 parents 1449cd4 + 78159d0 commit 0105829
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 14 deletions.
8 changes: 8 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ class ModelArguments:
)
}
)
truncate_to_model_max_length: bool = field(
default=True,
metadata={
"help": (
"whether truncate the dataset to model max length."
)
}
)
use_int8: bool = field(
default=False,
metadata={"help": "whether to load int8 quantization for inference"}
Expand Down
4 changes: 3 additions & 1 deletion src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def __init__(
# We resize the embeddings only when necessary to avoid index errors.
# If you are creating a model from scratch on a small vocab and want a
# smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
weights = model.get_input_embeddings().weight
embedding_size = weights.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))

Expand Down
13 changes: 8 additions & 5 deletions src/lmflow/pipeline/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,14 @@ def _evaluate_ppl(self, model, dataset: Dataset, verbose=True):
texts = [ instance["text"] for instance in data_dict["instances"] ]
encodings = model.get_tokenizer()("\n\n".join(texts), return_tensors="pt")
# Define some constant
try:
max_length = min(model.get_backend_model().config.n_positions, model.get_max_length())
except:
max_length = min(1024, model.get_max_length())

if self.model_args.truncate_to_model_max_length:
try:
max_length = min(model.get_backend_model().config.n_positions, model.get_max_length())
except:
max_length = min(1024, model.get_max_length())
else:
max_length = self.block_size

if verbose:
print(f"The maximum sequence length : {max_length}")
seq_len = encodings.input_ids.size(1)
Expand Down
29 changes: 21 additions & 8 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,27 @@ def group_text(self, tokenized_datasets, model_max_length):
block_size = 1024
else:
if data_args.block_size > model_max_length:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger"
f" than the maximum length for the model"
f"({model_max_length})."
f" Using block_size={model_max_length}."
)
block_size = min(data_args.block_size, model_max_length)

if self.model_args.truncate_to_model_max_length:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger"
f" than the maximum length for the model"
f"({model_max_length})."
f" Using block_size={model_max_length}."
f"If you would like to use a longer 'block_size' that is"
f" longer than the maximum length supported by the model,"
f" you can override this behavior with"
f"default with `--truncate_to_model_max_length False`."
)
block_size = model_max_length
else:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger"
f"than the maximum length for the model"
f"({model_max_length})."
f"Using block_size={data_args.block_size}.")
block_size = data_args.block_size
else:
block_size = data_args.block_size
# Main data processing function that will concatenate all texts from
# our dataset and generate chunks of block_size.
def group_texts(examples):
Expand Down

0 comments on commit 0105829

Please sign in to comment.