Skip to content

Commit

Permalink
lint formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Nov 9, 2023
1 parent fb29c71 commit 940d3fd
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions examples/llm/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@
logger = logging.getLogger(__name__)


def build_huggingface_embeddings(model_name: str,
model_kwargs: dict = None,
encode_kwargs: dict = None):
embeddings = HuggingFaceEmbeddings(model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
def build_huggingface_embeddings(model_name: str, model_kwargs: dict = None, encode_kwargs: dict = None):
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)

return embeddings


def build_llm_service(model_name: str, llm_service: str,
tokens_to_generate: int, **model_kwargs):
def build_llm_service(model_name: str, llm_service: str, tokens_to_generate: int, **model_kwargs):
lowered_llm_service = llm_service.lower()
if (lowered_llm_service == 'nemollm'):
model_kwargs['tokens_to_generate'] = tokens_to_generate
Expand All @@ -62,15 +57,13 @@ def build_milvus_config(embedding_size: int):
},
},
"schema_conf": {
"enable_dynamic_field":
True,
"enable_dynamic_field": True,
"schema_fields": [
pymilvus.FieldSchema(
name="id",
dtype=pymilvus.DataType.INT64,
description="Primary key for the collection",
is_primary=True,
auto_id=True).to_dict(),
pymilvus.FieldSchema(name="id",
dtype=pymilvus.DataType.INT64,
description="Primary key for the collection",
is_primary=True,
auto_id=True).to_dict(),
pymilvus.FieldSchema(name="title",
dtype=pymilvus.DataType.VARCHAR,
description="The title of the RSS Page",
Expand All @@ -83,30 +76,28 @@ def build_milvus_config(embedding_size: int):
dtype=pymilvus.DataType.VARCHAR,
description="The summary of the RSS Page",
max_length=65_535).to_dict(),
pymilvus.FieldSchema(
name="page_content",
dtype=pymilvus.DataType.VARCHAR,
description="A chunk of text from the RSS Page",
max_length=65_535).to_dict(),
pymilvus.FieldSchema(name="page_content",
dtype=pymilvus.DataType.VARCHAR,
description="A chunk of text from the RSS Page",
max_length=65_535).to_dict(),
pymilvus.FieldSchema(name="embedding",
dtype=pymilvus.DataType.FLOAT_VECTOR,
description="Embedding vectors",
dim=embedding_size).to_dict(),
],
"description":
"Test collection schema"
"description": "Test collection schema"
}
}

return milvus_resource_kwargs


def build_milvus_service(embedding_size: int,
uri: str = "http://localhost:19530"):
def build_milvus_service(embedding_size: int, uri: str = "http://localhost:19530"):
milvus_resource_kwargs = build_milvus_config(embedding_size)

vdb_service: MilvusVectorDBService = VectorDBServiceFactory.create_instance(
"milvus", uri=uri, **milvus_resource_kwargs)
vdb_service: MilvusVectorDBService = VectorDBServiceFactory.create_instance("milvus",
uri=uri,
**milvus_resource_kwargs)

return vdb_service

Expand Down

0 comments on commit 940d3fd

Please sign in to comment.