Skip to content

Commit

Permalink
Add OpenAI ChatCompletion Agent (oughtinc#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuhlmueller authored Jul 26, 2023
2 parents ab59d0b + 7ce39a1 commit 9cfbcf0
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 3 deletions.
3 changes: 3 additions & 0 deletions ice/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ice.agents.fake import FakeAgent
from ice.agents.human import HumanAgent
from ice.agents.openai import OpenAIAgent
from ice.agents.openai import OpenAIChatCompletionAgent
from ice.agents.openai_reasoning import OpenAIReasoningAgent
from ice.agents.ought_inference import OughtInferenceAgent
from ice.agents.squad import SquadAgent
Expand All @@ -24,6 +25,8 @@ def __init__(self, *args, **kwargs):


MACHINE_AGENTS = {
"chatgpt": lambda: OpenAIChatCompletionAgent(model="gpt-3.5-turbo"),
"gpt-4": lambda: OpenAIChatCompletionAgent(model="gpt-4"),
"instruct": lambda: OpenAIAgent(),
"instruct-reasoning": lambda: OpenAIReasoningAgent(),
"instruct-reasoning-crowd": lambda: OpenAIReasoningAgent(num_workers=8),
Expand Down
88 changes: 88 additions & 0 deletions ice/agents/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ice.agents.base import Agent
from ice.agents.base import Stop
from ice.apis.openai import openai_chatcomplete
from ice.apis.openai import openai_complete
from ice.environment import env
from ice.utils import longest_common_prefix
Expand Down Expand Up @@ -139,3 +140,90 @@ def lookup_prob(choice: str):
def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIChatCompletionAgent(Agent):
"""An agent that uses the OpenAI ChatCompletion API to generate completions."""

def __init__(
self,
model: str = "gpt-3.5-turbo",
temperature: float = 0.0,
top_p: float = 1.0,
):
self.model = model
self.temperature = temperature
self.top_p = top_p

async def complete(
self,
*,
prompt: str,
stop: Stop = None,
verbose: bool = False,
default: str = "",
max_tokens: int = 256,
) -> str:
"""Generate an answer to a question given some context."""
if verbose:
self._print_markdown(prompt)
response = await self._complete(prompt, stop=stop, max_tokens=max_tokens)
completion = self._extract_completion(response)
if verbose:
self._print_markdown(completion)
return completion

async def classify(
self,
*,
prompt: str,
choices: tuple[str, ...],
default: Optional[str] = None,
verbose: bool = False,
) -> tuple[dict[str, float], Optional[str]]:
raise NotImplementedError(
"OpenAI ChatCompletion has no option to score a classification."
)

async def relevance(
self,
*,
context: str,
question: str,
verbose: bool = False,
default: Optional[float] = None,
) -> float:
raise NotImplementedError(
"OpenAI ChatCompletion has no option to return a relevance score."
)

async def predict(
self, *, context: str, default: str = "", verbose: bool = False
) -> dict[str, float]:
raise NotImplementedError(
"OpenAI ChatCompletion does not support getting probabilities."
)

async def _complete(self, prompt, **kwargs) -> dict:
"""Send a completion request to the OpenAI API with the given prompt and parameters."""
kwargs.update(
{
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"n": 1,
}
)
messages = [{"role": "user", "content": prompt}]
response = await openai_chatcomplete(messages, **kwargs)
if "choices" not in response:
raise ValueError(f"No choices in response: {response}")
return response

def _extract_completion(self, response: dict) -> str:
"""Extract the answer text from the completion response."""
return response["choices"][0]["message"]["content"].strip()

def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)
35 changes: 33 additions & 2 deletions ice/apis/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def _post(
# TODO: support more model types for conversion


def get_davinci_equivalent_tokens(response: dict) -> int:
def extract_total_tokens(response: dict) -> int:
return response.get("usage", {}).get("total_tokens", 0)


Expand Down Expand Up @@ -166,5 +166,36 @@ async def openai_complete(
response = await _post("completions", json=params, cache_id=cache_id)
if isinstance(response, TooLongRequestError):
raise response
add_fields(davinci_equivalent_tokens=get_davinci_equivalent_tokens(response))
add_fields(davinci_equivalent_tokens=extract_total_tokens(response))
return response


@trace
async def openai_chatcomplete(
messages: list[dict[str, str]],
stop: Optional[str] = "\n",
top_p: float = 1,
temperature: float = 0,
model: str = "gpt-3.5-turbo",
max_tokens: int = 256,
logit_bias: Optional[Mapping[str, Union[int, float]]] = None,
n: int = 1,
cache_id: int = 0, # for repeated non-deterministic sampling using caching
) -> dict:
"""Send a completion request to the OpenAI API and return the JSON response."""
params = {
"messages": messages,
"stop": stop,
"top_p": top_p,
"temperature": temperature,
"model": model,
"max_tokens": max_tokens,
"n": n,
}
if logit_bias:
params["logit_bias"] = logit_bias # type: ignore[assignment]
response = await _post("chat/completions", json=params, cache_id=cache_id)
if isinstance(response, TooLongRequestError):
raise response
add_fields(total_tokens=extract_total_tokens(response))
return response
2 changes: 1 addition & 1 deletion ice/recipe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import importlib
import importlib.util
import sys
from abc import abstractmethod
from collections.abc import Awaitable
Expand Down

0 comments on commit 9cfbcf0

Please sign in to comment.