From 104a200c8166d80707e4c69c390457e944162c2c Mon Sep 17 00:00:00 2001 From: Pingdred <67059270+Pingdred@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:53:17 +0200 Subject: [PATCH 1/5] embed_tools refactoring --- core/cat/mad_hatter/mad_hatter.py | 60 +++++++++++++------------------ 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 66717e86..917da65a 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -162,39 +162,17 @@ def save_active_plugins_to_db(self, active_plugins): def embed_tools(self): # retrieve from vectorDB all tool embeddings - all_tools_points = self.ccat.memory.vectors.procedural.get_all_points() + embedded_tools = self.ccat.memory.vectors.procedural.get_all_points() - # easy access to plugin tools - plugins_tools_index = {t.description: t for t in self.tools} - - points_to_be_deleted = [] - - vector_db = self.ccat.memory.vectors.vector_db - - # loop over vectors - for record in all_tools_points: - # if the tools is active in plugins, assign embedding - try: - tool_description = record.payload["page_content"] - plugins_tools_index[tool_description].embedding = record.vector - # log(plugins_tools_index[tool_description], "WARNING") - # else delete it - except Exception as e: - log(f"Deleting embedded tool: {record.payload['page_content']}", "WARNING") - points_to_be_deleted.append(record.id) - - if len(points_to_be_deleted) > 0: - vector_db.delete( - collection_name="procedural", - points_selector=points_to_be_deleted - ) + # easy acces to (point_id, tool_description) + embedded_tools_description = [(t.id, t.payload["page_content"]) for t in embedded_tools] # loop over tools for tool in self.tools: - # if there is no embedding, create it - if not tool.embedding: + # if the tool is not embedded + if tool.description not in embedded_tools_description: # save it to DB - ids_inserted = self.ccat.memory.vectors.procedural.add_texts( + self.ccat.memory.vectors.procedural.add_texts( [tool.description], [{ "source": "tool", @@ -204,15 +182,25 @@ def embed_tools(self): }], ) - # retrieve saved point and assign embedding to the Tool - records_inserted = vector_db.retrieve( - collection_name="procedural", - ids=ids_inserted, - with_vectors=True - ) - tool.embedding = records_inserted[0].vector - log(f"Newly embedded tool: {tool.description}", "WARNING") + + points_to_be_deleted = [] + + tools_description = [t.description for t in self.tools] + + # loop over embedded tools + for tool_embedded_description in embedded_tools_description: + # if the tool is not active, it inserts it in the list of points to be delete + if tool_embedded_description[1] not in tools_description: + log(f"Deleted embedded tool: {tool_embedded_description[1]}", "WARNING") + points_to_be_deleted.append(tool_embedded_description[0]) + + # delete not active tools + if len(points_to_be_deleted) > 0: + self.ccat.memory.vectors.vector_db.delete( + collection_name="procedural", + points_selector=points_to_be_deleted + ) # activate / deactivate plugin def toggle_plugin(self, plugin_id): From e61428a2199d0bf3172770b533e02201da3cefde Mon Sep 17 00:00:00 2001 From: Pingdred <67059270+Pingdred@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:54:08 +0200 Subject: [PATCH 2/5] Removed attribute embedding from CatTool --- core/cat/mad_hatter/decorators/tool.py | 4 ---- core/tests/mad_hatter/test_mad_hatter.py | 7 +------ 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index 01afe2cb..89993ace 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -22,10 +22,6 @@ def augment_tool(self, cat_instance): if cat_arg_signature in self.description: self.description = self.description.replace(cat_arg_signature, ")") - # Tool embedding is saved in the "procedural" vector DB collection. - # During CheshireCat.bootstrap(), after memory is loaded, the mad_hatter will retrieve the embedding from memory or create one if not present, and assign this attribute - self.embedding = None - def _run(self, input_by_llm): return self.func(input_by_llm, cat=self.cat) diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 1c7a14b0..380f5932 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -51,7 +51,7 @@ def test_instantiation_discovery(mad_hatter): assert "what time is it" in tool.docstring assert isfunction(tool.func) assert tool.return_direct == False - assert tool.embedding is None # not embedded yet + #assert tool.embedding is None # not embedded yet # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() @@ -93,11 +93,6 @@ def test_plugin_install(mad_hatter: MadHatter): assert new_hook.priority == 2 assert id(new_hook) == id(mad_hatter.hooks[0]) # same object in memory! - # tool has been embedded - assert type(new_tool.embedding) == list - assert len(new_tool.embedding) == 128 # fake embedder - assert type(new_tool.embedding[0]) == float - # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() assert len(active_plugins) == 2 From 2646774493b681ce2cfd11f65b1af99fa1c59114 Mon Sep 17 00:00:00 2001 From: pieroit Date: Fri, 4 Aug 2023 21:07:35 +0200 Subject: [PATCH 3/5] fix bug in MadHatter.embed_tools() --- core/cat/mad_hatter/mad_hatter.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index a59ed9db..2ef679b4 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -170,13 +170,14 @@ def embed_tools(self): embedded_tools = self.ccat.memory.vectors.procedural.get_all_points() # easy acces to (point_id, tool_description) - embedded_tools_description = [(t.id, t.payload["page_content"]) for t in embedded_tools] + embedded_tools_ids = [t.id for t in embedded_tools] + embedded_tools_descriptions = [t.payload["page_content"] for t in embedded_tools] - # loop over tools + # loop over mad_hatter tools for tool in self.tools: # if the tool is not embedded - if tool.description not in embedded_tools_description: - # save it to DB + if tool.description not in embedded_tools_descriptions: + # embed the tool and save it to DB self.ccat.memory.vectors.procedural.add_texts( [tool.description], [{ @@ -189,16 +190,16 @@ def embed_tools(self): log(f"Newly embedded tool: {tool.description}", "WARNING") - points_to_be_deleted = [] - - tools_description = [t.description for t in self.tools] + # easy access to mad hatter tools (found in plugins) + mad_hatter_tools_descriptions = [t.description for t in self.tools] - # loop over embedded tools - for tool_embedded_description in embedded_tools_description: - # if the tool is not active, it inserts it in the list of points to be delete - if tool_embedded_description[1] not in tools_description: - log(f"Deleted embedded tool: {tool_embedded_description[1]}", "WARNING") - points_to_be_deleted.append(tool_embedded_description[0]) + # loop over embedded tools and delete the ones not present in active plugins + points_to_be_deleted = [] + for id, descr in zip(embedded_tools_ids, embedded_tools_descriptions): + # if the tool is not active, it inserts it in the list of points to be deleted + if descr not in mad_hatter_tools_descriptions: + log(f"Deleting embedded tool: {descr}", "WARNING") + points_to_be_deleted.append(id) # delete not active tools if len(points_to_be_deleted) > 0: From 36f9d17b410291f5483306d0f8ce4373d6fbcdd0 Mon Sep 17 00:00:00 2001 From: pieroit Date: Fri, 4 Aug 2023 21:20:46 +0200 Subject: [PATCH 4/5] rename mock_plugin tool; fix ambiguous assertion --- core/tests/mocks/mock_plugin/mock_tool.py | 6 +++--- core/tests/routes/plugins/test_plugin_toggle.py | 16 ++++++++-------- .../plugins/test_plugins_install_uninstall.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/tests/mocks/mock_plugin/mock_tool.py b/core/tests/mocks/mock_plugin/mock_tool.py index 3150d04e..501b10c3 100644 --- a/core/tests/mocks/mock_plugin/mock_tool.py +++ b/core/tests/mocks/mock_plugin/mock_tool.py @@ -2,7 +2,7 @@ @tool(return_direct=True) -def random_idea(topic, cat): - """Use to produce random ideas. Input is the topic.""" +def mock_tool(topic, cat): + """Used to test mock tools. Input is the topic.""" - return f"A random idea about {topic} :)" + return f"A mock about {topic} :)" diff --git a/core/tests/routes/plugins/test_plugin_toggle.py b/core/tests/routes/plugins/test_plugin_toggle.py index d37d30b7..73b62f25 100644 --- a/core/tests/routes/plugins/test_plugin_toggle.py +++ b/core/tests/routes/plugins/test_plugin_toggle.py @@ -22,16 +22,16 @@ def test_deactivate_plugin(client, just_installed_plugin): installed_plugins = response.json()["installed"] mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"] assert len(mock_plugin) == 1 # plugin installed - assert not mock_plugin[0]["active"] # plugin NOT active + assert mock_plugin[0]["active"] == False # plugin NOT active - # GET plugin info, plugin is not active + # GET single plugin info, plugin is not active response = client.get("/plugins/mock_plugin") - assert not response.json()["data"]["active"] + assert response.json()["data"]["active"] == False # tool has been taken away tools = get_embedded_tools(client) tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert not "random_idea" in tool_names + assert not "mock_tool" in tool_names def test_reactivate_plugin(client, just_installed_plugin): @@ -47,13 +47,13 @@ def test_reactivate_plugin(client, just_installed_plugin): installed_plugins = response.json()["installed"] mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"] assert len(mock_plugin) == 1 # plugin installed - assert mock_plugin[0]["active"] # plugin active + assert mock_plugin[0]["active"] == True # plugin active - # GET plugin info, plugin is active + # GET single plugin info, plugin is active response = client.get("/plugins/mock_plugin") - assert response.json()["data"]["active"] + assert response.json()["data"]["active"] == True # tool has been re-embedded tools = get_embedded_tools(client) tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" in tool_names \ No newline at end of file + assert "mock_tool" in tool_names \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 992b81a8..cd53b110 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -29,7 +29,7 @@ def test_plugin_install_upload_zip(client, just_installed_plugin): # check whether new tools have been embedded tools = get_embedded_tools(client) tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" in tool_names + assert "mock_tool" in tool_names def test_plugin_uninstall(client, just_installed_plugin): @@ -50,4 +50,4 @@ def test_plugin_uninstall(client, just_installed_plugin): # plugin tool disappeared tools = get_embedded_tools(client) tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" not in tool_names + assert "mock_tool" not in tool_names From a536d6b5cbb19d6dad6a567e40e67339d402d910 Mon Sep 17 00:00:00 2001 From: pieroit Date: Fri, 4 Aug 2023 21:31:36 +0200 Subject: [PATCH 5/5] more precise tests for tool sync --- core/tests/routes/plugins/test_plugin_toggle.py | 8 ++++++-- .../routes/plugins/test_plugins_install_uninstall.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/core/tests/routes/plugins/test_plugin_toggle.py b/core/tests/routes/plugins/test_plugin_toggle.py index 73b62f25..13074517 100644 --- a/core/tests/routes/plugins/test_plugin_toggle.py +++ b/core/tests/routes/plugins/test_plugin_toggle.py @@ -30,8 +30,10 @@ def test_deactivate_plugin(client, just_installed_plugin): # tool has been taken away tools = get_embedded_tools(client) + assert len(tools) == 1 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert not "mock_tool" in tool_names + assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin def test_reactivate_plugin(client, just_installed_plugin): @@ -55,5 +57,7 @@ def test_reactivate_plugin(client, just_installed_plugin): # tool has been re-embedded tools = get_embedded_tools(client) + assert len(tools) == 2 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" in tool_names \ No newline at end of file + assert "mock_tool" in tool_names + assert "get_the_time" in tool_names # from core_plugin \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index cd53b110..ca2d5874 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -26,10 +26,12 @@ def test_plugin_install_upload_zip(client, just_installed_plugin): # plugin has been actually extracted in (mock) plugins folder assert os.path.exists(mock_plugin_final_folder) - # check whether new tools have been embedded + # check whether new tool has been embedded tools = get_embedded_tools(client) + assert len(tools) == 2 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) assert "mock_tool" in tool_names + assert "get_the_time" in tool_names # from core_plugin def test_plugin_uninstall(client, just_installed_plugin): @@ -49,5 +51,7 @@ def test_plugin_uninstall(client, just_installed_plugin): # plugin tool disappeared tools = get_embedded_tools(client) + assert len(tools) == 1 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin