Skip to content

Commit 12797a0

Browse files
committed
minor fix
linter fixes poetry reformat more formatting linter linter linter linter lint lint lint tests
1 parent bb1c472 commit 12797a0

File tree

2 files changed

+104
-60
lines changed

2 files changed

+104
-60
lines changed

libs/community/langchain_community/vectorstores/deeplake.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
try:
99
import deeplake
10-
if deeplake.__version__.startswit("3."):
10+
11+
if deeplake.__version__.startswith("3."):
1112
from deeplake import VectorStore as DeepLakeVectorStore
1213
from deeplake.core.fast_forwarding import version_compare
1314
from deeplake.util.exceptions import SampleExtendError
1415
else:
15-
from deeplake_vector_search import DeepLakeVectorStore
16+
from langchain_community.vectorstores.deeplake_vector_search import (
17+
DeepLakeVectorStore,
18+
)
1619
_DEEPLAKE_INSTALLED = True
1720
except ImportError:
1821
_DEEPLAKE_INSTALLED = False
@@ -933,7 +936,8 @@ def ds(self) -> Any:
933936
return self.vectorstore.dataset
934937

935938
@classmethod
936-
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
939+
# type: ignore[no-untyped-def]
940+
def _validate_kwargs(cls, kwargs, method_name):
937941
if kwargs:
938942
valid_items = cls._get_valid_args(method_name)
939943
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
@@ -952,7 +956,8 @@ def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
952956
return []
953957

954958
@staticmethod
955-
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
959+
# type: ignore[no-untyped-def]
960+
def _get_unsupported_items(kwargs, valid_items):
956961
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
957962
unsupported_items = None
958963
if kwargs:
Lines changed: 95 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
1-
2-
import deeplake
31
import uuid
2+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
3+
4+
from langchain_core.embeddings import Embeddings
5+
6+
try:
7+
import deeplake
8+
9+
_DEEPLAKE_INSTALLED = True
10+
except ImportError:
11+
_DEEPLAKE_INSTALLED = False
12+
13+
414
class DeepLakeVectorStore:
5-
def __init__(self,
6-
path,
7-
embedding_function,
8-
read_only,
9-
token,
10-
exec_option,
11-
verbose,
12-
runtime,
13-
index_params,
14-
**kwargs: any):
15+
def __init__(
16+
self,
17+
path: str,
18+
embedding_function: Optional[Embeddings] = None,
19+
read_only: bool = False,
20+
token: Optional[str] = None,
21+
exec_option: Optional[str] = None,
22+
verbose: bool = False,
23+
runtime: Optional[Dict] = None,
24+
index_params: Optional[Dict[str, Union[int, str]]] = None,
25+
**kwargs: Any,
26+
):
27+
if _DEEPLAKE_INSTALLED is False:
28+
raise ImportError(
29+
"Could not import deeplake python package. "
30+
"Please install it with `pip install deeplake[enterprise]`."
31+
)
1532
self.path = path
1633
self.embedding_function = embedding_function
1734
self.read_only = read_only
@@ -29,72 +46,93 @@ def __init__(self,
2946
except deeplake.LogNotexistsError:
3047
self.__create_dataset()
3148

32-
def tensors(self):
49+
def tensors(self) -> list[str]:
3350
return [c.name for c in self.ds.schema.columns]
3451

35-
def add(self,
36-
text,
37-
metadata,
38-
embedding_data,
39-
embedding_tensor,
40-
embedding_function,
41-
return_ids: bool,
42-
**tensors):
43-
52+
def add(
53+
self,
54+
text: List[str],
55+
metadata: Optional[List[dict]],
56+
embedding_data: Iterable[str],
57+
embedding_tensor: str,
58+
embedding_function: Optional[Callable],
59+
return_ids: bool,
60+
**tensors: Any,
61+
) -> Optional[list[str]]:
4462
if embedding_function is not None:
4563
embedding_data = embedding_function(text)
4664
if embedding_tensor is not None:
4765
embedding_tensor = "embedding"
48-
_id = tensors['id'] if 'id' in tensors else [
49-
str(uuid.uuid1()) for _ in range(len(text))]
50-
self.ds.append({
51-
"text": text,
52-
"metadata": metadata,
53-
"id": _id,
54-
embedding_tensor: embedding_data,
55-
})
66+
_id = (
67+
tensors["id"]
68+
if "id" in tensors
69+
else [str(uuid.uuid1()) for _ in range(len(text))]
70+
)
71+
self.ds.append(
72+
{
73+
"text": text,
74+
"metadata": metadata,
75+
"id": _id,
76+
embedding_tensor: embedding_data,
77+
}
78+
)
5679
self.ds.commit()
5780
if return_ids:
5881
return _id
5982
else:
6083
return None
6184

62-
def search(self, query, exec_options):
85+
def search_tql(self, query: str, exec_options: Optional[str]) -> Dict[str, Any]:
6386
view = self.ds.query(query)
6487
return self.__view_to_docs(view)
6588

66-
def search(self, embedding,
67-
k,
68-
distance_metric,
69-
filter,
70-
exec_option,
71-
return_tensors,
72-
deep_memory):
89+
def search(
90+
self,
91+
embedding: Union[str, List[float]],
92+
k: int,
93+
distance_metric: str,
94+
filter: Optional[Dict[str, Any]],
95+
exec_option: Optional[str],
96+
return_tensors: List[str],
97+
deep_memory: Optional[bool],
98+
query: Optional[str] = None,
99+
) -> Dict[str, Any]:
100+
if query is None and embedding is None:
101+
raise ValueError(
102+
"Both `embedding` and `query` were specified."
103+
" Please specify either one or the other."
104+
)
105+
if query is not None:
106+
return self.search_tql(query, exec_option)
107+
73108
if isinstance(embedding, str):
74-
embedding = self.embedding_function(embedding)
109+
if self.embedding_function is None:
110+
raise ValueError(
111+
"embedding_function is required when embedding is a string"
112+
)
113+
embedding = self.embedding_function.embed_documents([embedding])[0]
75114
emb_str = ", ".join([str(e) for e in embedding])
76115

77-
column_list = (
78-
" * " if return_tensors else ", ".join(return_tensors))
116+
column_list = " * " if return_tensors else ", ".join(return_tensors)
79117

80118
metric = self.__metric_to_function(distance_metric)
81119
order_by = " ASC "
82120
if metric == "cosine_similarity":
83121
order_by = " DESC "
84-
column_list += f", {self.__metric_to_function(
85-
distance_metric)}(embedding, ARRAY[{emb_str}]) as score"
86-
query = f"SELECT {column_list} ORDER BY {self.__metric_to_function(
87-
distance_metric)}(embedding, ARRAY[{emb_str}])+1 {order_by} LIMIT {k}"
122+
dp = f"(embedding, ARRAY[{emb_str}])"
123+
column_list += f", {self.__metric_to_function(distance_metric)}{dp} as score"
124+
mf = self.__metric_to_function(distance_metric)
125+
query = f"SELECT {column_list} ORDER BY {mf}{dp} {order_by} LIMIT {k}"
88126
view = self.ds.query(query)
89127
return self.__view_to_docs(view)
90128

91-
def delete(self, ids, filter, delete_all):
129+
def delete(self, ids: List[str], filter: Dict[str, Any], delete_all: bool) -> None:
92130
raise NotImplementedError
93131

94-
def dataset(self):
132+
def dataset(self) -> Any:
95133
return self.ds
96134

97-
def __view_to_docs(self, view):
135+
def __view_to_docs(self, view: Any) -> Dict[str, Any]:
98136
docs = {}
99137
tenors = [(c.name, str(c.dtype)) for c in view.schema.columns]
100138
for name, type in tenors:
@@ -107,23 +145,24 @@ def __view_to_docs(self, view):
107145
docs[name] = view[name][:]
108146
return docs
109147

110-
def __metric_to_function(self, metric):
148+
def __metric_to_function(self, metric: str) -> str:
111149
if metric is None or metric == "cosine" or metric == "cosine_similarity":
112150
return "cosine_similarity"
113151
elif metric == "l2" or metric == "l2_norm":
114152
return "l2_norm"
115153
else:
116-
raise ValueError(f"Unknown metric: {metric}, should be one of ['cosine', 'cosine_similarity', 'l2', 'l2_norm']")
154+
raise ValueError(
155+
f"Unknown metric: {metric}, should be one of "
156+
"['cosine', 'cosine_similarity', 'l2', 'l2_norm']"
157+
)
117158

118-
def __create_dataset(self):
159+
def __create_dataset(self) -> None:
119160
if self.embedding_function is None:
120-
raise ValueError(
121-
"embedding_function is required to create a new dataset")
122-
emb_size = len(self.embedding_function.embed_documents("test")[0])
161+
raise ValueError("embedding_function is required to create a new dataset")
162+
emb_size = len(self.embedding_function.embed_documents(["test"])[0])
123163
self.ds = deeplake.create(self.path, self.token)
124164
self.ds.add_column("text", deeplake.types.Text("inverted"))
125165
self.ds.add_column("metadata", deeplake.types.Dict())
126-
self.ds.add_column(
127-
"embedding", deeplake.types.Embedding(size=emb_size))
166+
self.ds.add_column("embedding", deeplake.types.Embedding(size=emb_size))
128167
self.ds.add_column("id", deeplake.types.Text)
129168
self.ds.commit()

0 commit comments

Comments
 (0)