diff --git a/docs/how-to/Predict-Missing-Data.ipynb b/docs/how-to/Predict-Missing-Data.ipynb index cfeb7ec..66b7359 100644 --- a/docs/how-to/Predict-Missing-Data.ipynb +++ b/docs/how-to/Predict-Missing-Data.ipynb @@ -9,7 +9,9 @@ "\n", "The framework is designed to support different kinds of inference, including rule-based and LLMs. This notebooks shows simple ML-based inference using scikit-learn DecisionTrees.\n", "\n", - "We will use the Iris dataset:" + "This how-to walks through the basic operations of using the `linkml-store` command line tool to perform training and inference using scikit-learn DecisionTrees. This uses the command line interface, but the same operations can be performed programmatically using the Python API, or via the Web API.\n", + "\n", + "We will use a subset of the classic [Iris dataset](https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html), converted to jsonl (JSON Lines) format:" ], "metadata": { "collapsed": false @@ -18,7 +20,18 @@ }, { "cell_type": "code", - "execution_count": 18, + "source": [ + "%%bash\n", + "linkml-store -i ../../tests/input/iris.jsonl describe" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-23T22:15:36.754913Z", + "start_time": "2024-08-23T22:15:33.366042Z" + } + }, + "id": "d2ef6e85292b5a20", "outputs": [ { "name": "stdout", @@ -33,25 +46,111 @@ ] } ], - "source": [ - "%%bash\n", - "linkml-store -i ../../tests/input/iris.jsonl describe" - ], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## The Infer Command", + "id": "335516b2c129363a" + }, + { "metadata": { - "collapsed": false, "ExecuteTime": { - "end_time": "2024-08-12T20:08:06.401967Z", - "start_time": "2024-08-12T20:08:03.933123Z" + "end_time": "2024-08-23T22:20:41.635957Z", + "start_time": "2024-08-23T22:20:38.428284Z" } }, - "id": "d2ef6e85292b5a20" + "cell_type": "code", + "source": [ + "%%bash\n", + "linkml-store infer --help" + ], + "id": "e38efeb1addfe697", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Usage: linkml-store infer [OPTIONS]\n", + "\n", + " Predict a complete object from a partial object.\n", + "\n", + " Currently two main prediction methods are provided: RAG and sklearn\n", + "\n", + " ## RAG:\n", + "\n", + " The RAG approach will use Retrieval Augmented Generation to inference the\n", + " missing attributes of an object.\n", + "\n", + " Example:\n", + "\n", + " linkml-store -i countries.jsonl inference -t rag -q 'name: Uruguay'\n", + "\n", + " Result:\n", + "\n", + " capital: Montevideo, code: UY, continent: South America, languages:\n", + " [Spanish]\n", + "\n", + " You can pass in configurations as follows:\n", + "\n", + " linkml-store -i countries.jsonl inference -t\n", + " rag:llm_config.model_name=llama-3 -q 'name: Uruguay'\n", + "\n", + " ## SKLearn:\n", + "\n", + " This uses scikit-learn (defaulting to simple decision trees) to do the\n", + " prediction.\n", + "\n", + " linkml-store -i tests/input/iris.csv inference -t sklearn -q\n", + " '{\"sepal_length\": 5.1, \"sepal_width\": 3.5, \"petal_length\": 1.4,\n", + " \"petal_width\": 0.2}'\n", + "\n", + "Options:\n", + " -O, --output-type [json|jsonl|yaml|yamll|tsv|csv|python|parquet|formatted|table|duckdb|postgres|mongodb]\n", + " Output format\n", + " -o, --output PATH Output file path\n", + " -T, --target-attribute TEXT Target attributes for inference\n", + " -F, --feature-attributes TEXT Feature attributes for inference (comma\n", + " separated)\n", + " -Y, --inference-config-file PATH\n", + " Path to inference configuration file\n", + " -E, --export-model PATH Export model to file\n", + " -L, --load-model PATH Load model from file\n", + " -M, --model-format [pickle|onnx|pmml|pfa|joblib|png|linkml_expression|rulebased|rag_index]\n", + " Format for model\n", + " -S, --training-test-data-split ...\n", + " Training/test data split\n", + " -t, --predictor-type TEXT Type of predictor [default: sklearn]\n", + " -n, --evaluation-count INTEGER Number of examples to evaluate over\n", + " --evaluation-match-function TEXT\n", + " Name of function to use for matching objects\n", + " in eval\n", + " -q, --query TEXT query term\n", + " --help Show this message and exit.\n" + ] + } + ], + "execution_count": 5 }, { "cell_type": "markdown", "source": [ "## Training and Inference\n", "\n", - "We can perform training and inference in a single step:" + "We can perform training and inference in a single step. \n", + "\n", + "For feature labels, we use:\n", + "\n", + "- `petal_length`\n", + "- `petal_width`\n", + "- `sepal_length`\n", + "- `sepal_width`\n", + "\n", + "These can be explicitly specified using `-F`, but in this case we are specifying a query, so\n", + "the feature labels are inferred from the query.\n", + "\n", + "We specify the target label using `-T`. In this case, we are predicting the `species` of the iris.\n" ], "metadata": { "collapsed": false @@ -60,7 +159,18 @@ }, { "cell_type": "code", - "execution_count": 9, + "source": [ + "%%bash\n", + "linkml-store -i ../../tests/input/iris.jsonl infer -t sklearn -T species -q \"{petal_length: 2.5, petal_width: 0.5, sepal_length: 5.0, sepal_width: 3.5}\" " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-23T22:17:38.972690Z", + "start_time": "2024-08-23T22:17:35.558907Z" + } + }, + "id": "4984aeb4016df154", "outputs": [ { "name": "stderr", @@ -76,29 +186,27 @@ "text": [ "predicted_object:\n", " species: setosa\n", - "confidence: 1.0\n" + "confidence: 1.0\n", + "\n" ] } ], - "source": [ - "%%bash\n", - "linkml-store -i ../../tests/input/iris.jsonl infer -t sklearn -T species -q \"{petal_length: 2.5, petal_width: 0.5, sepal_length: 5.0, sepal_width: 3.5}\" " - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-08-12T19:35:08.172872Z", - "start_time": "2024-08-12T19:35:05.095856Z" - } - }, - "id": "4984aeb4016df154" + "execution_count": 4 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "The data model for the output consists of a `predicted_object` slot and a `confidence`. Note that for standard ML operations, the predicted object will typically have one attribute only, but other kinds of inference (OWL reasoning, LLMs) may be able to predict complex objects.", + "id": "dfcbdae846f56ada" }, { "cell_type": "markdown", "source": [ "## Saving the Model\n", "\n", - "Performing training and inference in a single step is convenient where training is fast, but more typically we'd want to save the model for later use:" + "Performing training and inference in a single step is convenient where training is fast, but more typically we'd want to save the model for later use.\n", + "\n", + "We can do this with the `-E` option:" ], "metadata": { "collapsed": false @@ -181,48 +289,29 @@ }, { "cell_type": "code", - "execution_count": 15, - "outputs": [], "source": [ "%%bash\n", - "linkml-store -i ../../tests/input/iris.jsonl infer -t sklearn -L \"tmp/iris-model.joblib\" -E \"tmp/iris-model.png\"" + "linkml-store --stacktrace -i ../../tests/input/iris.jsonl infer -t sklearn -T species -L tmp/iris-model.joblib -E input/iris-model.png" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-08-12T19:57:43.145521Z", - "start_time": "2024-08-12T19:57:40.441893Z" + "end_time": "2024-08-23T22:23:18.451362Z", + "start_time": "2024-08-23T22:23:15.571984Z" } }, - "id": "d7d14edd77e9e1fe" + "id": "d7d14edd77e9e1fe", + "outputs": [], + "execution_count": 9 }, { "cell_type": "markdown", - "source": [ - "![img](tmp/iris-model.png)" - ], + "source": "![img](input/iris-model.png)", "metadata": { "collapsed": false }, "id": "cca55edf629f8c26" }, - { - "cell_type": "code", - "execution_count": 29, - "outputs": [], - "source": [ - "%%bash\n", - "linkml-store -i ../../tests/input/iris.jsonl infer -t sklearn -L tmp/iris-model.joblib -E tmp/iris-model.rulebased.yaml" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-08-12T21:59:26.805316Z", - "start_time": "2024-08-12T21:59:24.343197Z" - } - }, - "id": "acb7c57ecb3be9b" - }, { "cell_type": "markdown", "source": [ @@ -244,8 +333,20 @@ "id": "3ef8a6bc39b5e667" }, { + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-23T22:24:16.457340Z", + "start_time": "2024-08-23T22:24:13.977990Z" + } + }, "cell_type": "code", - "execution_count": 30, + "source": [ + "%%bash\n", + "linkml-store -i ../../tests/input/iris.jsonl infer -t sklearn -T species -L tmp/iris-model.joblib -E tmp/iris-model.rulebased.yaml\n", + "cat tmp/iris-model.rulebased.yaml" + ], + "id": "acb7c57ecb3be9b", "outputs": [ { "name": "stdout", @@ -266,17 +367,13 @@ ] } ], - "source": [ - "%%bash\n", - "cat tmp/iris-model.rulebased.yaml" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2024-08-12T21:59:52.936844Z" - } - }, - "id": "4fdea226f501455e" + "execution_count": 10 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We can then apply this model to new data:", + "id": "50f9cd9df60b41c9" }, { "cell_type": "code", @@ -310,14 +407,26 @@ "id": "4df0d87dff96e667" }, { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## More advanced ML models\n", + "\n", + "Currently only Decision Trees are supported. Additionally, most of the underlying functionality of scikit-learn is hidden.\n", + "\n", + "For more advanced ML, you are encouraged to use linkml-store for *data management* and then exporting to standard tabular ot dataframe formats in order to do more advanced ML in Python. linkml-store is *not* intended as an ML platform. Instead a limited set of operations are provided to assist with data exploration and assisting in construction of deterministic rules.\n", + "\n", + "For inference using LLMs and Retrieval Augmented Generation, see the how-to guide on those topics.\n" + ], + "id": "d1b583ce2d75c0e0" + }, + { + "metadata": {}, "cell_type": "code", - "execution_count": null, "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "cef5b6e4ee9cb5f5" + "execution_count": null, + "source": "", + "id": "c8d9e36761d3088d" } ], "metadata": { diff --git a/docs/how-to/input/iris-model.png b/docs/how-to/input/iris-model.png new file mode 100644 index 0000000..2694c6a Binary files /dev/null and b/docs/how-to/input/iris-model.png differ diff --git a/src/linkml_store/api/collection.py b/src/linkml_store/api/collection.py index 324a809..08e0160 100644 --- a/src/linkml_store/api/collection.py +++ b/src/linkml_store/api/collection.py @@ -470,6 +470,7 @@ def search( where: Optional[Any] = None, index_name: Optional[str] = None, limit: Optional[int] = None, + mmr_relevance_factor: Optional[float] = None, **kwargs, ) -> QueryResult: """ @@ -534,7 +535,7 @@ def search( index_col = ix.index_field # TODO: optimize this for large indexes vector_pairs = [(row, np.array(row[index_col], dtype=float)) for row in qr.rows] - results = ix.search(query, vector_pairs, limit=limit) + results = ix.search(query, vector_pairs, limit=limit, mmr_relevance_factor=mmr_relevance_factor, **kwargs) for r in results: del r[1][index_col] new_qr = QueryResult(num_rows=len(results)) diff --git a/src/linkml_store/cli.py b/src/linkml_store/cli.py index b25cb98..4e618e0 100644 --- a/src/linkml_store/cli.py +++ b/src/linkml_store/cli.py @@ -1,8 +1,9 @@ import logging import sys import warnings +from collections import defaultdict from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Any import click import yaml @@ -415,14 +416,6 @@ def list_collections(ctx, **kwargs): def fq(ctx, where, limit, columns, output_type, wide, output): """ Query facets from the specified collection. - - :param ctx: - :param where: - :param limit: - :param columns: - :param output_type: - :param output: - :return: """ collection = ctx.obj["settings"].collection where_clause = yaml.safe_load(where) if where else None @@ -488,6 +481,41 @@ def describe(ctx, where, output_type, output, limit): write_output(df.describe(include="all").transpose(), output_type, target=output) +@cli.command() +@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query") +@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return") +@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format") +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--index", "-I", help="Attributes to index on in pivot") +@click.option("--columns", "-A", help="Attributes to use as columns in pivot") +@click.option("--values", "-V", help="Attributes to use as values in pivot") +@click.pass_context +def pivot(ctx, where, limit, index, columns, values, output_type, output): + collection = ctx.obj["settings"].collection + where_clause = yaml.safe_load(where) if where else None + column_atts = columns.split(",") if columns else None + value_atts = values.split(",") if values else None + index_atts = index.split(",") if index else None + results = collection.find(where_clause, limit=limit) + pivoted = defaultdict(dict) + for row in results.rows: + index_key = tuple([row.get(att) for att in index_atts]) + column_key = tuple([row.get(att) for att in column_atts]) + value_key = tuple([row.get(att) for att in value_atts]) + pivoted[index_key][column_key] = value_key + pivoted_objs = [] + def detuple(t: Tuple) -> Any: + if len(t) == 1: + return t[0] + return str(t) + for index_key, data in pivoted.items(): + obj = {att: key for att, key in zip(index_atts, index_key)} + for column_key, value_key in data.items(): + obj[detuple(column_key)] = detuple(value_key) + pivoted_objs.append(obj) + write_output(pivoted_objs, output_type, target=output) + + @cli.command() @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format") @click.option("--output", "-o", type=click.Path(), help="Output file path") diff --git a/src/linkml_store/index/indexer.py b/src/linkml_store/index/indexer.py index 3a887a8..70e227b 100644 --- a/src/linkml_store/index/indexer.py +++ b/src/linkml_store/index/indexer.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np +from linkml_store.utils.vector_utils import pairwise_cosine_similarity, mmr_diversified_search from pydantic import BaseModel INDEX_ITEM = np.ndarray @@ -19,20 +20,6 @@ class TemplateSyntaxEnum(str, Enum): fstring = "fstring" -def cosine_similarity(vector1, vector2) -> float: - """ - Calculate the cosine similarity between two vectors - - :param vector1: - :param vector2: - :return: - """ - dot_product = np.dot(vector1, vector2) - norm1 = np.linalg.norm(vector1) - norm2 = np.linalg.norm(vector2) - return dot_product / (norm1 * norm2) - - class Indexer(BaseModel): """ An indexer operates on a collection in order to search for objects. @@ -79,7 +66,7 @@ class Indexer(BaseModel): to get a sense of how they work. >>> vectors = indexer.objects_to_vectors([{"name": "Aardvark"}, {"name": "Aardwolf"}, {"name": "Zesty"}]) - >>> assert cosine_similarity(vectors[0], vectors[1]) > cosine_similarity(vectors[0], vectors[2]) + >>> assert pairwise_cosine_similarity(vectors[0], vectors[1]) > pairwise_cosine_similarity(vectors[0], vectors[2]) Note you should consult the documentation for the specific indexer you are using for more details on how text is converted to vectors. @@ -167,7 +154,8 @@ def object_to_text(self, obj: Dict[str, Any]) -> str: return str(obj) def search( - self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None + self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None, + mmr_relevance_factor: Optional[float] = None ) -> List[Tuple[float, Any]]: """ Use the indexer to search against a database of vectors. @@ -183,13 +171,29 @@ def search( # Convert the query string to a vector query_vector = self.text_to_vector(query, cache=False) + if mmr_relevance_factor is not None: + vlist = [v for _, v in vectors] + idlist = [id for id, _ in vectors] + sorted_indices = mmr_diversified_search( + query_vector, vlist, + relevance_factor=mmr_relevance_factor, top_n=limit) + results = [] + # TODO: this is inefficient when limit is high + for i in range(limit): + if i >= len(sorted_indices): + break + pos = sorted_indices[i] + score = pairwise_cosine_similarity(query_vector, vlist[pos]) + results.append((score, idlist[pos])) + return results + distances = [] # Iterate over each indexed item for item_id, item_vector in vectors: # Calculate the Euclidean distance between the query vector and the item vector # distance = 1-np.linalg.norm(query_vector - item_vector) - distance = cosine_similarity(query_vector, item_vector) + distance = pairwise_cosine_similarity(query_vector, item_vector) distances.append((distance, item_id)) # Sort the distances in ascending order diff --git a/src/linkml_store/inference/implementations/rag_inference_engine.py b/src/linkml_store/inference/implementations/rag_inference_engine.py index 6f8bf6d..e0a4a66 100644 --- a/src/linkml_store/inference/implementations/rag_inference_engine.py +++ b/src/linkml_store/inference/implementations/rag_inference_engine.py @@ -15,6 +15,10 @@ logger = logging.getLogger(__name__) +MAX_ITERATIONS = 5 +DEFAULT_NUM_EXAMPLES = 20 +DEFAULT_MMR_RELEVANCE_FACTOR = 0.8 + SYSTEM_PROMPT = """ You are a {llm_config.role}, your task is to inference the YAML object output given the YAML object input. I will provide you @@ -32,6 +36,10 @@ class TrainedModel(BaseModel, extra="forbid"): config: Optional[InferenceConfig] = None +class RAGInference(Inference): + iterations: int = 0 + + @dataclass class RAGInferenceEngine(InferenceEngine): """ @@ -103,7 +111,7 @@ def initialize_model(self, **kwargs): def object_to_text(self, object: OBJECT) -> str: return yaml.dump(object) - def derive(self, object: OBJECT) -> Optional[Inference]: + def derive(self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None) -> Optional[RAGInference]: import llm from tiktoken import encoding_for_model @@ -113,15 +121,17 @@ def derive(self, object: OBJECT) -> Optional[Inference]: model_name = self.config.llm_config.model_name feature_attributes = self.config.feature_attributes target_attributes = self.config.target_attributes - num_examples = self.config.llm_config.number_of_few_shot_examples or 5 + num_examples = self.config.llm_config.number_of_few_shot_examples or DEFAULT_NUM_EXAMPLES query_text = self.object_to_text(object) + mmr_relevance_factor = DEFAULT_MMR_RELEVANCE_FACTOR if not self.rag_collection: # TODO: zero-shot mode examples = [] else: if not self.rag_collection.indexers: raise ValueError("RAG collection must have an indexer attached") - rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm") + rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm", + mmr_relevance_factor=mmr_relevance_factor) examples = rs.rows if not examples: raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}") @@ -143,23 +153,43 @@ def derive(self, object: OBJECT) -> Optional[Inference]: ) prompt_clauses.append(prompt_clause) - prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n" system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config) + system_prompt += "\n".join(additional_prompt_texts or []) + prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n" - def make_text(texts): - return "\n".join(prompt_clauses) + prompt_end + def make_text(texts: List[str]): + return "\n".join(texts) + prompt_end try: encoding = encoding_for_model(model_name) except KeyError: encoding = encoding_for_model("gpt-4") token_limit = get_token_limit(model_name) - prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit) + prompt = render_formatted_text(make_text, values=prompt_clauses, + encoding=encoding, token_limit=token_limit, + additional_text=system_prompt) logger.info(f"Prompt: {prompt}") response = model.prompt(prompt, system_prompt) yaml_str = response.text() logger.info(f"Response: {yaml_str}") - return Inference(predicted_object=self._parse_yaml_payload(yaml_str)) + predicted_object = self._parse_yaml_payload(yaml_str, strict=True) + if self.config.validate_results: + base_collection = self.training_data.base_collection + errs = list(base_collection.iter_validate_collection([predicted_object])) + if errs: + print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}") + print(f"PARSED: {predicted_object}") + print(f"ERRORS: {errs}") + if iteration > MAX_ITERATIONS: + raise ValueError(f"Validation errors: {errs}") + extra_texts = [ + "Make sure results conform to the schema. Previously you provided:\n", + yaml_str, + "\nThis was invalid.\n", + "Validation errors:\n", + ] + [self.object_to_text(e) for e in errs] + return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts) + return RAGInference(predicted_object=predicted_object, iterations=iteration+1, query=object) def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]: if "```" in yaml_str: diff --git a/src/linkml_store/inference/inference_config.py b/src/linkml_store/inference/inference_config.py index 0dc5bf4..1556d27 100644 --- a/src/linkml_store/inference/inference_config.py +++ b/src/linkml_store/inference/inference_config.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any from pydantic import BaseModel, ConfigDict, Field @@ -36,6 +36,7 @@ class InferenceConfig(BaseModel, extra="forbid"): train_test_split: Optional[Tuple[float, float]] = None llm_config: Optional[LLMConfig] = None random_seed: Optional[int] = None + validate_results: Optional[bool] = None @classmethod def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig": @@ -58,6 +59,7 @@ class Inference(BaseModel, extra="forbid"): """ Result of an inference derivation. """ - + query: Optional[OBJECT] = Field(default=None, description="The query object.") predicted_object: OBJECT = Field(..., description="The predicted object.") confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0) + explanation: Optional[Any] = Field(default=None, description="Explanation of the prediction.") diff --git a/src/linkml_store/utils/vector_utils.py b/src/linkml_store/utils/vector_utils.py new file mode 100644 index 0000000..3a79a75 --- /dev/null +++ b/src/linkml_store/utils/vector_utils.py @@ -0,0 +1,165 @@ +import logging +from typing import List, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +LOL = List[List[float]] + +def pairwise_cosine_similarity(vector1: np.array, vector2: np.array) -> float: + """ + Calculate the cosine similarity between two vectors. + + >>> v100 = np.array([1, 0, 0]) + >>> v010 = np.array([0, 1, 0]) + >>> v001 = np.array([0, 0, 1]) + >>> v011 = np.array([0, 1, 1]) + >>> pairwise_cosine_similarity(v100, v010) + 0.0 + >>> pairwise_cosine_similarity(v100, v001) + 0.0 + >>> pairwise_cosine_similarity(v010, v001) + 0.0 + >>> pairwise_cosine_similarity(v100, v100) + 1.0 + >>> f"{pairwise_cosine_similarity(v010, v011):0.3f}" + '0.707' + + :param vector1: + :param vector2: + :return: + """ + dot_product = np.dot(vector1, vector2) + norm1 = np.linalg.norm(vector1) + norm2 = np.linalg.norm(vector2) + return dot_product / (norm1 * norm2) + + +def compute_cosine_similarity_matrix(list1: LOL, list2: LOL) -> np.ndarray: + """ + Compute cosine similarity between two lists of vectors. + + Result is a two column vector sim[ROW][COL] where ROW is from list1 and COL is from list2. + + :param list1: + :param list2: + :return: + """ + # Convert lists to numpy arrays + matrix1 = np.array(list1) + matrix2 = np.array(list2) + + # Normalize the vectors in both matrices + matrix1_norm = matrix1 / np.linalg.norm(matrix1, axis=1)[:, np.newaxis] + matrix2_norm = matrix2 / np.linalg.norm(matrix2, axis=1)[:, np.newaxis] + + # Compute dot products (resulting in cosine similarity values) + cosine_similarity_matrix = np.dot(matrix1_norm, matrix2_norm.T) + + return cosine_similarity_matrix + + +def top_matches(cosine_similarity_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Find the top match for each row in the cosine similarity matrix. + + :param cosine_similarity_matrix: + :return: + """ + # Find the index of the maximum value in each row + top_match_indices = np.argmax(cosine_similarity_matrix, axis=1) + + # Find the maximum similarity value in each row + top_match_values = np.amax(cosine_similarity_matrix, axis=1) + + return top_match_indices, top_match_values + + +def top_n_matches( + cosine_similarity_matrix: np.ndarray, n: int = 10 +) -> Tuple[np.ndarray, np.ndarray]: + # Find the indices that would sort each row in descending order + sorted_indices = np.argsort(-cosine_similarity_matrix, axis=1) + + # Take the first n indices from the sorted indices to get the top n matches + top_n_indices = sorted_indices[:, :n] + + # Take the first n values from the sorted values to get the top n match values + top_n_values = -np.sort(-cosine_similarity_matrix, axis=1)[:, :n] + + return top_n_indices, top_n_values + + +def mmr_diversified_search( + query_vector: np.ndarray, document_vectors: List[np.ndarray], relevance_factor=0.5, top_n=None +) -> List[int]: + """ + Perform diversified search using Maximal Marginal Relevance (MMR). + + :param query_vector: The vector representing the query. + :param document_vectors: The vectors representing the documents. + :param relevance_factor: The balance parameter between relevance and diversity. + :param top_n: The number of results to return. If None, return all. + :return: A list of indices representing the diversified order of documents. + """ + if top_n is None: + # If no specific number of results is specified, return all + top_n = len(document_vectors) + + if top_n == 0: + return [] + + # Calculate cosine similarities between query and all documents + norms_query = np.linalg.norm(query_vector) + norms_docs = np.linalg.norm(document_vectors, axis=1) + similarities = np.dot(document_vectors, query_vector) / (norms_docs * norms_query) + + # Initialize set of selected indices and results list + selected_indices = set() + result_indices = [] + + # Diversified search loop + for _ in range(top_n): + max_mmr = float("-inf") + best_index = None + + # Loop over all documents + for idx, _doc_vector in enumerate(document_vectors): + if idx not in selected_indices: + relevance = relevance_factor * similarities[idx] + diversity = 0 + + # Penalize based on similarity to already selected documents + if selected_indices: + max_sim_to_selected = max( + [ + np.dot(document_vectors[idx], document_vectors[s]) + / ( + np.linalg.norm(document_vectors[idx]) + * np.linalg.norm(document_vectors[s]) + ) + for s in selected_indices + ] + ) + diversity = (1 - relevance_factor) * max_sim_to_selected + + mmr_score = relevance - diversity + + # Update best MMR score and index + if mmr_score > max_mmr: + max_mmr = mmr_score + best_index = idx + + # Add the best document to the result and mark it as selected + if best_index is None: + logger.warning(f"No best index found over {len(document_vectors)} documents.") + continue + result_indices.append(best_index) + selected_indices.add(best_index) + + return result_indices + + + diff --git a/tests/test_inference/test_rag_engine.py b/tests/test_inference/test_rag_engine.py index 2fd1099..7cf6e24 100644 --- a/tests/test_inference/test_rag_engine.py +++ b/tests/test_inference/test_rag_engine.py @@ -1,6 +1,9 @@ import logging import pytest +from linkml_runtime import SchemaView +from linkml_runtime.dumpers import yaml_dumper +from linkml_runtime.utils.schema_builder import SchemaBuilder from linkml_store.inference import InferenceConfig, get_inference_engine from linkml_store.inference.implementations.rag_inference_engine import RAGInferenceEngine from linkml_store.inference.implementations.rule_based_inference_engine import RuleBasedInferenceEngine @@ -86,8 +89,8 @@ def test_inference_nested(handle): assert obj assert "triples" in obj assert isinstance(obj["triples"], list) + # don't enforce a strict match for now assert any(t for t in obj["triples"] if t["subject"] == "a" and t["object"] == "b") - # TODO: fuzzy matches for complex objects - we don't expect precise matches check_accuracy2(ie, targets, threshold=0.33, features=features) ie2 = roundtrip(ie) check_accuracy2(ie2, targets, threshold=0.33, features=features, test_data=ie.testing_data.as_dataframe()) @@ -98,4 +101,90 @@ def test_inference_nested(handle): ie = get_inference_engine("rag", config=config) ie.load_and_split_data(collection) ie.initialize_model() + # TODO: check why roundtrip doesn't clear the cache # check_accuracy2(ie2, targets, threshold=0.33, features=features, test_data=ie.testing_data.as_dataframe()) + + + +@pytest.mark.integration +@pytest.mark.parametrize("handle", SCHEMES) +def test_with_validation(handle): + """ + Test RAG inference in validation mode. + + In validation mode, the results of the RAG inference are validated using the LinkML schema. + If it fails, the error is presented to the LLM on a second iteration. + + We test this using a simple extraction schema, where we have training examples that pair + texts with extracted relationships/triples of the subject-predicate-object form. + + We will attempt to foil the engin with a deliberately hard to guess enumeration permissible value + for the predicate ("played_a_leading_role_in"). + + this value is not present in the training set, and we do not present the schema ahead of time, + so we do not expect the LLM to succeed on the first iteration, in which it will make up a predicate. + This will fail validation, but the validation error includes the actual permissible values, so + we expect this to succeed the second time. + """ + client = create_client(handle) + db = client.get_database() + db.import_database(INPUT_DIR / "nested-target.yaml", Format.YAML, collection_name="test_rag") + collection = db.get_collection("test_rag") + collection.metadata.type = "Extraction" + features = ["paper.abstract"] + targets = ["triples.subject", "triples.predicate", "triples.object"] + config = InferenceConfig(target_attributes=targets, feature_attributes=features) + ie = get_inference_engine("rag", config=config) + assert isinstance(ie.config, InferenceConfig) + ie.config.validate_results = True + # tr_coll = ie.training_data.collection + sb = SchemaBuilder() + sb.add_class("Triple", ["subject", "predicate", "object"]) + sb.add_class("Paper", ["abstract"]) + sb.add_class("Extraction", ["triples", "paper"]) + sb.add_slot("triples", multivalued=True, inlined_as_list=True, range="Triple", replace_if_present=True) + sb.add_defaults() + schema = sb.schema + sv = SchemaView(schema) + # print(yaml_dumper.dumps(sv.schema)) + collection.parent.set_schema_view(sv) + assert collection.target_class_name == "Extraction" + cd = collection.class_definition() + assert cd.name == "Extraction" + assert cd.slots + print(yaml_dumper.dumps(sv.schema)) + # split into test and training; + # for RAG the "training" set is the set used as the RAG database + ie.load_and_split_data(collection) + ie.initialize_model() + assert isinstance(ie, RAGInferenceEngine) + result = ie.derive({"paper": {"abstract": "a precedes b, and b precedes c"}}) + assert result + obj = result.predicted_object + assert obj + assert "triples" in obj + assert isinstance(obj["triples"], list) + # don't enforce a strict match for now + assert any(t for t in obj["triples"] if t["subject"] == "a" and t["object"] == "b") + check_accuracy2(ie, targets, threshold=0.33, features=features) + # now we will attempt to foil the engine by restricting the schema + sb.add_enum("PredicateType", ["likes", "has_part", "is_a", "created", "consumed", "played_a_leading_role_in"]) + sb.add_slot("predicate", range="PredicateType", replace_if_present=True) + sv = SchemaView(sb.schema) + collection.parent.set_schema_view(sv) + errs = list(collection.iter_validate_collection([{"triples": [{"subject": "a", "predicate": "unknown", "object": "b"}]}])) + assert len(errs) == 1 + result = ie.derive({"paper": {"abstract": "Mark Hamill played a starring role in the movie Star Wars"}}) + assert result + obj = result.predicted_object + assert obj + print(obj) + assert any(t for t in obj["triples"] if t["predicate"] == "played_a_leading_role_in") + # highly unlikely to solve this on the first go, because it requires out of band knowledge + # (note that in future this unit test could conceivably be used in training models, in which case + # it will need to be modified to a different hard-to-guess predicate) + assert result.iterations > 1 + + + +