Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ results.csv
results/

.DS_Store
env
env

adle_notebook.ipynb
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ We send to the LLM a text description of the screen. The LLM decide on the next
# Installation

- Follow instructions in https://docs.diambra.ai/#installation
- Download the ROM and put it in `~/.diambra/roms`
- Download the ROM and put it in `~/.diambra/roms` (no need to dezip the content)
- (Optional) Create and activate a [new python venv](https://docs.python.org/3/library/venv.html)
- Install dependencies with `make install` or `pip install -r requirements.txt`
- Create a `.env` file and fill it with the content like in the `.env.example` file
Expand Down
8 changes: 1 addition & 7 deletions agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
# load env variables before importing any other module
from dotenv import load_dotenv

load_dotenv()

from .robot import Robot
from .robot import TextRobot, VisionRobot
from .observer import KEN_GREEN, KEN_RED
from .llm import get_client
35 changes: 34 additions & 1 deletion agent/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.multi_modal_llms.base import MultiModalLLM


def get_client(model_str: str) -> FunctionCallingLLM:
Expand Down Expand Up @@ -45,4 +46,36 @@ def get_client(model_str: str) -> FunctionCallingLLM:

return Cerebras(model=model_name)

raise ValueError(f"Provider {provider} not found")
raise ValueError(f"Provider {provider} not found in models")


def get_client_multimodal(model_str: str) -> MultiModalLLM:
split_result = model_str.split(":")
if len(split_result) == 1:
# Assume default provider to be openai
provider = "ollama"
model_name = split_result[0]
elif len(split_result) > 2:
# Some model names have :, so we need to join the rest of the string
provider = split_result[0]
model_name = ":".join(split_result[1:])
else:
provider = split_result[0]
model_name = split_result[1]

if provider == "openai":
from llama_index.multi_modal_llms.openai import OpenAIMultiModal

return OpenAIMultiModal(model=model_name)

if provider == "ollama":
from llama_index.multi_modal_llms.ollama import OllamaMultiModal

return OllamaMultiModal(model=model_name)

elif provider == "mistral":
from llama_index.multi_modal_llms.mistralai import MistralAIMultiModal

return MistralAIMultiModal(model=model_name)

raise ValueError(f"Provider {provider} not found in multimodal models")
Loading