Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChatModelWithLLMIface to use chat models #33

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions gptstonks_api/chat_model_llm_iface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from langchain.llms import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.outputs import Generation, LLMResult


class ChatModelWithLLMIface(BaseLLM):
"""Wrapper model to transform the ChatModel interface to a LLM interface.

This class will be moved to `openbb-chat`.
"""

chat_model: BaseChatModel
system_message: str = "You write concise and complete answers."

def _generate(
self,
prompts: list[str],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
outputs = []
for prompt in prompts:
messages = [SystemMessage(content=self.system_message), HumanMessage(content=prompt)]
chat_result = self.chat_model._generate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
outputs.append(
[Generation(text=chat_gen.text) for chat_gen in chat_result.generations]
)
return LLMResult(generations=outputs)

def _llm_type(self) -> str:
"""Return type of chat model."""
return self.chat_model._llm_type()
7 changes: 6 additions & 1 deletion gptstonks_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.chat_models import ChatOpenAI
from langchain.globals import set_debug
from langchain.llms import Bedrock, LlamaCpp, OpenAI, VertexAI
from langchain.tools import DuckDuckGoSearchResults, WikipediaQueryRun
Expand All @@ -26,6 +27,7 @@
from openbb import obb
from openbb_chat.kernels.auto_llama_index import AutoLlamaIndex

from .chat_model_llm_iface import ChatModelWithLLMIface
from .utils import (
arun_qa_over_tool_output,
fix_frequent_code_errors,
Expand Down Expand Up @@ -115,7 +117,10 @@ def init_data():
"top_p": float(os.getenv("LLM_TOP_P", 1.0)),
}
if model_provider == "openai":
llm = OpenAI(**llm_common_kwargs)
if "instruct" in llm_model_name:
llm = OpenAI(**llm_common_kwargs)
else:
llm = ChatModelWithLLMIface(chat_model=ChatOpenAI(**llm_common_kwargs))
elif model_provider == "anyscale":
raise NotImplementedError("Anyscale does not support yet async API in langchain")
elif model_provider == "bedrock":
Expand Down
Loading