diff --git a/docs/02-streaming.ipynb b/docs/02-streaming.ipynb index 3857404..7190d28 100644 --- a/docs/02-streaming.ipynb +++ b/docs/02-streaming.ipynb @@ -15,24 +15,11 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/pydantic/_internal/_config.py:341: UserWarning: Valid config keys have changed in V2:\n", - "* 'allow_population_by_field_name' has been renamed to 'populate_by_name'\n", - "* 'smart_union' has been removed\n", - " warnings.warn(message, UserWarning)\n", - "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ - "from semantic_router.llms import OpenAILLM\n", + "from openai import OpenAI\n", "\n", - "llm = OpenAILLM(name=\"gpt-4o-2024-08-06\")" + "client = OpenAI()" ] }, { @@ -48,7 +35,11 @@ " query: str = Field(description=\"Search query for internet information\")\n", "\n", "class Memory(BaseModel):\n", - " query: str = Field(description=\"Self-directed query to search information from your long term memory\")" + " query: str = Field(description=\"Self-directed query to search information from your long term memory\")\n", + "\n", + "class FinalAnswer(BaseModel):\n", + " answer: str = Field(description=\"Final answer to the user query, must be in markdown format\")\n", + " sources: str = Field(description=\"Sources used to answer the user query, must be in markdown format\")" ] }, { @@ -60,106 +51,155 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=False\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=False\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:39 INFO semantic_router.utils.logger JB TEMP !!!: stream=False\u001b[0m\n" + "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/pydantic/_internal/_config.py:341: UserWarning: Valid config keys have changed in V2:\n", + "* 'allow_population_by_field_name' has been renamed to 'populate_by_name'\n", + "* 'smart_union' has been removed\n", + " warnings.warn(message, UserWarning)\n", + "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ - "import ast\n", + "import json\n", "import openai\n", "from graphai import router, node\n", - "from semantic_router.schema import Message\n", "\n", "\n", "@node(start=True)\n", "def node_start(input: dict):\n", " \"\"\"Descriptive string for the node.\"\"\"\n", - " print(\"node_a\")\n", + " print(\">>> node_start\")\n", " return {\"input\": input}\n", "\n", "\n", - "@router\n", - "def node_router(input: dict):\n", - " print(\"node_router\")\n", + "@router(stream=True)\n", + "def node_router(input: dict, callback):\n", + " print(\">>> node_router\")\n", " query = input[\"query\"]\n", " messages = [\n", - " Message(\n", - " role=\"system\",\n", - " content=\"\"\"You are a helpful assistant. Select the best route to answer the user query. ONLY choose one function.\"\"\",\n", - " ),\n", - " Message(role=\"user\", content=query),\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"\"\"You are a helpful assistant. Select the best route to answer the user query. ONLY choose one function.\"\"\",\n", + " },\n", + " {\"role\": \"user\", \"content\": query},\n", " ]\n", - " response = llm(\n", + " # we stream directly from the client\n", + " stream = client.chat.completions.create(\n", " messages=messages,\n", - " function_schemas=[\n", + " model=\"gpt-4o-mini\",\n", + " stream=True,\n", + " tools=[\n", " openai.pydantic_function_tool(Search),\n", " openai.pydantic_function_tool(Memory),\n", " ],\n", + " tool_choice=\"required\",\n", " )\n", - " choice = ast.literal_eval(response)[0]\n", - " print(\"choice\", choice)\n", + "\n", + " first_chunk = True # first chunk contains the tool name\n", + " args_str = \"\"\n", + " for chunk in stream:\n", + " choice = chunk.choices[0]\n", + " if first_chunk:\n", + " toolname = choice.delta.tool_calls[0].function.name.lower()\n", + " first_chunk = False\n", + " callback(f\"\")\n", + " elif choice.finish_reason == \"tool_calls\":\n", + " # this means we finished the tool call\n", + " pass\n", + " else:\n", + " chunk = choice.delta.tool_calls[0].function.arguments\n", + " callback(chunk)\n", + " args_str += chunk\n", + " args = json.loads(args_str)\n", " return {\n", - " \"choice\": choice[\"function_name\"].lower(),\n", - " \"input\": {**input, **choice[\"arguments\"]},\n", + " \"choice\": toolname,\n", + " \"input\": {**input, **args},\n", " }\n", "\n", "\n", "@node(stream=True)\n", - "def memory(input: dict, callback = None):\n", - " print(\"memory\")\n", - " query = input[\"query\"]\n", - " callback(query)\n", - " print(\"memory query\", query)\n", - " return {\"input\": {\"text\": \"The user is in Bali right now.\", **input}}\n", + "def memory(input: dict, callback):\n", + " print(\">>> memory\")\n", + " #query = input[\"query\"]\n", + " # dummy function for testing to simulate memory search\n", + " out = \"The user is in Bali right now.\"\n", + " callback(out)\n", + " return {\"input\": {\"text\": out, **input}}\n", "\n", "\n", "@node(stream=True)\n", - "def search(input: dict, callback = None):\n", - " print(\"search\")\n", - " query = input[\"query\"]\n", - " callback(query)\n", - " print(\"search query\", query)\n", + "def search(input: dict, callback):\n", + " print(\">>> search\")\n", + " #query = input[\"query\"]\n", + " # another dummy function for testing to simulate search\n", + " out = \"The most famous photo spot in Bali is the Uluwatu Temple.\"\n", + " callback(out)\n", " return {\n", " \"input\": {\n", - " \"text\": \"The most famous photo spot in Bali is the Uluwatu Temple.\",\n", + " \"text\": out,\n", " **input,\n", " }\n", " }\n", "\n", "\n", "@node(stream=True)\n", - "def llm_node(input: dict, callback = None):\n", - " print(\"llm_node\")\n", + "def llm_node(input: dict, callback):\n", + " print(\">>> llm_node\")\n", " chat_history = [\n", - " Message(role=message[\"role\"], content=message[\"content\"])\n", + " {\"role\": message[\"role\"], \"content\": message[\"content\"]}\n", " for message in input[\"chat_history\"]\n", " ]\n", - "\n", + " # construct all messages\n", " messages = [\n", - " Message(role=\"system\", content=\"\"\"You are a helpful assistant.\"\"\"),\n", + " {\"role\": \"system\", \"content\": \"\"\"You are a helpful assistant.\"\"\"},\n", " *chat_history,\n", - " Message(\n", - " role=\"user\",\n", - " content=(\n", - " f\"Response to the following query from the user: {input['query']}\\n\"\n", - " \"Here is additional context. You can use it to answer the user query. \"\n", - " f\"But do not directly reference it: {input.get('text', '')}.\"\n", - " ),\n", - " ),\n", + " {\"role\": \"user\", \"content\": input[\"query\"]},\n", + " {\"role\": \"user\", \"content\": (\n", + " f\"Response to the following query from the user: {input['query']}\\n\"\n", + " \"Here is additional context. You can use it to answer the user query. \"\n", + " f\"But do not directly reference it: {input.get('text', '')}.\"\n", + " )},\n", " ]\n", - " response = llm(messages=messages)\n", - " return {\"output\": response}\n", + " # we stream directly from the client\n", + " stream = client.chat.completions.create(\n", + " messages=messages,\n", + " model=\"gpt-4o-mini\",\n", + " stream=True,\n", + " tools=[openai.pydantic_function_tool(FinalAnswer)],\n", + " tool_choice=\"required\",\n", + " )\n", + "\n", + " first_chunk = True # first chunk contains the tool name\n", + " args_str = \"\"\n", + " for chunk in stream:\n", + " try:\n", + " choice = chunk.choices[0]\n", + " if first_chunk:\n", + " toolname = choice.delta.tool_calls[0].function.name.lower()\n", + " first_chunk = False\n", + " callback(f\"\")\n", + " elif choice.finish_reason == \"tool_calls\":\n", + " # this means we finished the tool call\n", + " pass\n", + " else:\n", + " chunk = choice.delta.tool_calls[0].function.arguments\n", + " callback(chunk)\n", + " args_str += chunk\n", + " except:\n", + " pass\n", + " args = json.loads(args_str)\n", + " return {\n", + " \"choice\": toolname,\n", + " \"input\": {**input, **args},\n", + " }\n", "\n", "\n", "@node(end=True)\n", "def node_end(input: dict, callback = None):\n", " \"\"\"Descriptive string for the node.\"\"\"\n", - " print(\"node_end\")\n", + " print(\">>> node_end\")\n", + " callback.close()\n", " return {\"output\": input[\"output\"]}" ] }, @@ -199,13 +239,6 @@ "#graph.visualize()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 7, @@ -233,253 +266,135 @@ ] }, { - "cell_type": "code", - "execution_count": 9, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "async def chat(query: str):\n", - " chat_history.append({\"role\": \"user\", \"content\": query})\n", - " response = await graph.async_execute(\n", - " input={\"input\": {\"query\": query, \"chat_history\": chat_history}}\n", - " )\n", - " chat_history.append({\"role\": \"assistant\", \"content\": response[\"output\"]})\n", - " return response[\"output\"]" + "Now we can get a response:" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_start'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_router'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:41 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "node_a\n", - "node_router\n", - "choice {'function_name': 'Memory', 'arguments': {'query': 'user location'}}\n" + ">>> node_start\n", + ">>> node_router\n", + ">>> memory\n", + ">>> llm_node\n" ] }, - { - "ename": "ValueError", - "evalue": "No callback provided to graph. Please add it via `.add_callback`.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m chat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdo you remember where I am?\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "Cell \u001b[0;32mIn[9], line 3\u001b[0m, in \u001b[0;36mchat\u001b[0;34m(query)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mchat\u001b[39m(query: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 2\u001b[0m chat_history\u001b[38;5;241m.\u001b[39mappend({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: query})\n\u001b[0;32m----> 3\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m graph\u001b[38;5;241m.\u001b[39masync_execute(\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery\u001b[39m\u001b[38;5;124m\"\u001b[39m: query, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchat_history\u001b[39m\u001b[38;5;124m\"\u001b[39m: chat_history}}\n\u001b[1;32m 5\u001b[0m )\n\u001b[1;32m 6\u001b[0m chat_history\u001b[38;5;241m.\u001b[39mappend({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: response[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m\"\u001b[39m]})\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", - "File \u001b[0;32m~/Documents/projects/aurelio-labs/graphai/graphai/graph.py:67\u001b[0m, in \u001b[0;36mGraph.async_execute\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m current_node\u001b[38;5;241m.\u001b[39mstream:\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo callback provided to graph. Please add it via `.add_callback`.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 68\u001b[0m \u001b[38;5;66;03m# add callback tokens and param here if we are streaming\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback\u001b[38;5;241m.\u001b[39mstart_node(node_name\u001b[38;5;241m=\u001b[39mcurrent_node\u001b[38;5;241m.\u001b[39mname)\n", - "\u001b[0;31mValueError\u001b[0m: No callback provided to graph. Please add it via `.add_callback`." - ] - } - ], - "source": [ - "await chat(\"do you remember where I am?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "from graphai.callback import Callback\n", - "\n", - "callback = Callback()\n", - "\n", - "graph.add_callback(callback)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ { "data": { "text/plain": [ - "" + "{'choice': 'finalanswer',\n", + " 'input': {'text': 'The user is in Bali right now.',\n", + " 'query': 'user location',\n", + " 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'},\n", + " {'role': 'user', 'content': 'do you remember where I am?'}],\n", + " 'answer': \"It seems you're currently in Bali!\",\n", + " 'sources': ''}}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "callback" + "query = \"do you remember where I am?\"\n", + "\n", + "chat_history.append({\"role\": \"user\", \"content\": query})\n", + "response = await graph.async_execute(\n", + " input={\"input\": {\"query\": query, \"chat_history\": chat_history}}\n", + ")\n", + "response" ] }, { - "cell_type": "code", - "execution_count": 14, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_start'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_router'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:52 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'do you remember where I am?', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "node_a\n", - "node_router\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='memory'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: callback=\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='llm_node'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger TEST: input={'input': {'text': 'The user is in Bali right now.', 'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: callback=\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger TEST 2: input={'input': {'text': 'The user is in Bali right now.', 'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:53 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'text': 'The user is in Bali right now.', 'query': 'user location', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}]}, 'callback': }\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "choice {'function_name': 'Memory', 'arguments': {'query': 'user location'}}\n", - "memory\n", - "memory query user location\n", - "llm_node\n" - ] - }, - { - "data": { - "text/plain": [ - "\"You're currently in Bali. How can I assist you further?\"" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "await chat(\"do you remember where I am?\")" + "We can see the order that our graph nodes were called, and the final JSON output. However, none of this was streamed — for streaming we need to use the `callback` object that can be generated from our graph:" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_start'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='node_router'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:55 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'tell me a long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "node_a\n", - "node_router\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='memory'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger TEST: input={'input': {'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: callback=\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger TEST 2: input={'input': {'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: func.__name__='llm_node'\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger TEST: input={'input': {'text': 'The user is in Bali right now.', 'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}}\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: stream=True\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: callback=\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger TEST 2: input={'input': {'text': 'The user is in Bali right now.', 'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}, 'callback': }\u001b[0m\n", - "\u001b[32m2024-09-07 19:54:56 INFO semantic_router.utils.logger JB TEMP !!!: args_dict={'input': {'text': 'The user is in Bali right now.', 'query': 'long story', 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'user', 'content': 'do you remember where I am?'}, {'role': 'assistant', 'content': \"You're currently in Bali. How can I assist you further?\"}]}, 'callback': }\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "choice {'function_name': 'Memory', 'arguments': {'query': 'long story'}}\n", - "memory\n", - "memory query long story\n", - "llm_node\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-09-07 19:54:57 INFO semantic_router.utils.logger [*] Stream Started\u001b[0m\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ + ">>> node_start\n", + ">>> node_router\n", + ">>> memory\n", + ">>> llm_node\n", + "\n", + "\n", + "{\"\n", + "query\n", + "\":\"\n", + "long\n", + " story\n", + "\"}\n", + "\n", "\n", - "user location\n", + "The user is in Bali right now.\n", "\n", "\n", + "\n", + "{\"\n", + "answer\n", + "\":\"\n", + "It\n", + " sounds\n", + " like\n", + " you've\n", + " got\n", + " something\n", + " interesting\n", + " going\n", + " on\n", + "!\n", + " If\n", + " you'd\n", + " like\n", + " to\n", + " share\n", + " more\n", + " about\n", + " your\n", + " long\n", + " story\n", + ",\n", + " I'm\n", + " here\n", + " to\n", + " listen\n", + ".\",\"\n", + "sources\n", + "\":\n", + "\"\"\n", + "}\n", "\n", - "\n", - "long story\n", - "\n", - "\n", - "\n" + "\n" ] } ], "source": [ + "import asyncio\n", + "\n", + "callback = graph.get_callback()\n", + "\n", "query = \"tell me a long story\"\n", - "response = await graph.async_execute(\n", + "response = asyncio.create_task(graph.async_execute(\n", " input={\"input\": {\"query\": query, \"chat_history\": chat_history}}\n", - ")\n", + "))\n", "\n", "async for token in callback.aiter():\n", " print(token)" diff --git a/graphai/callback.py b/graphai/callback.py index 16f1671..4de897a 100644 --- a/graphai/callback.py +++ b/graphai/callback.py @@ -26,14 +26,16 @@ async def acall(self, token: str, node_name: Optional[str] = None): self.queue.put_nowait(token) async def aiter(self) -> AsyncIterator[str]: - if log_stream: - logger.info("[*] Stream Started") + """Used by receiver to get the tokens from the stream queue. Creates + a generator that yields tokens from the queue until the END token is + received. + """ while True: token = await self.queue.get() yield token - # await asyncio.sleep(10) - if log_stream: - logger.info("[X] Stream Closed") + self.queue.task_done() + if token == "": + break async def start_node(self, node_name: str, active: bool = True): self.current_node_name = node_name @@ -49,6 +51,9 @@ async def end_node(self, node_name: str): if self.active: self.queue.put_nowait(f"") + async def close(self): + self.queue.put_nowait("") + def _check_node_name(self, node_name: Optional[str] = None): if node_name: # we confirm this is the current node diff --git a/graphai/graph.py b/graphai/graph.py index 179f0eb..eb89c77 100644 --- a/graphai/graph.py +++ b/graphai/graph.py @@ -1,6 +1,7 @@ from typing import List from graphai.nodes.base import _Node from graphai.callback import Callback +from semantic_router.utils.logger import logger class Graph: @@ -9,6 +10,7 @@ def __init__(self): self.edges = [] self.start_node = None self.end_nodes = [] + self.Callback = Callback self.callback = None def add_node(self, node): @@ -58,12 +60,15 @@ def _is_valid(self): async def async_execute(self, input): # TODO JB: may need to add init callback here to init the queue on every new execution + if self.callback is None: + self.callback = self.get_callback() current_node = self.start_node state = input - while current_node not in self.end_nodes: + while True: # we invoke the node here if current_node.stream: if self.callback is None: + # TODO JB: can remove? raise ValueError("No callback provided to graph. Please add it via `.add_callback`.") # add callback tokens and param here if we are streaming await self.callback.start_node(node_name=current_node.name) @@ -82,6 +87,8 @@ async def async_execute(self, input): if current_node.is_end: break # TODO JB: may need to add end callback here to close the queue for every execution + if self.callback: + await self.callback.close() return state def execute(self, input): @@ -112,8 +119,9 @@ def execute(self, input): # TODO JB: may need to add end callback here to close the queue for every execution return state - def add_callback(self, callback: Callback): - self.callback = callback + def get_callback(self): + self.callback = self.Callback() + return self.callback def _get_node_by_name(self, node_name: str) -> _Node: for node in self.nodes: diff --git a/graphai/nodes/base.py b/graphai/nodes/base.py index 63739ea..b508f54 100644 --- a/graphai/nodes/base.py +++ b/graphai/nodes/base.py @@ -38,7 +38,6 @@ def _node( raise ValueError("Node must be a callable function.") func_signature = inspect.signature(func) - logger.info(f"JB TEMP !!!: {stream=}") class NodeClass: _func_signature = func_signature @@ -61,7 +60,6 @@ def execute(self): bound_args.apply_defaults() # Prepare arguments, including callback if stream is True args_dict = bound_args.arguments.copy() # Copy arguments to modify safely - logger.info(f"JB TEMP !!!: {args_dict=}") return func(**args_dict) # Pass only the necessary arguments @classmethod @@ -82,18 +80,13 @@ def get_signature(cls): @classmethod def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None): - logger.info(f"JB TEMP !!!: {func.__name__=}") - logger.info(f"TEST: {input=}") if callback: - logger.info(f"JB TEMP !!!: {stream=}") if stream: - logger.info(f"JB TEMP !!!: {callback=}") input["callback"] = callback else: raise ValueError( f"Error in node {func.__name__}. When callback provided, stream must be True." ) - logger.info(f"TEST 2: {input=}") instance = cls(**input) out = instance.execute() return out