This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 735
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add infer-retrieve-rerank pack (#911)
- Loading branch information
Showing
9 changed files
with
1,269 additions
and
403 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
llama_hub/llama_packs/research/infer_retrieve_rerank/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Infer-Retrieve-Rerank LlamaPack | ||
|
||
This is our implementation of the paper ["In-Context Learning for Extreme Multi-Label Classification](https://arxiv.org/pdf/2401.12178.pdf) by Oosterlinck et al. | ||
|
||
The paper proposes "infer-retrieve-rerank", a simple paradigm using frozen LLM/retriever models that can do "extreme"-label classification (the label space is huge). | ||
1. Given a user query, use an LLM to predict an initial set of labels. | ||
2. For each prediction, retrieve the actual label from the corpus. | ||
3. Given the final set of labels, rerank them using an LLM. | ||
|
||
All of these can be implemented as LlamaIndex abstractions. | ||
|
||
A full notebook guide can be found [here](https://github.com/run-llama/llama-hub/blob/main/llama_hub/llama_packs/research/infer_retrieve_rerank/infer_retrieve_rerank.ipynb). | ||
|
||
## CLI Usage | ||
|
||
You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package: | ||
|
||
```bash | ||
llamaindex-cli download-llamapack InferRetrieveRerankPack --download-dir ./infer_retrieve_rerank_pack | ||
``` | ||
|
||
You can then inspect the files at `./infer_retrieve_rerank_pack` and use them as a template for your own project! | ||
|
||
## Code Usage | ||
|
||
You can download the pack to a `./infer_retrieve_rerank_pack` directory: | ||
|
||
```python | ||
from llama_index.llama_pack import download_llama_pack | ||
|
||
# download and install dependencies | ||
InferRetrieveRerankPack = download_llama_pack( | ||
"InferRetrieveRerankPack", "./infer_retrieve_rerank_pack" | ||
) | ||
``` | ||
|
||
From here, you can use the pack, or inspect and modify the pack in `./infer_retrieve_rerank_pack`. | ||
|
||
Then, you can set up the pack like so: | ||
|
||
```python | ||
# create the pack | ||
pack = InferRetrieveRerankPack( | ||
labels, # list of all label strings | ||
llm=llm, | ||
pred_context="<pred_context>", | ||
reranker_top_n=3, | ||
verbose=True | ||
) | ||
|
||
``` | ||
|
||
|
||
The `run()` function runs predictions. | ||
|
||
```python | ||
pred_reactions = pack.run(inputs=[s["text"] for s in samples]) | ||
``` | ||
|
||
You can also use modules individually. | ||
|
||
```python | ||
# call the llm.complete() | ||
llm = pack.llm | ||
label_retriever = pack.label_retriever | ||
``` |
Empty file.
157 changes: 157 additions & 0 deletions
157
llama_hub/llama_packs/research/infer_retrieve_rerank/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
"""Infer-Retrieve-Rerank Pack. | ||
Taken from this paper: https://arxiv.org/pdf/2401.12178.pdf. | ||
""" | ||
|
||
|
||
from typing import Any, Dict, Optional | ||
|
||
from llama_index.llama_pack.base import BaseLlamaPack | ||
from llama_index.schema import TextNode | ||
from llama_index.embeddings import OpenAIEmbedding | ||
from llama_index.ingestion import IngestionPipeline | ||
from llama_index import VectorStoreIndex | ||
from llama_index.retrievers import BaseRetriever | ||
from llama_index.llms.llm import LLM | ||
from llama_index.llms import OpenAI | ||
from llama_index.prompts import PromptTemplate | ||
from llama_index.query_pipeline import QueryPipeline | ||
from llama_index.postprocessor.rankGPT_rerank import RankGPTRerank | ||
from llama_index.output_parsers import ChainableOutputParser | ||
from typing import List | ||
|
||
INFER_PROMPT_STR = """\ | ||
Your job is to output a list of predictions given context from a given piece of text. The text context, | ||
and information regarding the set of valid predictions is given below. | ||
Return the predictions as a comma-separated list of strings. | ||
Text Context: | ||
{doc_context} | ||
Prediction Info: | ||
{pred_context} | ||
Predictions: """ | ||
|
||
INFER_PROMPT_TMPL = PromptTemplate(INFER_PROMPT_STR) | ||
|
||
|
||
class PredsOutputParser(ChainableOutputParser): | ||
"""Predictions output parser.""" | ||
|
||
def parse(self, output: str) -> List[str]: | ||
"""Parse predictions.""" | ||
tokens = output.split(",") | ||
return [t.strip() for t in tokens] | ||
|
||
|
||
preds_output_parser = PredsOutputParser() | ||
|
||
|
||
RERANK_PROMPT_STR = """\ | ||
Given a piece of text, rank the {num} labels above based on their relevance \ | ||
to this piece of text. The labels \ | ||
should be listed in descending order using identifiers. \ | ||
The most relevant labels should be listed first. \ | ||
The output format should be [] > [], e.g., [1] > [2]. \ | ||
Only response the ranking results, \ | ||
do not say any word or explain. \ | ||
Here is a given piece of text: {query}. | ||
""" | ||
RERANK_PROMPT_TMPL = PromptTemplate(RERANK_PROMPT_STR) | ||
|
||
|
||
def infer_retrieve_rerank( | ||
query: str, | ||
retriever: BaseRetriever, | ||
llm: LLM, | ||
pred_context: str, | ||
infer_prompt: PromptTemplate, | ||
rerank_prompt: PromptTemplate, | ||
reranker_top_n: int = 3, | ||
) -> List[str]: | ||
"""Infer retrieve rerank.""" | ||
infer_prompt_c = infer_prompt.as_query_component( | ||
partial={"pred_context": pred_context} | ||
) | ||
infer_pipeline = QueryPipeline(chain=[infer_prompt_c, llm, preds_output_parser]) | ||
preds = infer_pipeline.run(query) | ||
|
||
all_nodes = [] | ||
for pred in preds: | ||
nodes = retriever.retrieve(str(pred)) | ||
all_nodes.extend(nodes) | ||
|
||
reranker = RankGPTRerank( | ||
llm=llm, | ||
top_n=reranker_top_n, | ||
rankgpt_rerank_prompt=rerank_prompt, | ||
# verbose=True, | ||
) | ||
reranked_nodes = reranker.postprocess_nodes(all_nodes, query_str=query) | ||
return [n.get_content() for n in reranked_nodes] | ||
|
||
|
||
class InferRetrieveRerankPack(BaseLlamaPack): | ||
"""Infer Retrieve Rerank pack.""" | ||
|
||
def __init__( | ||
self, | ||
labels: List[str], | ||
llm: Optional[LLM] = None, | ||
pred_context: str = "", | ||
reranker_top_n: int = 3, | ||
infer_prompt: Optional[PromptTemplate] = None, | ||
rerank_prompt: Optional[PromptTemplate] = None, | ||
verbose: bool = False, | ||
) -> None: | ||
"""Init params.""" | ||
# NOTE: we use 16k model by default to fit longer contexts | ||
self.llm = llm or OpenAI(model="gpt-3.5-turbo-16k") | ||
label_nodes = [TextNode(text=label) for label in labels] | ||
pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()]) | ||
label_nodes_w_embed = pipeline.run(documents=label_nodes) | ||
|
||
index = VectorStoreIndex(label_nodes_w_embed, show_progress=verbose) | ||
self.label_retriever = index.as_retriever(similarity_top_k=2) | ||
self.pred_context = pred_context | ||
self.reranker_top_n = reranker_top_n | ||
self.verbose = verbose | ||
|
||
self.infer_prompt = infer_prompt or INFER_PROMPT_TMPL | ||
self.rerank_prompt = rerank_prompt or RERANK_PROMPT_TMPL | ||
|
||
def get_modules(self) -> Dict[str, Any]: | ||
"""Get modules.""" | ||
return { | ||
"llm": self.llm, | ||
"label_retriever": self.label_retriever, | ||
} | ||
|
||
def run(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Run the pipeline.""" | ||
inputs = kwargs.get("inputs", []) | ||
pred_reactions = [] | ||
for idx, input in enumerate(inputs): | ||
if self.verbose: | ||
print(f"\n\n> Generating predictions for input {idx}: {input[:300]}") | ||
cur_pred_reactions = infer_retrieve_rerank( | ||
input, | ||
self.label_retriever, | ||
self.llm, | ||
self.pred_context, | ||
self.infer_prompt, | ||
self.rerank_prompt, | ||
reranker_top_n=self.reranker_top_n, | ||
) | ||
if self.verbose: | ||
print(f"> Generated predictions: {cur_pred_reactions}") | ||
|
||
pred_reactions.append(cur_pred_reactions) | ||
|
||
return pred_reactions |
Oops, something went wrong.