|
19 | 19 | "outputs": [],
|
20 | 20 | "source": [
|
21 | 21 | "import os\n",
|
22 |
| - "import openai\n", |
23 | 22 | "\n",
|
24 | 23 | "# OpenAI env variables\n",
|
25 | 24 | "os.environ[\"OPENAI_API_KEY\"] = \"YOUR_OPENAI_API_KEY_HERE\"\n",
|
|
58 | 57 | "metadata": {},
|
59 | 58 | "outputs": [],
|
60 | 59 | "source": [
|
61 |
| - "import random\n", |
62 |
| - "import time\n", |
| 60 | + "from typing import List\n", |
63 | 61 | "\n",
|
64 | 62 | "import numpy as np\n",
|
65 | 63 | "from openai import OpenAI\n",
|
66 |
| - "from sklearn.feature_extraction.text import TfidfVectorizer\n", |
67 | 64 | "from sklearn.metrics.pairwise import cosine_similarity\n",
|
| 65 | + "from sklearn.feature_extraction.text import TfidfVectorizer\n", |
68 | 66 | "\n",
|
69 | 67 | "from openlayer.lib import trace, trace_openai"
|
70 | 68 | ]
|
|
93 | 91 | "\n",
|
94 | 92 | " Answers to a user query with the LLM.\n",
|
95 | 93 | " \"\"\"\n",
|
96 |
| - " context = self.retrieve_context(user_query)\n", |
| 94 | + " context = self.retrieve_contexts(user_query)\n", |
97 | 95 | " prompt = self.inject_prompt(user_query, context)\n",
|
98 | 96 | " answer = self.generate_answer_with_gpt(prompt)\n",
|
99 | 97 | " return answer\n",
|
100 | 98 | "\n",
|
101 | 99 | " @trace()\n",
|
102 |
| - " def retrieve_context(self, query: str) -> str:\n", |
| 100 | + " def retrieve_contexts(self, query: str) -> List[str]:\n", |
103 | 101 | " \"\"\"Context retriever.\n",
|
104 | 102 | "\n",
|
105 | 103 | " Given the query, returns the most similar context (using TFIDF).\n",
|
106 | 104 | " \"\"\"\n",
|
107 | 105 | " query_vector = self.vectorizer.transform([query])\n",
|
108 | 106 | " cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()\n",
|
109 | 107 | " most_relevant_idx = np.argmax(cosine_similarities)\n",
|
110 |
| - " return self.context_sections[most_relevant_idx]\n", |
| 108 | + " contexts = [self.context_sections[most_relevant_idx]]\n", |
| 109 | + " return contexts\n", |
111 | 110 | "\n",
|
112 |
| - " @trace()\n", |
113 |
| - " def inject_prompt(self, query: str, context: str):\n", |
| 111 | + " # You can also specify the name of the `context_kwarg` to unlock RAG metrics that\n", |
| 112 | + " # evaluate the performance of the context retriever. The value of the `context_kwarg`\n", |
| 113 | + " # should be a list of strings.\n", |
| 114 | + " @trace(context_kwarg=\"contexts\")\n", |
| 115 | + " def inject_prompt(self, query: str, contexts: List[str]) -> List[dict]:\n", |
114 | 116 | " \"\"\"Combines the query with the context and returns\n",
|
115 | 117 | " the prompt (formatted to conform with OpenAI models).\"\"\"\n",
|
116 | 118 | " return [\n",
|
117 | 119 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
118 | 120 | " {\n",
|
119 | 121 | " \"role\": \"user\",\n",
|
120 |
| - " \"content\": f\"Answer the user query using only the following context: {context}. \\nUser query: {query}\",\n", |
| 122 | + " \"content\": f\"Answer the user query using only the following context: {contexts[0]}. \\nUser query: {query}\",\n", |
121 | 123 | " },\n",
|
122 | 124 | " ]\n",
|
123 | 125 | "\n",
|
|
172 | 174 | {
|
173 | 175 | "cell_type": "code",
|
174 | 176 | "execution_count": null,
|
175 |
| - "id": "f960a36f-3438-4c81-8cdb-ca078aa509cd", |
| 177 | + "id": "a45d5562", |
176 | 178 | "metadata": {},
|
177 | 179 | "outputs": [],
|
178 | 180 | "source": []
|
|
0 commit comments