Skip to content

Commit

Permalink
add: BGEImageRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Sep 30, 2024
1 parent 3e058ed commit e6ca1be
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions finance_multi_modal_rag/finance_multi_modal_rag/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

import numpy as np
import safetensors
Expand All @@ -12,55 +12,51 @@

class BGERetriever(weave.Model):
model_name: str
weave_chunked_dataset_address: Optional[str]
corpus: List[Union[Dict[str, str], str]] = []
weave_chunked_dataset_address: str
_corpus: List[Dict[str, str]] = []
_index: np.ndarray = None
_model: SentenceTransformer = None

def __init__(
self,
weave_chunked_dataset_address: str,
model_name: str,
weave_chunked_dataset_address: Optional[str] = None,
corpus: List[Dict[str, str]] = [],
index: Optional[np.ndarray] = None,
):
super().__init__(
model_name=model_name,
weave_chunked_dataset_address=weave_chunked_dataset_address,
corpus=corpus,
model_name=model_name,
)
self._index = index
self._model = SentenceTransformer(self.model_name)
self.corpus = (
[
dict(row)
for row in weave.ref(self.weave_chunked_dataset_address).get().rows
]
if len(corpus) == 0
else corpus
)
self._corpus = [
dict(row)
for row in weave.ref(self.weave_chunked_dataset_address).get().rows
]

@classmethod
def from_wandb_artifact(cls, artifact_address: str, model_name: str):
def from_wandb_artifact(
cls, artifact_address: str, weave_chunked_dataset_address: str, model_name: str
):
api = wandb.Api()
artifact = api.artifact(artifact_address)
artifact_dir = artifact.download()
with open(os.path.join(artifact_dir, "index.safetensors"), "rb") as f:
index = f.read()
index = safetensors.numpy.load(index)["index"]
return cls(model_name=model_name, index=index)
return cls(
weave_chunked_dataset_address=weave_chunked_dataset_address,
model_name=model_name,
index=index,
)

def create_index(
self,
index_persist_dir: Optional[str] = None,
artifact_name: Optional[str] = None,
):
self._index = self._model.encode(
sentences=(
[row["cleaned_content"] for row in self.corpus]
if self.weave_chunked_dataset_address
else self.corpus
),
sentences=[row["cleaned_content"] for row in self._corpus],
normalize_embeddings=True,
)
if index_persist_dir:
Expand Down Expand Up @@ -90,20 +86,10 @@ def search(self, query: str, top_k: int = 5):
top_k_indices = sorted_indices[:top_k].tolist()
retrieved_pages = []
for idx in top_k_indices:
retrieved_content = (
self.corpus[idx]["cleaned_content"]
if self.weave_chunked_dataset_address
else self.corpus[idx]
)
metadata = (
self.corpus[idx]["metadata"]
if self.weave_chunked_dataset_address
else {"idx": idx}
)
retrieved_pages.append(
{
"retrieved_content": retrieved_content,
"metadata": metadata,
"retrieved_content": self._corpus[idx]["cleaned_content"],
"metadata": self._corpus[idx]["metadata"],
}
)
return retrieved_pages
Expand All @@ -115,3 +101,40 @@ def predict(self, query: str, top_k: int = 5):
+ query,
top_k,
)


class BGEImageRetriever(weave.Model):
model_name: str
_model: SentenceTransformer = None

def __init__(self, model_name: str):
super().__init__(model_name=model_name)
self._model = SentenceTransformer(self.model_name)

@weave.op()
def search(self, query: str, image_descriptions: List[str], top_k: int = 5):
index = self._model.encode(
sentences=image_descriptions, normalize_embeddings=True
)
query_embeddings = self._model.encode([query], normalize_embeddings=True)
scores = query_embeddings @ index.T
sorted_indices = np.argsort(scores, axis=None)[::-1]
top_k_indices = sorted_indices[:top_k].tolist()
retrieved_pages = []
for idx in top_k_indices:
retrieved_pages.append(
{
"retrieved_image_description": image_descriptions[idx],
"image_idx": idx,
}
)
return retrieved_pages

@weave.op()
def predict(self, query: str, image_descriptions: List[str], top_k: int = 1):
return self.search(
"Generate a representation for this sentence that can be used to retrieve related articles:\n"
+ query,
image_descriptions,
top_k,
)

0 comments on commit e6ca1be

Please sign in to comment.