diff --git a/flashrag/generator/generator.py b/flashrag/generator/generator.py index 69d9e3c..dab1f31 100644 --- a/flashrag/generator/generator.py +++ b/flashrag/generator/generator.py @@ -53,11 +53,15 @@ def __init__(self, config): self.model = FiDT5.from_pretrained(self.model_path) else: - self.model = T5ForConditionalGeneration.from_pretrained(self.model_path) + self.model = T5ForConditionalGeneration.from_pretrained( + self.model_path + ) else: if self.fid: assert False, "FiD only support T5" - self.model = BartForConditionalGeneration.from_pretrained(self.model_path) + self.model = BartForConditionalGeneration.from_pretrained( + self.model_path + ) self.model.cuda() self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) @@ -95,10 +99,18 @@ def generate(self, input_list: List, batch_size=None, **params): from flashrag.generator.stop_word_criteria import StopWordCriteria stop_sym = generation_params.pop("stop") - stopping_criteria = [StopWordCriteria(tokenizer=self.tokenizer, prompts=input_list, stop_words=stop_sym)] + stopping_criteria = [ + StopWordCriteria( + tokenizer=self.tokenizer, + prompts=input_list, + stop_words=stop_sym, + ) + ] generation_params["stopping_criteria"] = stopping_criteria - max_tokens = params.pop("max_tokens", None) or params.pop("max_new_tokens", None) + max_tokens = params.pop("max_tokens", None) or params.pop( + "max_new_tokens", None + ) if max_tokens is not None: generation_params["max_new_tokens"] = max_tokens else: @@ -108,21 +120,36 @@ def generate(self, input_list: List, batch_size=None, **params): generation_params.pop("max_tokens", None) responses = [] - for idx in trange(0, len(input_list), batch_size, desc="Generation process: "): + for idx in trange( + 0, len(input_list), batch_size, desc="Generation process: " + ): batched_prompts = input_list[idx : idx + batch_size] if self.fid: # assume each input in input_list is a list, contains K string - input_ids, attention_mask = self.encode_passages(batched_prompts) - inputs = {"input_ids": input_ids.to(self.device), "attention_mask": attention_mask.to(self.device)} + input_ids, attention_mask = self.encode_passages( + batched_prompts + ) + inputs = { + "input_ids": input_ids.to(self.device), + "attention_mask": attention_mask.to(self.device), + } else: inputs = self.tokenizer( - batched_prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_input_len + batched_prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_input_len, ).to(self.device) # TODO: multi-gpu inference outputs = self.model.generate(**inputs, **generation_params) - outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) + outputs = self.tokenizer.batch_decode( + outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) responses += outputs @@ -146,7 +173,11 @@ def __init__(self, config): else: tensor_parallel_size = self.gpu_num - self.lora_path = None if "generator_lora_path" not in config else config["generator_lora_path"] + self.lora_path = ( + None + if "generator_lora_path" not in config + else config["generator_lora_path"] + ) self.use_lora = False if self.lora_path is not None: self.use_lora = True @@ -166,10 +197,18 @@ def __init__(self, config): gpu_memory_utilization=gpu_memory_utilization, max_logprobs=32016, ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) @torch.inference_mode(mode=True) - def generate(self, input_list: List[str], return_raw_output=False, return_scores=False, **params): + def generate( + self, + input_list: List[str], + return_raw_output=False, + return_scores=False, + **params, + ): from vllm import SamplingParams if isinstance(input_list, str): @@ -180,7 +219,9 @@ def generate(self, input_list: List[str], return_raw_output=False, return_scores if "do_sample" in generation_params: generation_params.pop("do_sample") - max_tokens = params.pop("max_tokens", None) or params.pop("max_new_tokens", None) + max_tokens = params.pop("max_tokens", None) or params.pop( + "max_new_tokens", None + ) if max_tokens is not None: generation_params["max_tokens"] = max_tokens else: @@ -205,7 +246,9 @@ def generate(self, input_list: List[str], return_raw_output=False, return_scores from vllm.lora.request import LoRARequest outputs = self.model.generate( - input_list, sampling_params, lora_request=LoRARequest("lora_module", 1, self.lora_path) + input_list, + sampling_params, + lora_request=LoRARequest("lora_module", 1, self.lora_path), ) else: outputs = self.model.generate(input_list, sampling_params) @@ -219,7 +262,12 @@ def generate(self, input_list: List[str], return_raw_output=False, return_scores scores = [] for output in outputs: logprobs = output.outputs[0].logprobs - scores.append([np.exp(list(score_dict.values())[0].logprob) for score_dict in logprobs]) + scores.append( + [ + np.exp(list(score_dict.values())[0].logprob) + for score_dict in logprobs + ] + ) return base_output, scores else: return base_output @@ -231,7 +279,11 @@ class HFCausalLMGenerator(BaseGenerator): def __init__(self, config, model=None): super().__init__(config) self.config = config - lora_path = None if "generator_lora_path" not in config else config["generator_lora_path"] + lora_path = ( + None + if "generator_lora_path" not in config + else config["generator_lora_path"] + ) self.model, self.tokenizer = self._load_model(model=model) self.use_lora = False if lora_path is not None: @@ -242,20 +294,66 @@ def _load_model(self, model=None): r"""Load model and tokenizer for generator.""" if model is None: model = AutoModelForCausalLM.from_pretrained( - self.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True + self.model_path, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, ) else: model.cuda() model.eval() - tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) if "qwen" not in self.model_name: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" return model, tokenizer + def add_new_tokens( + self, token_embedding_path, token_name_func=lambda idx: f"[ref{idx+1}]" + ): + # get original embedding weight matrix + embedding_layer = self.model.get_input_embeddings() + embedding_weights = embedding_layer.weight + original_vocab_size, embedding_dim = embedding_weights.shape + + new_tokens_weights = torch.load( + token_embedding_path, map_location=embedding_layer.weight.device + ) + new_tokens_length = new_tokens_weights.shape[0] + + # expand vocabulary + new_tokens = [token_name_func(idx) for idx in range(new_tokens_length)] + self.tokenizer.add_tokens(new_tokens) + + # create new embedding matrix + new_vocab_size = original_vocab_size + new_tokens_length + new_embedding_weights = torch.zeros(new_vocab_size, embedding_dim) + + # copy original embeddings to the new weights + new_embedding_weights[:original_vocab_size, :] = embedding_weights + + # append virtual token embeddings to the new weights + for token, embedding in zip(new_tokens, new_tokens_weights): + token_id = self.tokenizer.convert_tokens_to_ids(token) + new_embedding_weights[token_id] = embedding + + # update the embedding table + # note: we should avoid using the function resize_token_embeddings() because this function will also change the lm_head of the model + embedding_layer.weight.data = new_embedding_weights + self.model.cuda() + @torch.inference_mode(mode=True) - def generate(self, input_list: List[str], batch_size=None, return_scores=False, return_dict=False, **params): + def generate( + self, + input_list: List[str], + batch_size=None, + return_scores=False, + return_dict=False, + **params, + ): """Generate batches one by one. The generated content needs to exclude input.""" if isinstance(input_list, str): @@ -272,10 +370,18 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False, from flashrag.generator.stop_word_criteria import StopWordCriteria stop_sym = generation_params.pop("stop") - stopping_criteria = [StopWordCriteria(tokenizer=self.tokenizer, prompts=input_list, stop_words=stop_sym)] + stopping_criteria = [ + StopWordCriteria( + tokenizer=self.tokenizer, + prompts=input_list, + stop_words=stop_sym, + ) + ] generation_params["stopping_criteria"] = stopping_criteria - max_tokens = params.pop("max_tokens", None) or params.pop("max_new_tokens", None) + max_tokens = params.pop("max_tokens", None) or params.pop( + "max_new_tokens", None + ) if max_tokens is not None: generation_params["max_new_tokens"] = max_tokens else: @@ -286,7 +392,10 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False, # set eos token for llama if "llama" in self.model_name.lower(): - extra_eos_tokens = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")] + extra_eos_tokens = [ + self.tokenizer.eos_token_id, + self.tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ] if "eos_token_id" in generation_params: generation_params["eos_token_id"].extend(extra_eos_tokens) else: @@ -297,41 +406,79 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False, generated_token_ids = [] generated_token_logits = [] - for idx in trange(0, len(input_list), batch_size, desc="Generation process: "): + for idx in trange( + 0, len(input_list), batch_size, desc="Generation process: " + ): torch.cuda.empty_cache() batched_prompts = input_list[idx : idx + batch_size] inputs = self.tokenizer( - batched_prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_input_len + batched_prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_input_len, ).to(self.model.device) outputs = self.model.generate( - **inputs, output_scores=True, return_dict_in_generate=True, **generation_params + **inputs, + output_scores=True, + return_dict_in_generate=True, + **generation_params, ) generated_ids = outputs.sequences logits = torch.stack(outputs.scores, dim=1).softmax(-1) generated_ids = generated_ids[:, inputs["input_ids"].shape[-1] :] - gen_score = torch.gather(logits, 2, generated_ids[:, :, None]).squeeze(-1).cpu().tolist() + gen_score = ( + torch.gather(logits, 2, generated_ids[:, :, None]) + .squeeze(-1) + .cpu() + .tolist() + ) scores.extend(gen_score) # get additinoal info if return_dict: batch_generated_token_ids = generated_ids.detach().cpu() batch_generated_token_logits = ( - torch.cat([token_scores.unsqueeze(1) for token_scores in outputs.scores], dim=1).detach().cpu() + torch.cat( + [ + token_scores.unsqueeze(1) + for token_scores in outputs.scores + ], + dim=1, + ) + .detach() + .cpu() ) - if batch_generated_token_ids.shape[1] < generation_params["max_new_tokens"]: - real_batch_size, num_generated_tokens = batch_generated_token_ids.shape - padding_length = generation_params["max_new_tokens"] - num_generated_tokens + if ( + batch_generated_token_ids.shape[1] + < generation_params["max_new_tokens"] + ): + real_batch_size, num_generated_tokens = ( + batch_generated_token_ids.shape + ) + padding_length = ( + generation_params["max_new_tokens"] + - num_generated_tokens + ) padding_token_ids = torch.zeros( - (real_batch_size, padding_length), dtype=batch_generated_token_ids.dtype + (real_batch_size, padding_length), + dtype=batch_generated_token_ids.dtype, ).fill_(self.tokenizer.pad_token_id) padding_token_logits = torch.zeros( - (real_batch_size, padding_length, batch_generated_token_logits.shape[-1]), + ( + real_batch_size, + padding_length, + batch_generated_token_logits.shape[-1], + ), dtype=batch_generated_token_logits.dtype, ) - batch_generated_token_ids = torch.cat([batch_generated_token_ids, padding_token_ids], dim=1) + batch_generated_token_ids = torch.cat( + [batch_generated_token_ids, padding_token_ids], dim=1 + ) batch_generated_token_logits = torch.cat( - [batch_generated_token_logits, padding_token_logits], dim=1 + [batch_generated_token_logits, padding_token_logits], + dim=1, ) generated_token_ids.append(batch_generated_token_ids) generated_token_logits.append(batch_generated_token_logits) @@ -339,7 +486,9 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False, for i, generated_sequence in enumerate(outputs.sequences): input_ids = inputs["input_ids"][i] text = self.tokenizer.decode( - generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=False + generated_sequence, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) if input_ids is None: prompt_length = 0 @@ -413,7 +562,11 @@ def _load_model(self, model=None): def get_gpu_memory(max_gpus=None): """Get available memory for each GPU.""" gpu_memory = [] - num_gpus = torch.cuda.device_count() if max_gpus is None else min(max_gpus, torch.cuda.device_count()) + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) for gpu_id in range(num_gpus): with torch.cuda.device(gpu_id): device = torch.cuda.current_device() @@ -434,7 +587,10 @@ def get_gpu_memory(max_gpus=None): max_gpu_memory = None if self.gpu_num != 1: available_gpu_memory = get_gpu_memory(self.gpu_num) - max_gpu_memory = str(int(min(available_gpu_memory) * gpu_memory_utilization)) + "GiB" + max_gpu_memory = ( + str(int(min(available_gpu_memory) * gpu_memory_utilization)) + + "GiB" + ) model, tokenizer = load_model( self.model_path, @@ -448,10 +604,14 @@ def get_gpu_memory(max_gpus=None): else: model.cuda() - tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) model.eval() - tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) if "qwen" not in self.model_name: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" diff --git a/flashrag/pipeline/pipeline.py b/flashrag/pipeline/pipeline.py index 8ea5895..cc58761 100644 --- a/flashrag/pipeline/pipeline.py +++ b/flashrag/pipeline/pipeline.py @@ -45,14 +45,19 @@ def evaluate(self, dataset, do_eval=True, pred_process_fun=None): class SequentialPipeline(BasicPipeline): - def __init__(self, config, prompt_template=None): + def __init__( + self, config, prompt_template=None, retriever=None, generator=None + ): """ inference stage: query -> pre-retrieval -> retriever -> post-retrieval -> generator """ super().__init__(config, prompt_template) - self.retriever = get_retriever(config) + if retriever is None: + self.retriever = get_retriever(config) + else: + self.retriever = retriever # TODO: add rewriter module @@ -60,22 +65,33 @@ def __init__(self, config, prompt_template=None): self.generator = None if config["refiner_name"] is not None: - if "kg" in config["refiner_name"].lower(): - self.generator = get_generator(config) self.refiner = get_refiner(config, self.retriever, self.generator) + + # For refiners other than kg, do not load the generator for now to save memory + if "kg" in config["refiner_name"].lower(): + self.generator = ( + get_generator(config) if generator is None else generator + ) else: self.refiner = None - self.generator = get_generator(config) + self.generator = self.generator = ( + get_generator(config) if generator is None else generator + ) def naive_run(self, dataset, do_eval=True, pred_process_fun=None): # direct generation without RAG - input_prompts = [self.prompt_template.get_string(question=q) for q in dataset.question] + input_prompts = [ + self.prompt_template.get_string(question=q) + for q in dataset.question + ] dataset.update_output("prompt", input_prompts) pred_answer_list = self.generator.generate(input_prompts) dataset.update_output("pred", pred_answer_list) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset def run(self, dataset, do_eval=True, pred_process_fun=None): @@ -89,7 +105,9 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): if "llmlingua" in self.refiner.name and input_prompt_flag: # input prompt input_prompts = [ - self.prompt_template.get_string(question=q, retrieval_result=r) + self.prompt_template.get_string( + question=q, retrieval_result=r + ) for q, r in zip(dataset.question, dataset.retrieval_result) ] dataset.update_output("prompt", input_prompts) @@ -99,7 +117,9 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): refine_results = self.refiner.batch_run(dataset) dataset.update_output("refine_result", refine_results) input_prompts = [ - self.prompt_template.get_string(question=q, formatted_reference=r) + self.prompt_template.get_string( + question=q, formatted_reference=r + ) for q, r in zip(dataset.question, refine_results) ] @@ -125,7 +145,9 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): pred_answer_list = self.generator.generate(input_prompts) dataset.update_output("pred", pred_answer_list) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset @@ -160,11 +182,15 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): pos_dataset = self.sequential_pipeline.run(pos_dataset, do_eval=False) self.sequential_pipeline.prompt_template = self.zero_shot_templete - neg_dataset = self.sequential_pipeline.naive_run(neg_dataset, do_eval=False) + neg_dataset = self.sequential_pipeline.naive_run( + neg_dataset, do_eval=False + ) # merge datasets into original format dataset = merge_dataset(pos_dataset, neg_dataset, judge_result) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset