Skip to content

Commit

Permalink
retrieve from weaviate using weaviate instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Hakimovich99 committed Nov 21, 2023
1 parent 73b1116 commit bb5cb2c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/components/load_from_csv/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ image: ghcr.io/ml6team/load_from_csv:dev
produces:
text: #TODO: fill in here
fields:
question:
data:
type: string

args:
Expand Down
31 changes: 31 additions & 0 deletions src/components/load_from_hf_hub/fondant_component.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Load from huggingface hub
description: Component that loads a dataset from huggingface hub
image: fndnt/load_from_hf_hub:0.6.2

produces:
text:
fields:
data:
type: string

args:
dataset_name:
description: Name of dataset on the hub
type: str
column_name_mapping:
description: Mapping of the consumed hub dataset to fondant column names
type: dict
default: {}
image_column_names:
description: Optional argument, a list containing the original image column names in case the
dataset on the hub contains them. Used to format the image from HF hub format to a byte string.
type: list
default: []
n_rows_to_load:
description: Optional argument that defines the number of rows to load. Useful for testing pipeline runs on a small scale
type: int
default: None
index_column:
description: Column to set index to in the load component, if not specified a default globally unique index will be set
type: str
default: None
10 changes: 3 additions & 7 deletions src/components/retrieve_from_weaviate/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
FROM --platform=linux/amd64 python:3.8-slim as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt ./
COPY requirements.txt /
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component/src

Expand Down
14 changes: 6 additions & 8 deletions src/components/retrieve_from_weaviate/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ image: ghcr.io/ml6team/retrieve_from_weaviate:dev
consumes:
text: #TODO: fill in here
fields:
question:
data:
type: string
embedding:
type: array
items:
type: float32

produces:
text: #TODO: fill in here
fields:
question:
data:
type: string
retrieved_chunks:
type: array
Expand All @@ -29,12 +33,6 @@ args:
The name of the weaviate class that will be created and used to store the embeddings.
Should follow the weaviate naming conventions.
type: str
text_property_name:
description: Name set for the text stored in the vectorDB
type: str
hf_embed_model:
description: Embedding model used to store the vectors in the vectorDB
type: str
top_k:
description: Number of chunks to retrieve
type: int
7 changes: 2 additions & 5 deletions src/components/retrieve_from_weaviate/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
llama-index==0.8.68
weaviate-client
transformers
torch
fondant[docker]==0.6.2
weaviate-client==3.24.1
fondant[component]==0.7.0
52 changes: 21 additions & 31 deletions src/components/retrieve_from_weaviate/src/main.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,45 @@
import dask
import pandas as pd
from fondant.component import PandasTransformComponent
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.vector_stores import WeaviateVectorStore

import weaviate

dask.config.set({"dataframe.convert-string": False})


class RetrieveChunks(PandasTransformComponent):
def __init__(
self,
*_,
weaviate_url: str,
class_name: str,
text_property_name: str,
hf_embed_model: str,
top_k: int
) -> None:
self,
*_,
weaviate_url: str,
class_name: str,
top_k: int
) -> None:
"""
Args:
weaviate_url: An argument passed to the component
"""
# Initialize your component here based on the arguments
self.client = weaviate.Client(weaviate_url)
self.class_name = class_name
self.text_property_name = text_property_name
self.model = HuggingFaceEmbedding(hf_embed_model)
self.k = top_k
self.retriever = self._set_retriever(
self.client, self.class_name, self.model, self.k
)

def _set_retriever(self, client, class_name, model, k):
vector_store = WeaviateVectorStore(
weaviate_client=client, index_name=class_name
)
service_context = ServiceContext.from_defaults(llm=None, embed_model=model)
indexed_vector_db = VectorStoreIndex.from_vector_store(
vector_store=vector_store, service_context=service_context

def retrieve_chunks(self, vector_query: str):
"""Get results from weaviate database"""
result = (
self.client.query
.get(self.class_name, ["passage"])
.with_near_vector({"vector":vector_query})
.with_limit(self.k)
.with_additional(["distance"])
.do()
)
return indexed_vector_db.as_retriever(similarity_top_k=k)
result_dict = result["data"]["Get"][self.class_name]
text = [retrieved_chunk["passage"] for retrieved_chunk in result_dict]

def retrieve_chunks(self, query: str):
retrievals = self.retriever.retrieve(query)
return [chunk.metadata[self.text_property_name] for chunk in retrievals]
return text

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe[("text", "retrieved_chunks")] = dataframe[("text", "question")].apply(
dataframe[("text", "retrieved_chunks")] = dataframe[("text", "embedding")].apply(
self.retrieve_chunks
)
return dataframe

0 comments on commit bb5cb2c

Please sign in to comment.