Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Feat batched knn #212

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 85 additions & 64 deletions doc/source/notebooks/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-25T12:57:01.715707Z",
"start_time": "2023-08-25T12:56:54.919200Z"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -51,7 +56,12 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-25T12:57:02.644919Z",
"start_time": "2023-08-25T12:57:01.723149Z"
}
},
"outputs": [],
"source": [
"content = [\"I have a dog.\", \"I like eating apples.\"]\n",
Expand Down Expand Up @@ -81,35 +91,17 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-25T12:57:08.645604Z",
"start_time": "2023-08-25T12:57:02.646625Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"\t<tr>\n",
"\t\t<th>id</th>\n",
"\t\t<th>content</th>\n",
"\t</tr>\n",
"\t<tr>\n",
"\t\t<td>0</td>\n",
"\t\t<td>I have a dog.</td>\n",
"\t</tr>\n",
"\t<tr>\n",
"\t\t<td>1</td>\n",
"\t\t<td>I like eating apples.</td>\n",
"\t</tr>\n",
"</table>"
],
"text/plain": [
"----------------------------\n",
" id | content \n",
"----+-----------------------\n",
" 0 | I have a dog. \n",
" 1 | I like eating apples. \n",
"----------------------------\n",
"(2 rows)"
]
"text/plain": "----------------------------\n id | content \n----+-----------------------\n 0 | I have a dog. \n 1 | I like eating apples. \n----------------------------\n(2 rows)",
"text/html": "<table>\n\t<tr>\n\t\t<th>id</th>\n\t\t<th>content</th>\n\t</tr>\n\t<tr>\n\t\t<td>0</td>\n\t\t<td>I have a dog.</td>\n\t</tr>\n\t<tr>\n\t\t<td>1</td>\n\t\t<td>I like eating apples.</td>\n\t</tr>\n</table>"
},
"execution_count": 3,
"metadata": {},
Expand All @@ -133,30 +125,17 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-25T12:57:14.069009Z",
"start_time": "2023-08-25T12:57:08.643273Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"\t<tr>\n",
"\t\t<th>id</th>\n",
"\t\t<th>content</th>\n",
"\t</tr>\n",
"\t<tr>\n",
"\t\t<td>1</td>\n",
"\t\t<td>I like eating apples.</td>\n",
"\t</tr>\n",
"</table>"
],
"text/plain": [
"----------------------------\n",
" id | content \n",
"----+-----------------------\n",
" 1 | I like eating apples. \n",
"----------------------------\n",
"(1 row)"
]
"text/plain": "----------------------------\n id | content \n----+-----------------------\n 1 | I like eating apples. \n----------------------------\n(1 row)",
"text/html": "<table>\n\t<tr>\n\t\t<th>id</th>\n\t\t<th>content</th>\n\t</tr>\n\t<tr>\n\t\t<td>1</td>\n\t\t<td>I like eating apples.</td>\n\t</tr>\n</table>"
},
"execution_count": 4,
"metadata": {},
Expand All @@ -169,39 +148,81 @@
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleaning All at Once"
]
"## Batched k-NN search"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"query = (\n",
" db.create_dataframe(columns={\"idd\": range(3), \"query\": [\"apple\", \"dog\", \"banana\"]})\n",
" .save_as(\n",
" table_name=\"query_sample\",\n",
" column_names=[\"idd\", \"query\"],\n",
" distribution_key={\"idd\"},\n",
" distribution_type=\"hash\",\n",
" )\n",
" .check_unique(columns={\"idd\"})\n",
" .embedding()\n",
" .create_index(column=\"query\", model=\"all-MiniLM-L6-v2\")\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-25T12:57:17.400047Z",
"start_time": "2023-08-25T12:57:14.059315Z"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://localhost:7000\n",
"Done.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
"text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(3 rows)",
"text/html": "<table>\n\t<tr>\n\t\t<th>idd</th>\n\t\t<th>id</th>\n\t\t<th>query</th>\n\t\t<th>content</th>\n\t</tr>\n\t<tr>\n\t\t<td>1</td>\n\t\t<td>0</td>\n\t\t<td>dog</td>\n\t\t<td>I have a dog.</td>\n\t</tr>\n\t<tr>\n\t\t<td>2</td>\n\t\t<td>1</td>\n\t\t<td>banana</td>\n\t\t<td>I like eating apples.</td>\n\t</tr>\n\t<tr>\n\t\t<td>0</td>\n\t\t<td>1</td>\n\t\t<td>apple</td>\n\t\t<td>I like eating apples.</td>\n\t</tr>\n</table>"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.embedding().search(column=\"content\", query=query[\"query\"], top_k=1)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-25T12:57:18.305871Z",
"start_time": "2023-08-25T12:57:17.402679Z"
}
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleaning All at Once"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext sql\n",
"%sql postgresql://localhost:7000\n",
"%sql DROP TABLE text_sample CASCADE;"
"%sql DROP TABLE query_sample CASCADE;"
]
}
],
Expand Down
161 changes: 118 additions & 43 deletions greenplumpython/experimental/embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import abc
from typing import Any, Callable, cast
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
from uuid import uuid4

import greenplumpython as gp
from greenplumpython.col import Column
from greenplumpython.row import Row
from greenplumpython.type import TypeCast

Expand Down Expand Up @@ -144,7 +145,13 @@ def create_index(self, column: str, model: str) -> gp.DataFrame:
)
return self._dataframe

def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:
def search(
self,
column: str,
query: Any,
top_k: int,
query_unique_key_columns: Optional[Union[Dict[str, Optional[str]], Set[str]]] = None,
) -> gp.DataFrame:
"""
Searche unstructured data based on semantic similarity on embeddings.

Expand All @@ -155,54 +162,122 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:

Returns:
Dataframe with the top k most similar results in the `column` of `query`.

Example:
Please refer to :ref:`embedding-example` for more details.
"""
assert self._dataframe._db is not None
embdedding_info = self._dataframe._db._execute(
f"""
WITH indexed_col_info AS (
SELECT attrelid, attnum
FROM pg_attribute

def find_embedding_df(df: gp.DataFrame, column_c: str):
assert df._db is not None

embdedding_info = df._db._execute(
f"""
WITH indexed_col_info AS (
SELECT attrelid, attnum
FROM pg_attribute
WHERE
attrelid = '{df._qualified_table_name}'::regclass::oid AND
attname = '{column_c}'
), reloptions AS (
SELECT unnest(reloptions) AS option
FROM pg_class, indexed_col_info
WHERE pg_class.oid = attrelid
), embedding_info_json AS (
SELECT split_part(option, '=', 2)::json AS val
FROM reloptions, indexed_col_info
WHERE option LIKE format('_pygp_emb_%s=%%', attnum)
), embedding_info AS (
SELECT *
FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text)
)
SELECT nspname, relname, attname, model
FROM embedding_info, pg_class, pg_namespace, pg_attribute
WHERE
attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND
attname = '{column}'
), reloptions AS (
SELECT unnest(reloptions) AS option
FROM pg_class, indexed_col_info
WHERE pg_class.oid = attrelid
), embedding_info_json AS (
SELECT split_part(option, '=', 2)::json AS val
FROM reloptions, indexed_col_info
WHERE option LIKE format('_pygp_emb_%s=%%', attnum)
), embedding_info AS (
SELECT *
FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text)
pg_class.oid = embedding_relid AND
relnamespace = pg_namespace.oid AND
embedding_relid = attrelid AND
pg_attribute.attnum = 2;
"""
)
SELECT nspname, relname, attname, model
FROM embedding_info, pg_class, pg_namespace, pg_attribute
WHERE
pg_class.oid = embedding_relid AND
relnamespace = pg_namespace.oid AND
embedding_relid = attrelid AND
pg_attribute.attnum = 2;
"""
)
row: Row = embdedding_info[0]
schema: str = row["nspname"]
embedding_table_name: str = row["relname"]
model = row["model"]
embedding_col_name = row["attname"]
embedding_df = self._dataframe._db.create_dataframe(
table_name=embedding_table_name, schema=schema
)
row: Row = embdedding_info[0]
schema: str = row["nspname"]
embedding_table_name: str = row["relname"]
model = row["model"]
embedding_col_name = row["attname"]
embedding_df = df._db.create_dataframe(table_name=embedding_table_name, schema=schema)
return embedding_df, embedding_table_name, embedding_col_name, model

def _bind(t: str, columns: Union[Dict[str, Optional[str]], Set[str]]) -> List[str]:
target_list: List[str] = []
for k in columns:
v = columns[k] if isinstance(columns, dict) else None
col_serialize = t + "." + k + (f' AS "{v}"' if v is not None else "")
target_list.append(col_serialize)
return target_list

(
self_embedding_df,
self_embedding_table_name,
self_embedding_col_name,
self_model,
) = find_embedding_df(self._dataframe, column)
assert self._dataframe.unique_key is not None
distance = gp.operator("<->") # L2 distance is the default operator class in pgvector
if isinstance(query, Column):
assert query._dataframe is not None
(_, query_embedding_table_name, query_embedding_col_name, _,) = find_embedding_df(
query._dataframe.embedding()._dataframe, query._name # type: ignore reportUnknownArgumentType
)
assert query._dataframe.unique_key is not None
joint_table_name = "cte_" + uuid4().hex
right_join_table_name = "cte_" + uuid4().hex
query_df_unique_keys: List[str] = list(query._dataframe.unique_key)
self_df_unique_keys: List[str] = list(self._dataframe.unique_key)
assert query_df_unique_keys is not None
assert self_df_unique_keys is not None
lateral_join_df = gp.DataFrame(
query=f"""
WITH {joint_table_name} as (
SELECT
{",".join(_bind(query_embedding_table_name, columns=query_unique_key_columns))
if query_unique_key_columns is not None
else ",".join(
[(query_embedding_table_name+"."+key) for key in query_df_unique_keys]
)},
{",".join([(right_join_table_name+"."+key) for key in self_df_unique_keys])},
{query_embedding_table_name}.{query_embedding_col_name},
{right_join_table_name}.{self_embedding_col_name}
FROM {query_embedding_table_name} CROSS JOIN LATERAL (
SELECT * FROM {self_embedding_table_name}
ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name}
LIMIT {top_k}
) AS {right_join_table_name}
)

SELECT
{",".join(_bind(query._dataframe._qualified_table_name, columns=query_unique_key_columns))
if query_unique_key_columns is not None
else ",".join(
[(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys]
)},
{",".join([(self._dataframe._qualified_table_name+"."+key) for key in self_df_unique_keys])},
{query._dataframe._qualified_table_name}.{query._name},
{self._dataframe._qualified_table_name}.{column}
FROM {joint_table_name}
JOIN {query._dataframe._qualified_table_name}
ON {"AND".join([
(f"{query._dataframe._qualified_table_name}.{key} = {joint_table_name}.{query_unique_key_columns[key] if query_unique_key_columns is not None and key in query_unique_key_columns else key}")
for key in query_df_unique_keys
])}
JOIN {self._dataframe._qualified_table_name}
ON {"AND".join([(self._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in self_df_unique_keys])}
""",
db=self._dataframe._db,
)
return lateral_join_df

return self._dataframe.join(
embedding_df.assign(
self_embedding_df.assign(
distance=lambda t: distance(
embedding_df[embedding_col_name], _generate_embedding(query, model)
self_embedding_df[self_embedding_col_name],
_generate_embedding(query, self_model),
)
).order_by("distance")[:top_k],
how="inner",
Expand Down