diff --git a/examples/llm/common/utils.py b/examples/llm/common/utils.py index 406c7eee6d..23e9f69559 100644 --- a/examples/llm/common/utils.py +++ b/examples/llm/common/utils.py @@ -24,18 +24,13 @@ logger = logging.getLogger(__name__) -def build_huggingface_embeddings(model_name: str, - model_kwargs: dict = None, - encode_kwargs: dict = None): - embeddings = HuggingFaceEmbeddings(model_name=model_name, - model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs) +def build_huggingface_embeddings(model_name: str, model_kwargs: dict = None, encode_kwargs: dict = None): + embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) return embeddings -def build_llm_service(model_name: str, llm_service: str, - tokens_to_generate: int, **model_kwargs): +def build_llm_service(model_name: str, llm_service: str, tokens_to_generate: int, **model_kwargs): lowered_llm_service = llm_service.lower() if (lowered_llm_service == 'nemollm'): model_kwargs['tokens_to_generate'] = tokens_to_generate @@ -62,15 +57,13 @@ def build_milvus_config(embedding_size: int): }, }, "schema_conf": { - "enable_dynamic_field": - True, + "enable_dynamic_field": True, "schema_fields": [ - pymilvus.FieldSchema( - name="id", - dtype=pymilvus.DataType.INT64, - description="Primary key for the collection", - is_primary=True, - auto_id=True).to_dict(), + pymilvus.FieldSchema(name="id", + dtype=pymilvus.DataType.INT64, + description="Primary key for the collection", + is_primary=True, + auto_id=True).to_dict(), pymilvus.FieldSchema(name="title", dtype=pymilvus.DataType.VARCHAR, description="The title of the RSS Page", @@ -83,30 +76,28 @@ def build_milvus_config(embedding_size: int): dtype=pymilvus.DataType.VARCHAR, description="The summary of the RSS Page", max_length=65_535).to_dict(), - pymilvus.FieldSchema( - name="page_content", - dtype=pymilvus.DataType.VARCHAR, - description="A chunk of text from the RSS Page", - max_length=65_535).to_dict(), + pymilvus.FieldSchema(name="page_content", + dtype=pymilvus.DataType.VARCHAR, + description="A chunk of text from the RSS Page", + max_length=65_535).to_dict(), pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, description="Embedding vectors", dim=embedding_size).to_dict(), ], - "description": - "Test collection schema" + "description": "Test collection schema" } } return milvus_resource_kwargs -def build_milvus_service(embedding_size: int, - uri: str = "http://localhost:19530"): +def build_milvus_service(embedding_size: int, uri: str = "http://localhost:19530"): milvus_resource_kwargs = build_milvus_config(embedding_size) - vdb_service: MilvusVectorDBService = VectorDBServiceFactory.create_instance( - "milvus", uri=uri, **milvus_resource_kwargs) + vdb_service: MilvusVectorDBService = VectorDBServiceFactory.create_instance("milvus", + uri=uri, + **milvus_resource_kwargs) return vdb_service