Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoMcCabe committed Oct 16, 2023
1 parent a436d0f commit 17a2d34
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 92 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ build/
.env.sh
venv/
ENV/
venv_test_installation/

# IDE-specific files
.vscode/
Expand Down Expand Up @@ -80,3 +81,5 @@ docs/db/
poetry.lock
.aider*
.gpteng

#test artifacts
153 changes: 73 additions & 80 deletions gpt_engineer/core/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,10 @@ def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""):
"""
self.temperature = temperature
self.azure_endpoint = azure_endpoint
self.model_name = (
fallback_model(model_name) if azure_endpoint == "" else model_name
)
self.llm = create_chat_model(self, self.model_name, self.temperature)
self.tokenizer = get_tokenizer(self.model_name)
self.model_name = model_name

self.llm = self._create_chat_model()
self.tokenizer = self._get_tokenizer()
logger.debug(f"Using model {self.model_name} with llm {self.llm}")

# initialize token usage log
Expand All @@ -170,6 +169,7 @@ def start(self, system: str, user: str, step_name: str) -> List[Message]:
List[Message]
The list of messages in the conversation.
"""

messages: List[Message] = [
SystemMessage(content=system),
HumanMessage(content=user),
Expand Down Expand Up @@ -463,94 +463,87 @@ def num_tokens_from_messages(self, messages: List[Message]) -> int:
n_tokens += 2 # every reply is primed with <im_start>assistant
return n_tokens

def _check_model_acess_and_fallback(self) -> str:
"""
Retrieve the specified model, or fallback to "gpt-3.5-turbo" if the model is not available.
def fallback_model(model: str) -> str:
"""
Retrieve the specified model, or fallback to "gpt-3.5-turbo" if the model is not available.
Parameters
----------
model : str
The name of the model to retrieve.
Parameters
----------
model : str
The name of the model to retrieve.
Returns
-------
str
The name of the retrieved model, or "gpt-3.5-turbo" if the specified model is not available.
"""
try:
openai.Model.retrieve(self.model_name)
except openai.InvalidRequestError:
print(
f"Model {self.model_name} not available for provided API key. Reverting "
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
"https://openai.com/waitlist/gpt-4-api\n"
)
self.model_name = "gpt-3.5-turbo"

Returns
-------
str
The name of the retrieved model, or "gpt-3.5-turbo" if the specified model is not available.
"""
try:
openai.Model.retrieve(model)
return model
except openai.InvalidRequestError:
print(
f"Model {model} not available for provided API key. Reverting "
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
"https://openai.com/waitlist/gpt-4-api\n"
)
return "gpt-3.5-turbo"
def _create_chat_model(self) -> BaseChatModel:
"""
Create a chat model with the specified model name and temperature.
Parameters
----------
model : str
The name of the model to create.
temperature : float
The temperature to use for the model.
def create_chat_model(self, model: str, temperature) -> BaseChatModel:
"""
Create a chat model with the specified model name and temperature.
Returns
-------
BaseChatModel
The created chat model.
"""
if self.azure_endpoint:
return AzureChatOpenAI(
openai_api_base=self.azure_endpoint,
openai_api_version="2023-05-15", # might need to be flexible in the future
deployment_name=self.model_name,
openai_api_type="azure",
streaming=True,
)

Parameters
----------
model : str
The name of the model to create.
temperature : float
The temperature to use for the model.
self._check_model_acess_and_fallback()

Returns
-------
BaseChatModel
The created chat model.
"""
if self.azure_endpoint:
return AzureChatOpenAI(
openai_api_base=self.azure_endpoint,
openai_api_version="2023-05-15", # might need to be flexible in the future
deployment_name=model,
openai_api_type="azure",
return ChatOpenAI(
model=self.model_name,
temperature=self.temperature,
streaming=True,
client=openai.ChatCompletion,
)
# Fetch available models from OpenAI API
supported = [model["id"] for model in openai.Model.list()["data"]]
if model not in supported:
raise ValueError(
f"Model {model} is not supported, supported models are: {supported}"
)
return ChatOpenAI(
model=model,
temperature=temperature,
streaming=True,
client=openai.ChatCompletion,
)

def _get_tokenizer(self):
"""
Get the tokenizer for the specified model.
def get_tokenizer(model: str):
"""
Get the tokenizer for the specified model.
Parameters
----------
model : str
The name of the model to get the tokenizer for.
Parameters
----------
model : str
The name of the model to get the tokenizer for.
Returns
-------
Tokenizer
The tokenizer for the specified model.
"""
if "gpt-4" in model or "gpt-3.5" in model:
return tiktoken.encoding_for_model(model)
Returns
-------
Tokenizer
The tokenizer for the specified model.
"""
if "gpt-4" in self.model_name or "gpt-3.5" in self.model_name:
return tiktoken.encoding_for_model(self.model_name)

logger.debug(
f"No encoder implemented for model {model}."
"Defaulting to tiktoken cl100k_base encoder."
"Use results only as estimates."
)
return tiktoken.get_encoding("cl100k_base")
logger.debug(
f"No encoder implemented for model {self.model_name}."
"Defaulting to tiktoken cl100k_base encoder."
"Use results only as estimates."
)
return tiktoken.get_encoding("cl100k_base")


def serialize_messages(messages: List[Message]) -> str:
Expand Down
101 changes: 95 additions & 6 deletions tests/test_ai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,98 @@
import pytest

from gpt_engineer.core.ai import AI
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.base import BaseChatModel
import copy


def test_start(monkeypatch):
"""
Test function for the AI system.
This test sets up a fake LLM and tests that the start method successfully returns a response
"""

# arrange
def mock_create_chat_model(self) -> BaseChatModel:
return FakeListChatModel(responses=["response1", "response2", "response3"])

monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model)

ai = AI("fake")

# act
response_messages = ai.start("system prompt", "user prompt", "step name")

# assert
assert response_messages[-1].content == "response1"


def test_next(monkeypatch):
"""
Test function for the AI system.
This test sets up a fake LLM and tests that the start method successfully returns a response
"""

# arrange
def mock_create_chat_model(self) -> BaseChatModel:
return FakeListChatModel(responses=["response1", "response2", "response3"])

monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model)

ai = AI("fake")
response_messages = ai.start("system prompt", "user prompt", "step name")

# act
response_messages = ai.next(
response_messages, "next user prompt", step_name="step name"
)

# assert
assert response_messages[-1].content == "response2"


def test_token_logging(monkeypatch):
"""
Test function for the AI system.
This test sets up a fake LLM and tests that the start method successfully returns a response
"""

# arrange
def mock_create_chat_model(self) -> BaseChatModel:
return FakeListChatModel(responses=["response1", "response2", "response3"])

monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model)

ai = AI("fake")

# act
initial_token_counts = (
ai.cumulative_prompt_tokens,
ai.cumulative_completion_tokens,
ai.cumulative_total_tokens,
)
response_messages = ai.start("system prompt", "user prompt", "step name")
token_counts_1 = (
ai.cumulative_prompt_tokens,
ai.cumulative_completion_tokens,
ai.cumulative_total_tokens,
)
ai.next(response_messages, "next user prompt", step_name="step name")
token_counts_2 = (
ai.cumulative_prompt_tokens,
ai.cumulative_completion_tokens,
ai.cumulative_total_tokens,
)

# assert
assert initial_token_counts == (0, 0, 0)

assert_all_greater_than(
token_counts_1, (1, 1, 1)
) # all the token counts are greater than 1

assert_all_greater_than(
token_counts_2, token_counts_1
) # all counts in token_counts_2 greater than token_counts_1


@pytest.mark.xfail(reason="Constructor assumes API access")
def test_ai():
AI()
# TODO Assert that methods behave and not only constructor.
def assert_all_greater_than(left, right):
assert all(x > y for x, y in zip(left, right))
6 changes: 0 additions & 6 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ def test_DBs_initialization(tmp_path):
assert isinstance(dbs_instance.project_metadata, DB)


def test_invalid_path():
with pytest.raises((PermissionError, OSError)):
# Test with a path that will raise a permission error
DB("/root/test")


def test_large_files(tmp_path):
db = DB(tmp_path)
large_content = "a" * (10**6) # 1MB of data
Expand Down

0 comments on commit 17a2d34

Please sign in to comment.