Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,7 @@
from gptcli.providers.anthropic import AnthropicCompletionProvider
from gptcli.providers.cohere import CohereCompletionProvider
from gptcli.providers.azure_openai import AzureOpenAICompletionProvider


class AssistantConfig(TypedDict, total=False):
messages: List[Message]
model: str
openai_base_url_override: Optional[str]
openai_api_key_override: Optional[str]
temperature: float
top_p: float
from gptcli.config import ModelConfig, AssistantConfig


CONFIG_DEFAULTS = {
Expand Down Expand Up @@ -73,17 +65,7 @@ def get_completion_provider(
openai_base_url_override: Optional[str] = None,
openai_api_key_override: Optional[str] = None,
) -> CompletionProvider:
if (
model.startswith("gpt")
or model.startswith("ft:gpt")
or model.startswith("oai-compat:")
or model.startswith("chatgpt")
or model.startswith("o1")
):
return OpenAICompletionProvider(
openai_base_url_override, openai_api_key_override
)
elif model.startswith("oai-azure:"):
if model.startswith("oai-azure:"):
return AzureOpenAICompletionProvider()
elif model.startswith("claude"):
return AnthropicCompletionProvider()
Expand All @@ -94,15 +76,18 @@ def get_completion_provider(
elif model.startswith("gemini"):
return GoogleCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")
return OpenAICompletionProvider(
openai_base_url_override, openai_api_key_override
)


class Assistant:
def __init__(self, config: AssistantConfig):
def __init__(self, config: AssistantConfig, model_configs: Optional[Dict[str, ModelConfig]] = None):
self.config = config
self.model_configs = model_configs or {}

@classmethod
def from_config(cls, name: str, config: AssistantConfig):
def from_config(cls, name: str, config: AssistantConfig, model_configs: Optional[Dict[str, ModelConfig]] = None):
config = config.copy()
if name in DEFAULT_ASSISTANTS:
# Merge the config with the default config
Expand All @@ -112,7 +97,7 @@ def from_config(cls, name: str, config: AssistantConfig):
if config.get(key) is None:
config[key] = default_config[key]

return cls(config)
return cls(config, model_configs=model_configs)

def init_messages(self) -> List[Message]:
return self.config.get("messages", [])[:]
Expand All @@ -124,17 +109,31 @@ def _param(self, param: str) -> Any:

def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]:
model = self._param("model")
# Check if there is a model configuration override for this model.
if model in self.model_configs:
model_conf = self.model_configs[model]
print(model_conf)
openai_api_key_override = model_conf['api_key'] or self.config.get("openai_api_key_override")
openai_base_url_override = model_conf['base_url'] or self.config.get("openai_base_url_override")
pricing_override = model_conf['pricing']
else:
openai_api_key_override = self._param("openai_api_key_override")
openai_base_url_override = self._param("openai_base_url_override")
pricing_override = None

completion_provider = get_completion_provider(
model,
self._param("openai_base_url_override"),
self._param("openai_api_key_override"),
openai_base_url_override,
openai_api_key_override,
)
return completion_provider.complete(
messages,
{
"model": model,
"temperature": float(self._param("temperature")),
"top_p": float(self._param("top_p")),
# Pass along the pricing override if available.
"pricing": pricing_override,
},
stream,
)
Expand All @@ -149,13 +148,15 @@ class AssistantGlobalArgs:


def init_assistant(
args: AssistantGlobalArgs, custom_assistants: Dict[str, AssistantConfig]
args: AssistantGlobalArgs,
custom_assistants: Dict[str, AssistantConfig],
model_configs: Optional[Dict[str, ModelConfig]] = None,
) -> Assistant:
name = args.assistant_name
if name in custom_assistants:
assistant = Assistant.from_config(name, custom_assistants[name])
assistant = Assistant.from_config(name, custom_assistants[name], model_configs=model_configs)
elif name in DEFAULT_ASSISTANTS:
assistant = Assistant.from_config(name, DEFAULT_ASSISTANTS[name])
assistant = Assistant.from_config(name, DEFAULT_ASSISTANTS[name], model_configs=model_configs)
else:
print(f"Unknown assistant: {name}")
sys.exit(1)
Expand Down
17 changes: 15 additions & 2 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,27 @@
import yaml
from attr import dataclass

from gptcli.assistant import AssistantConfig
from gptcli.providers.llama import LLaMAModelConfig
from gptcli.completion import Message

CONFIG_FILE_PATHS = [
os.path.join(os.path.expanduser("~"), ".config", "gpt-cli", "gpt.yml"),
os.path.join(os.path.expanduser("~"), ".gptrc"),
]

class AssistantConfig:
messages: List[Message]
model: str
openai_base_url_override: Optional[str]
openai_api_key_override: Optional[str]
temperature: float
top_p: float

@dataclass
class ModelConfig:
api_key: Optional[str] = None
base_url: Optional[str] = None
pricing: Optional[Dict[str, float]] = None

@dataclass
class GptCliConfig:
Expand All @@ -30,7 +43,7 @@ class GptCliConfig:
assistants: Dict[str, AssistantConfig] = {}
interactive: Optional[bool] = None
llama_models: Optional[Dict[str, LLaMAModelConfig]] = None

model_configs: Optional[Dict[str, ModelConfig]] = None

def choose_config_file(paths: List[str]) -> str:
for path in paths:
Expand Down
2 changes: 1 addition & 1 deletion gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def main():
if config.llama_models is not None:
init_llama_models(config.llama_models)

assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants)
assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants, model_configs=config.model_configs)

if args.prompt is not None:
run_non_interactive(args, assistant)
Expand Down
18 changes: 10 additions & 8 deletions gptcli/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def complete(
):
yield MessageDeltaEvent(response.choices[0].delta.content)

if response.usage and (pricing := gpt_pricing(args["model"])):
pricing = args.get("pricing") or gpt_pricing(args["model"])
if response.usage and pricing:
yield UsageEvent.with_pricing(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
Expand All @@ -73,13 +74,14 @@ def complete(
next_choice = response.choices[0]
if next_choice.message.content:
yield MessageDeltaEvent(next_choice.message.content)
if response.usage and (pricing := gpt_pricing(args["model"])):
yield UsageEvent.with_pricing(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
pricing=pricing,
)
pricing = args.get("pricing") or gpt_pricing(args["model"])
if response.usage and pricing:
yield UsageEvent.with_pricing(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
pricing=pricing,
)

except openai.BadRequestError as e:
raise BadRequestError(e.message) from e
Expand Down