Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
add infer-retrieve-rerank pack (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Jan 29, 2024
1 parent ee80a9d commit f1f7c1f
Show file tree
Hide file tree
Showing 9 changed files with 1,269 additions and 403 deletions.
5 changes: 5 additions & 0 deletions llama_hub/llama_packs/library.json
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,10 @@
"id": "llama_packs/vanna",
"author": "jerryjliu",
"keywords": ["vanna", "sql", "ai", "text-to-sql"]
},
"InferRetrieveRerankPack": {
"id": "llama_packs/research/infer_retrieve_rerank",
"author": "jerryjliu",
"keywords": ["infer", "retrieve", "rerank", "retriever", "rag"]
}
}
66 changes: 66 additions & 0 deletions llama_hub/llama_packs/research/infer_retrieve_rerank/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Infer-Retrieve-Rerank LlamaPack

This is our implementation of the paper ["In-Context Learning for Extreme Multi-Label Classification](https://arxiv.org/pdf/2401.12178.pdf) by Oosterlinck et al.

The paper proposes "infer-retrieve-rerank", a simple paradigm using frozen LLM/retriever models that can do "extreme"-label classification (the label space is huge).
1. Given a user query, use an LLM to predict an initial set of labels.
2. For each prediction, retrieve the actual label from the corpus.
3. Given the final set of labels, rerank them using an LLM.

All of these can be implemented as LlamaIndex abstractions.

A full notebook guide can be found [here](https://github.com/run-llama/llama-hub/blob/main/llama_hub/llama_packs/research/infer_retrieve_rerank/infer_retrieve_rerank.ipynb).

## CLI Usage

You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package:

```bash
llamaindex-cli download-llamapack InferRetrieveRerankPack --download-dir ./infer_retrieve_rerank_pack
```

You can then inspect the files at `./infer_retrieve_rerank_pack` and use them as a template for your own project!

## Code Usage

You can download the pack to a `./infer_retrieve_rerank_pack` directory:

```python
from llama_index.llama_pack import download_llama_pack

# download and install dependencies
InferRetrieveRerankPack = download_llama_pack(
"InferRetrieveRerankPack", "./infer_retrieve_rerank_pack"
)
```

From here, you can use the pack, or inspect and modify the pack in `./infer_retrieve_rerank_pack`.

Then, you can set up the pack like so:

```python
# create the pack
pack = InferRetrieveRerankPack(
labels, # list of all label strings
llm=llm,
pred_context="<pred_context>",
reranker_top_n=3,
verbose=True
)

```


The `run()` function runs predictions.

```python
pred_reactions = pack.run(inputs=[s["text"] for s in samples])
```

You can also use modules individually.

```python
# call the llm.complete()
llm = pack.llm
label_retriever = pack.label_retriever
```
Empty file.
157 changes: 157 additions & 0 deletions llama_hub/llama_packs/research/infer_retrieve_rerank/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Infer-Retrieve-Rerank Pack.
Taken from this paper: https://arxiv.org/pdf/2401.12178.pdf.
"""


from typing import Any, Dict, Optional

from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.schema import TextNode
from llama_index.embeddings import OpenAIEmbedding
from llama_index.ingestion import IngestionPipeline
from llama_index import VectorStoreIndex
from llama_index.retrievers import BaseRetriever
from llama_index.llms.llm import LLM
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import QueryPipeline
from llama_index.postprocessor.rankGPT_rerank import RankGPTRerank
from llama_index.output_parsers import ChainableOutputParser
from typing import List

INFER_PROMPT_STR = """\
Your job is to output a list of predictions given context from a given piece of text. The text context,
and information regarding the set of valid predictions is given below.
Return the predictions as a comma-separated list of strings.
Text Context:
{doc_context}
Prediction Info:
{pred_context}
Predictions: """

INFER_PROMPT_TMPL = PromptTemplate(INFER_PROMPT_STR)


class PredsOutputParser(ChainableOutputParser):
"""Predictions output parser."""

def parse(self, output: str) -> List[str]:
"""Parse predictions."""
tokens = output.split(",")
return [t.strip() for t in tokens]


preds_output_parser = PredsOutputParser()


RERANK_PROMPT_STR = """\
Given a piece of text, rank the {num} labels above based on their relevance \
to this piece of text. The labels \
should be listed in descending order using identifiers. \
The most relevant labels should be listed first. \
The output format should be [] > [], e.g., [1] > [2]. \
Only response the ranking results, \
do not say any word or explain. \
Here is a given piece of text: {query}.
"""
RERANK_PROMPT_TMPL = PromptTemplate(RERANK_PROMPT_STR)


def infer_retrieve_rerank(
query: str,
retriever: BaseRetriever,
llm: LLM,
pred_context: str,
infer_prompt: PromptTemplate,
rerank_prompt: PromptTemplate,
reranker_top_n: int = 3,
) -> List[str]:
"""Infer retrieve rerank."""
infer_prompt_c = infer_prompt.as_query_component(
partial={"pred_context": pred_context}
)
infer_pipeline = QueryPipeline(chain=[infer_prompt_c, llm, preds_output_parser])
preds = infer_pipeline.run(query)

all_nodes = []
for pred in preds:
nodes = retriever.retrieve(str(pred))
all_nodes.extend(nodes)

reranker = RankGPTRerank(
llm=llm,
top_n=reranker_top_n,
rankgpt_rerank_prompt=rerank_prompt,
# verbose=True,
)
reranked_nodes = reranker.postprocess_nodes(all_nodes, query_str=query)
return [n.get_content() for n in reranked_nodes]


class InferRetrieveRerankPack(BaseLlamaPack):
"""Infer Retrieve Rerank pack."""

def __init__(
self,
labels: List[str],
llm: Optional[LLM] = None,
pred_context: str = "",
reranker_top_n: int = 3,
infer_prompt: Optional[PromptTemplate] = None,
rerank_prompt: Optional[PromptTemplate] = None,
verbose: bool = False,
) -> None:
"""Init params."""
# NOTE: we use 16k model by default to fit longer contexts
self.llm = llm or OpenAI(model="gpt-3.5-turbo-16k")
label_nodes = [TextNode(text=label) for label in labels]
pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()])
label_nodes_w_embed = pipeline.run(documents=label_nodes)

index = VectorStoreIndex(label_nodes_w_embed, show_progress=verbose)
self.label_retriever = index.as_retriever(similarity_top_k=2)
self.pred_context = pred_context
self.reranker_top_n = reranker_top_n
self.verbose = verbose

self.infer_prompt = infer_prompt or INFER_PROMPT_TMPL
self.rerank_prompt = rerank_prompt or RERANK_PROMPT_TMPL

def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {
"llm": self.llm,
"label_retriever": self.label_retriever,
}

def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run the pipeline."""
inputs = kwargs.get("inputs", [])
pred_reactions = []
for idx, input in enumerate(inputs):
if self.verbose:
print(f"\n\n> Generating predictions for input {idx}: {input[:300]}")
cur_pred_reactions = infer_retrieve_rerank(
input,
self.label_retriever,
self.llm,
self.pred_context,
self.infer_prompt,
self.rerank_prompt,
reranker_top_n=self.reranker_top_n,
)
if self.verbose:
print(f"> Generated predictions: {cur_pred_reactions}")

pred_reactions.append(cur_pred_reactions)

return pred_reactions
Loading

0 comments on commit f1f7c1f

Please sign in to comment.