diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index 01afe2cb..89993ace 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -22,10 +22,6 @@ def augment_tool(self, cat_instance): if cat_arg_signature in self.description: self.description = self.description.replace(cat_arg_signature, ")") - # Tool embedding is saved in the "procedural" vector DB collection. - # During CheshireCat.bootstrap(), after memory is loaded, the mad_hatter will retrieve the embedding from memory or create one if not present, and assign this attribute - self.embedding = None - def _run(self, input_by_llm): return self.func(input_by_llm, cat=self.cat) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 858c5641..2ef679b4 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -167,39 +167,18 @@ def save_active_plugins_to_db(self, active_plugins): def embed_tools(self): # retrieve from vectorDB all tool embeddings - all_tools_points = self.ccat.memory.vectors.procedural.get_all_points() + embedded_tools = self.ccat.memory.vectors.procedural.get_all_points() - # easy access to plugin tools - plugins_tools_index = {t.description: t for t in self.tools} + # easy acces to (point_id, tool_description) + embedded_tools_ids = [t.id for t in embedded_tools] + embedded_tools_descriptions = [t.payload["page_content"] for t in embedded_tools] - points_to_be_deleted = [] - - vector_db = self.ccat.memory.vectors.vector_db - - # loop over vectors - for record in all_tools_points: - # if the tools is active in plugins, assign embedding - try: - tool_description = record.payload["page_content"] - plugins_tools_index[tool_description].embedding = record.vector - # log(plugins_tools_index[tool_description], "WARNING") - # else delete it - except Exception as e: - log(f"Deleting embedded tool: {record.payload['page_content']}", "WARNING") - points_to_be_deleted.append(record.id) - - if len(points_to_be_deleted) > 0: - vector_db.delete( - collection_name="procedural", - points_selector=points_to_be_deleted - ) - - # loop over tools + # loop over mad_hatter tools for tool in self.tools: - # if there is no embedding, create it - if not tool.embedding: - # save it to DB - ids_inserted = self.ccat.memory.vectors.procedural.add_texts( + # if the tool is not embedded + if tool.description not in embedded_tools_descriptions: + # embed the tool and save it to DB + self.ccat.memory.vectors.procedural.add_texts( [tool.description], [{ "source": "tool", @@ -209,15 +188,25 @@ def embed_tools(self): }], ) - # retrieve saved point and assign embedding to the Tool - records_inserted = vector_db.retrieve( - collection_name="procedural", - ids=ids_inserted, - with_vectors=True - ) - tool.embedding = records_inserted[0].vector - log(f"Newly embedded tool: {tool.description}", "WARNING") + + # easy access to mad hatter tools (found in plugins) + mad_hatter_tools_descriptions = [t.description for t in self.tools] + + # loop over embedded tools and delete the ones not present in active plugins + points_to_be_deleted = [] + for id, descr in zip(embedded_tools_ids, embedded_tools_descriptions): + # if the tool is not active, it inserts it in the list of points to be deleted + if descr not in mad_hatter_tools_descriptions: + log(f"Deleting embedded tool: {descr}", "WARNING") + points_to_be_deleted.append(id) + + # delete not active tools + if len(points_to_be_deleted) > 0: + self.ccat.memory.vectors.vector_db.delete( + collection_name="procedural", + points_selector=points_to_be_deleted + ) # activate / deactivate plugin def toggle_plugin(self, plugin_id): diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 6261f2eb..8b01de4b 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -51,7 +51,7 @@ def test_instantiation_discovery(mad_hatter): assert "what time is it" in tool.docstring assert isfunction(tool.func) assert tool.return_direct == False - assert tool.embedding is None # not embedded yet + #assert tool.embedding is None # not embedded yet # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() @@ -93,11 +93,6 @@ def test_plugin_install(mad_hatter: MadHatter): assert new_hook.priority == 2 assert id(new_hook) == id(mad_hatter.hooks[0]) # same object in memory! - # tool has been embedded - assert type(new_tool.embedding) == list - assert len(new_tool.embedding) == 128 # fake embedder - assert type(new_tool.embedding[0]) == float - # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() assert len(active_plugins) == 2 diff --git a/core/tests/mocks/mock_plugin/mock_tool.py b/core/tests/mocks/mock_plugin/mock_tool.py index 3150d04e..501b10c3 100644 --- a/core/tests/mocks/mock_plugin/mock_tool.py +++ b/core/tests/mocks/mock_plugin/mock_tool.py @@ -2,7 +2,7 @@ @tool(return_direct=True) -def random_idea(topic, cat): - """Use to produce random ideas. Input is the topic.""" +def mock_tool(topic, cat): + """Used to test mock tools. Input is the topic.""" - return f"A random idea about {topic} :)" + return f"A mock about {topic} :)" diff --git a/core/tests/routes/plugins/test_plugin_toggle.py b/core/tests/routes/plugins/test_plugin_toggle.py index d37d30b7..13074517 100644 --- a/core/tests/routes/plugins/test_plugin_toggle.py +++ b/core/tests/routes/plugins/test_plugin_toggle.py @@ -22,16 +22,18 @@ def test_deactivate_plugin(client, just_installed_plugin): installed_plugins = response.json()["installed"] mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"] assert len(mock_plugin) == 1 # plugin installed - assert not mock_plugin[0]["active"] # plugin NOT active + assert mock_plugin[0]["active"] == False # plugin NOT active - # GET plugin info, plugin is not active + # GET single plugin info, plugin is not active response = client.get("/plugins/mock_plugin") - assert not response.json()["data"]["active"] + assert response.json()["data"]["active"] == False # tool has been taken away tools = get_embedded_tools(client) + assert len(tools) == 1 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert not "random_idea" in tool_names + assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin def test_reactivate_plugin(client, just_installed_plugin): @@ -47,13 +49,15 @@ def test_reactivate_plugin(client, just_installed_plugin): installed_plugins = response.json()["installed"] mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"] assert len(mock_plugin) == 1 # plugin installed - assert mock_plugin[0]["active"] # plugin active + assert mock_plugin[0]["active"] == True # plugin active - # GET plugin info, plugin is active + # GET single plugin info, plugin is active response = client.get("/plugins/mock_plugin") - assert response.json()["data"]["active"] + assert response.json()["data"]["active"] == True # tool has been re-embedded tools = get_embedded_tools(client) + assert len(tools) == 2 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" in tool_names \ No newline at end of file + assert "mock_tool" in tool_names + assert "get_the_time" in tool_names # from core_plugin \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 992b81a8..ca2d5874 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -26,10 +26,12 @@ def test_plugin_install_upload_zip(client, just_installed_plugin): # plugin has been actually extracted in (mock) plugins folder assert os.path.exists(mock_plugin_final_folder) - # check whether new tools have been embedded + # check whether new tool has been embedded tools = get_embedded_tools(client) + assert len(tools) == 2 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" in tool_names + assert "mock_tool" in tool_names + assert "get_the_time" in tool_names # from core_plugin def test_plugin_uninstall(client, just_installed_plugin): @@ -49,5 +51,7 @@ def test_plugin_uninstall(client, just_installed_plugin): # plugin tool disappeared tools = get_embedded_tools(client) + assert len(tools) == 1 tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "random_idea" not in tool_names + assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin