Skip to content

ColBERT Upstream Updates #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fast**RAG** is a research framework designed to facilitate the building of retri

## Updates

- **June 2023**: ColBERT index modification: adding/removing documents; see [IndexUpdater](libs/colbert/colbert/index_updater.py).
- **May 2023**: [RAG with LLM and dynamic prompt synthesis example](examples/rag-prompt-hf.ipynb).
- **April 2023**: Qdrant `DocumentStore` support.

Expand Down
2 changes: 1 addition & 1 deletion fastrag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastrag import image_generators, kg_creators, rankers, readers, retrievers, stores
from fastrag.utils import add_timing_to_pipeline

__version__ = "1.2.0"
__version__ = "1.3.0"


def load_pipeline(config_path: str) -> Pipeline:
Expand Down
22 changes: 20 additions & 2 deletions libs/colbert/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 🚨 **Announcements**

* (1/29/23) We have merged a new index updater feature and support for additional Hugging Face models! These are in beta so please give us feedback as you try them out.
* (1/24/23) If you're looking for the **DSP** framework for composing ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dsp

# ColBERT (v2)

### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
Expand All @@ -18,7 +23,7 @@ These rich interactions allow ColBERT to surpass the quality of _single-vector_
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21).
* [**Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval**](https://arxiv.org/abs/2101.00436) (NeurIPS'21).
* [**ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction**](https://arxiv.org/abs/2112.01488) (NAACL'22).
* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (preprint).
* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (CIKM'22).

----

Expand All @@ -29,7 +34,7 @@ The ColBERTv1 code from the SIGIR'20 paper is in the [`colbertv1` branch](https:

## Installation

ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.
ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [Hugging Face Transformers](https://github.com/huggingface/transformers) library.

We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).)

Expand Down Expand Up @@ -161,6 +166,19 @@ if __name__=='__main__':
print(f"Saved checkpoint to {checkpoint_path}...")
```

## Running a lightweight ColBERTv2 server
We provide a script to run a lightweight server which serves k (upto 100) results in ranked order for a given search query, in JSON format. This script can be used to power DSP programs.

To run the server, update the environment variables `INDEX_ROOT` and `INDEX_NAME` in the `.env` file to point to the appropriate ColBERT index. The run the following command:
```
python server.py
```

A sample query:
```
http://localhost:8893/api/search?query=Who won the 2022 FIFA world cup&k=25
```

## Branches

### Supported branches
Expand Down
1 change: 1 addition & 0 deletions libs/colbert/colbert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .index_updater import IndexUpdater
from .indexer import Indexer
from .modeling.checkpoint import Checkpoint
from .searcher import Searcher
Expand Down
50 changes: 50 additions & 0 deletions libs/colbert/colbert/distillation/ranking_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections import defaultdict

import tqdm
import ujson
from colbert.data import Ranking
from colbert.distillation.scorer import Scorer
from colbert.infra import Run
from colbert.infra.provenance import Provenance
from colbert.utility.utils.save_metadata import get_metadata_only
from colbert.utils.utils import print_message, zipstar


class RankingScorer:
def __init__(self, scorer: Scorer, ranking: Ranking):
self.scorer = scorer
self.ranking = ranking.tolist()
self.__provenance = Provenance()

print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!")

def provenance(self):
return self.__provenance

def run(self):
print_message(f"#> Starting..")

qids, pids, *_ = zipstar(self.ranking)
distillation_scores = self.scorer.launch(qids, pids)

scores_by_qid = defaultdict(list)

for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)):
scores_by_qid[qid].append((score, pid))

with Run().open("distillation_scores.json", "w") as f:
for qid in tqdm.tqdm(scores_by_qid):
obj = (qid, scores_by_qid[qid])
f.write(ujson.dumps(obj) + "\n")

output_path = f.name
print_message(f"#> Saved the distillation_scores to {output_path}")

with Run().open(f"{output_path}.meta", "w") as f:
d = {}
d["metadata"] = get_metadata_only()
d["provenance"] = self.provenance()
line = ujson.dumps(d, indent=4)
f.write(line)

return output_path
75 changes: 75 additions & 0 deletions libs/colbert/colbert/distillation/scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import tqdm
from colbert.infra import Run, RunConfig
from colbert.infra.launcher import Launcher
from colbert.modeling.reranker.electra import ElectraReranker
from colbert.utils.utils import flatten
from transformers import AutoModelForSequenceClassification, AutoTokenizer

DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"


class Scorer:
def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
self.queries = queries
self.collection = collection
self.model = model

self.maxlen = maxlen
self.bsize = bsize

def launch(self, qids, pids):
launcher = Launcher(self._score_pairs_process, return_all=True)
outputs = launcher.launch(Run().config, qids, pids)

return flatten(outputs)

def _score_pairs_process(self, config, qids, pids):
assert len(qids) == len(pids), (len(qids), len(pids))
share = 1 + len(qids) // config.nranks
offset = config.rank * share
endpos = (1 + config.rank) * share

return self._score_pairs(
qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1)
)

def _score_pairs(self, qids, pids, show_progress=False):
tokenizer = AutoTokenizer.from_pretrained(self.model)
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()

assert len(qids) == len(pids), (len(qids), len(pids))

scores = []

model.eval()
with torch.inference_mode():
with torch.cuda.amp.autocast():
for offset in tqdm.tqdm(
range(0, len(qids), self.bsize), disable=(not show_progress)
):
endpos = offset + self.bsize

queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]

features = tokenizer(
queries_,
passages_,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=self.maxlen,
).to(model.device)

scores.append(model(**features).logits.flatten())

scores = torch.cat(scores)
scores = scores.tolist()

Run().print(f"Returning with {len(scores)} scores")

return scores


# LONG-TERM TODO: This can be sped up by sorting by length in advance.
3 changes: 1 addition & 2 deletions libs/colbert/colbert/evaluation/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def load_collection(collection_path):
print(f"{line_idx // 1000 // 1000}M", end=" ", flush=True)

pid, passage, *rest = line.strip("\n\r ").split("\t")
# id could be either "id" (the first line), a number or have the format "docNUM"
assert pid == "id" or int(pid if pid.isnumeric() else pid[3:]) == line_idx
assert pid == "id" or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}"

if len(rest) >= 1:
title = rest[0]
Expand Down
Loading