Skip to content

Commit

Permalink
Merge branch 'Pingdred-embed_tools_refactoring' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Aug 4, 2023
2 parents 62baa69 + a536d6b commit 1dd6362
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 62 deletions.
4 changes: 0 additions & 4 deletions core/cat/mad_hatter/decorators/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ def augment_tool(self, cat_instance):
if cat_arg_signature in self.description:
self.description = self.description.replace(cat_arg_signature, ")")

# Tool embedding is saved in the "procedural" vector DB collection.
# During CheshireCat.bootstrap(), after memory is loaded, the mad_hatter will retrieve the embedding from memory or create one if not present, and assign this attribute
self.embedding = None

def _run(self, input_by_llm):
return self.func(input_by_llm, cat=self.cat)

Expand Down
65 changes: 27 additions & 38 deletions core/cat/mad_hatter/mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,39 +167,18 @@ def save_active_plugins_to_db(self, active_plugins):
def embed_tools(self):

# retrieve from vectorDB all tool embeddings
all_tools_points = self.ccat.memory.vectors.procedural.get_all_points()
embedded_tools = self.ccat.memory.vectors.procedural.get_all_points()

# easy access to plugin tools
plugins_tools_index = {t.description: t for t in self.tools}
# easy acces to (point_id, tool_description)
embedded_tools_ids = [t.id for t in embedded_tools]
embedded_tools_descriptions = [t.payload["page_content"] for t in embedded_tools]

points_to_be_deleted = []

vector_db = self.ccat.memory.vectors.vector_db

# loop over vectors
for record in all_tools_points:
# if the tools is active in plugins, assign embedding
try:
tool_description = record.payload["page_content"]
plugins_tools_index[tool_description].embedding = record.vector
# log(plugins_tools_index[tool_description], "WARNING")
# else delete it
except Exception as e:
log(f"Deleting embedded tool: {record.payload['page_content']}", "WARNING")
points_to_be_deleted.append(record.id)

if len(points_to_be_deleted) > 0:
vector_db.delete(
collection_name="procedural",
points_selector=points_to_be_deleted
)

# loop over tools
# loop over mad_hatter tools
for tool in self.tools:
# if there is no embedding, create it
if not tool.embedding:
# save it to DB
ids_inserted = self.ccat.memory.vectors.procedural.add_texts(
# if the tool is not embedded
if tool.description not in embedded_tools_descriptions:
# embed the tool and save it to DB
self.ccat.memory.vectors.procedural.add_texts(
[tool.description],
[{
"source": "tool",
Expand All @@ -209,15 +188,25 @@ def embed_tools(self):
}],
)

# retrieve saved point and assign embedding to the Tool
records_inserted = vector_db.retrieve(
collection_name="procedural",
ids=ids_inserted,
with_vectors=True
)
tool.embedding = records_inserted[0].vector

log(f"Newly embedded tool: {tool.description}", "WARNING")

# easy access to mad hatter tools (found in plugins)
mad_hatter_tools_descriptions = [t.description for t in self.tools]

# loop over embedded tools and delete the ones not present in active plugins
points_to_be_deleted = []
for id, descr in zip(embedded_tools_ids, embedded_tools_descriptions):
# if the tool is not active, it inserts it in the list of points to be deleted
if descr not in mad_hatter_tools_descriptions:
log(f"Deleting embedded tool: {descr}", "WARNING")
points_to_be_deleted.append(id)

# delete not active tools
if len(points_to_be_deleted) > 0:
self.ccat.memory.vectors.vector_db.delete(
collection_name="procedural",
points_selector=points_to_be_deleted
)

# activate / deactivate plugin
def toggle_plugin(self, plugin_id):
Expand Down
7 changes: 1 addition & 6 deletions core/tests/mad_hatter/test_mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_instantiation_discovery(mad_hatter):
assert "what time is it" in tool.docstring
assert isfunction(tool.func)
assert tool.return_direct == False
assert tool.embedding is None # not embedded yet
#assert tool.embedding is None # not embedded yet

# list of active plugins in DB is correct
active_plugins = mad_hatter.load_active_plugins_from_db()
Expand Down Expand Up @@ -93,11 +93,6 @@ def test_plugin_install(mad_hatter: MadHatter):
assert new_hook.priority == 2
assert id(new_hook) == id(mad_hatter.hooks[0]) # same object in memory!

# tool has been embedded
assert type(new_tool.embedding) == list
assert len(new_tool.embedding) == 128 # fake embedder
assert type(new_tool.embedding[0]) == float

# list of active plugins in DB is correct
active_plugins = mad_hatter.load_active_plugins_from_db()
assert len(active_plugins) == 2
Expand Down
6 changes: 3 additions & 3 deletions core/tests/mocks/mock_plugin/mock_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@tool(return_direct=True)
def random_idea(topic, cat):
"""Use to produce random ideas. Input is the topic."""
def mock_tool(topic, cat):
"""Used to test mock tools. Input is the topic."""

return f"A random idea about {topic} :)"
return f"A mock about {topic} :)"
20 changes: 12 additions & 8 deletions core/tests/routes/plugins/test_plugin_toggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ def test_deactivate_plugin(client, just_installed_plugin):
installed_plugins = response.json()["installed"]
mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"]
assert len(mock_plugin) == 1 # plugin installed
assert not mock_plugin[0]["active"] # plugin NOT active
assert mock_plugin[0]["active"] == False # plugin NOT active

# GET plugin info, plugin is not active
# GET single plugin info, plugin is not active
response = client.get("/plugins/mock_plugin")
assert not response.json()["data"]["active"]
assert response.json()["data"]["active"] == False

# tool has been taken away
tools = get_embedded_tools(client)
assert len(tools) == 1
tool_names = list(map(lambda t: t["metadata"]["name"], tools))
assert not "random_idea" in tool_names
assert "mock_tool" not in tool_names
assert "get_the_time" in tool_names # from core_plugin


def test_reactivate_plugin(client, just_installed_plugin):
Expand All @@ -47,13 +49,15 @@ def test_reactivate_plugin(client, just_installed_plugin):
installed_plugins = response.json()["installed"]
mock_plugin = [p for p in installed_plugins if p["id"] == "mock_plugin"]
assert len(mock_plugin) == 1 # plugin installed
assert mock_plugin[0]["active"] # plugin active
assert mock_plugin[0]["active"] == True # plugin active

# GET plugin info, plugin is active
# GET single plugin info, plugin is active
response = client.get("/plugins/mock_plugin")
assert response.json()["data"]["active"]
assert response.json()["data"]["active"] == True

# tool has been re-embedded
tools = get_embedded_tools(client)
assert len(tools) == 2
tool_names = list(map(lambda t: t["metadata"]["name"], tools))
assert "random_idea" in tool_names
assert "mock_tool" in tool_names
assert "get_the_time" in tool_names # from core_plugin
10 changes: 7 additions & 3 deletions core/tests/routes/plugins/test_plugins_install_uninstall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def test_plugin_install_upload_zip(client, just_installed_plugin):
# plugin has been actually extracted in (mock) plugins folder
assert os.path.exists(mock_plugin_final_folder)

# check whether new tools have been embedded
# check whether new tool has been embedded
tools = get_embedded_tools(client)
assert len(tools) == 2
tool_names = list(map(lambda t: t["metadata"]["name"], tools))
assert "random_idea" in tool_names
assert "mock_tool" in tool_names
assert "get_the_time" in tool_names # from core_plugin


def test_plugin_uninstall(client, just_installed_plugin):
Expand All @@ -49,5 +51,7 @@ def test_plugin_uninstall(client, just_installed_plugin):

# plugin tool disappeared
tools = get_embedded_tools(client)
assert len(tools) == 1
tool_names = list(map(lambda t: t["metadata"]["name"], tools))
assert "random_idea" not in tool_names
assert "mock_tool" not in tool_names
assert "get_the_time" in tool_names # from core_plugin

0 comments on commit 1dd6362

Please sign in to comment.