Skip to content

Add LocalAI provider #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/shelloracle/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def get_provider(name: str) -> type[Provider]:
:param name: the provider name
:return: the requested provider
"""
from .providers import Ollama, OpenAI
from .providers import Ollama, OpenAI, LocalAI
providers = {
Ollama.name: Ollama,
OpenAI.name: OpenAI
OpenAI.name: OpenAI,
LocalAI.name: LocalAI
}
return providers[name]
1 change: 1 addition & 0 deletions src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ollama import Ollama
from .openai import OpenAI
from .localai import LocalAI
49 changes: 49 additions & 0 deletions src/shelloracle/providers/localai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient

from ..config import Setting
from ..provider import Provider, ProviderError


class LocalAI(Provider):
name = "LocalAI"

host = Setting(default="localhost")
port = Setting(default=8080)
model = Setting(default="mistral-openorca")
system_prompt = Setting(
default=(
"Based on the following user description, generate a corresponding Bash command. Focus solely "
"on interpreting the requirements and translating them into a single, executable Bash command. "
"Ensure accuracy and relevance to the user's description. The output should be a valid Bash "
"command that directly aligns with the user's intent, ready for execution in a command-line "
"environment. Output nothing except for the command. No code block, no English explanation, "
"no start/end tags."
)
)

@property
def endpoint(self) -> str:
return f"http://{self.host}:{self.port}"

def __init__(self):
# Use a placeholder API key so the client will work
self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except APIError as e:
raise ProviderError(f"Something went wrong while querying LocalAI: {e}") from e