Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions TestRunner/GenericTestRunner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import subprocess
from lib.utils import CodeGenSandbox
from abc import ABCMeta, abstractmethod
from pytest_plugins import ResultsCollector, SessionStartPlugin

Expand All @@ -10,35 +11,38 @@ def run(self, *args, **kwargs) -> (int, str): pass


class SubProcessTestRunner(GenericTestRunner):
code_dir: str
test_dir: str
sandbox: CodeGenSandbox

def __init__(self, _code, _test) -> None:
self.code_dir = _code
self.test_dir = _test
def __init__(self, sandbox: CodeGenSandbox) -> None:
self.sandbox = sandbox

def run(self, *args, **kwargs) -> (int, str):
# TODO: check that code_dir and test_dir exist
proc = subprocess.run(["pytest", self.test_dir], capture_output=True, universal_newlines=True)
proc = subprocess.run(
["pytest", self.sandbox.test_path],
cwd=self.sandbox.get_sandboxed_project_path(),
capture_output=True,
universal_newlines=True
)
return proc.returncode, proc.stdout


class InlineTestRunner(GenericTestRunner):

def __init__(self, _code, _test) -> None:
self.code_dir = _code
self.test_dir = _test

def run(self, *args, **kwargs) -> (int, str):
collector = ResultsCollector()
setup = SessionStartPlugin()
pytest.main(args=["-k", "ExampleClass"], plugins=[collector, setup])
_out = ""

if collector.exitcode > 0:
for report in collector.reports:
_out += f"{report.outcome.upper()} {report.nodeid} ... Outcome: - {report.longrepr.reprcrash.message}"
_out += "\n"
_out += report.longreprtext
_out += "\n"
return collector.exitcode, _out
# TODO: This implementation is currently defunct
# class InlineTestRunner(GenericTestRunner):
#
# def __init__(self, sandbox) -> None:
# self.test_dir = sandbox.get_sandboxed_test_path()
#
# def run(self, *args, **kwargs) -> (int, str):
# collector = ResultsCollector()
# setup = SessionStartPlugin()
# # TODO: Remove the ExampleClass reference here.
# pytest.main(args=["-k", "ExampleClass"], plugins=[collector, setup])
# _out = ""
#
# if collector.exitcode > 0:
# for report in collector.reports:
# _out += f"{report.outcome.upper()} {report.nodeid} ... Outcome: - {report.longrepr.reprcrash.message}"
# _out += "\n"
# _out += report.longreprtext
# _out += "\n"
# return collector.exitcode, _out
63 changes: 41 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
from lib.utils import CodeGenSandbox
import typer

import langroid as lr
from langroid.utils.configuration import set_global, Settings
from langroid.utils.logging import setup_colored_logging

from TestRunner.GenericTestRunner import GenericTestRunner, InlineTestRunner, SubProcessTestRunner
from TestRunner.GenericTestRunner import GenericTestRunner, SubProcessTestRunner

app = typer.Typer()
setup_colored_logging()


def generate_first_attempt(class_skeleton: str) -> None:
with open(class_skeleton, "r") as f:
def generate_first_attempt(sandbox: CodeGenSandbox) -> None:
with open(sandbox.get_sandboxed_class_path(), "r") as f:
class_skeleton = f.read()

cfg = lr.ChatAgentConfig(
Expand All @@ -30,11 +31,11 @@ def generate_first_attempt(class_skeleton: str) -> None:
f"Do not say 'here is the python code'"
f"Your output MUST be valid, runnable python code and NOTHING else."
f"{class_skeleton}")
with open(os.path.join(".", "generated", "test_class.py"), "w+") as _out:
with open(sandbox.get_sandboxed_class_path(), "w+") as _out:
_out.write(response.content)


def generate_next_attempt(test_results: str, test_results_insights: str) -> None:
def generate_next_attempt(sandbox: CodeGenSandbox, test_results: str, test_results_insights: str) -> None:
cfg = lr.ChatAgentConfig(
llm=lr.language_models.OpenAIGPTConfig(
chat_model="ollama/llama3:latest",
Expand All @@ -43,7 +44,7 @@ def generate_next_attempt(test_results: str, test_results_insights: str) -> None
vecdb=None
)
agent = lr.ChatAgent(cfg)
with open(os.path.join(".", "generated", "test_class.py"), "r") as f:
with open(sandbox.get_sandboxed_class_path(), "r") as f:
code_snippet = f.read()

prompt = f"""
Expand All @@ -65,7 +66,7 @@ def generate_next_attempt(test_results: str, test_results_insights: str) -> None
Your output MUST be valid, runnable python code and NOTHING else.
"""
response = agent.llm_response(prompt)
with open(os.path.join(".", "generated", "test_class.py"), "w") as _out:
with open(sandbox.get_sandboxed_class_path(), "w") as _out:
_out.write(response.content)


Expand Down Expand Up @@ -101,8 +102,8 @@ def teardown() -> None:
generated_file.truncate(0)


def chat(class_skeleton: str, test_dir: str, test_runner: GenericTestRunner, max_epochs: int=5) -> None:
generate_first_attempt(class_skeleton)
def chat(sandbox: CodeGenSandbox, test_runner: GenericTestRunner, max_epochs: int=5) -> None:
generate_first_attempt(sandbox)
solved = False
for _ in range(max_epochs):
# test_exit_code, test_results = get_test_results()
Expand All @@ -114,22 +115,44 @@ def chat(class_skeleton: str, test_dir: str, test_runner: GenericTestRunner, max
break
elif test_exit_code == 1:
results_insights = interpret_test_results(test_results)
generate_next_attempt(test_results, results_insights)
generate_next_attempt(sandbox, test_results, results_insights)
else:
solved = True
print("There is some problem with the test suite itself.")
break
teardown()
# teardown()
if not solved:
print(f"Reached the end of epoch {max_epochs} without finding a solution :(")


@app.command()
def main(
debug: bool = typer.Option(False, "--debug", "-d", help="debug mode"),
no_stream: bool = typer.Option(False, "--nostream", "-ns", help="no streaming"),
nocache: bool = typer.Option(False, "--nocache", "-nc", help="don't use cache"),
class_skeleton: str = typer.Option(None, "--class-skeleton", "-c", help="You must provide a class skeleton."),
test_dir: str = typer.Option(os.path.join(".", "test"), "--test-dir", "-t", help=""),
project_dir: str = typer.Argument(
default=".",
help="The project directory that contains your tests and class skeleton. "
"This directory may also have other contents. "
"The directory you give here will be cloned into a 'sandbox' for the code generator to operate in."
),
class_skeleton_path: str = typer.Argument(
default=os.path.join("assets", "test_class.py"),
help="Path to the class skeleton file, relative to project_dir."
),
test_path: str = typer.Argument(
default=os.path.join(".", "test"),
help="Path to the test file or directory, relative to project_dir."
),
sandbox_path: str = typer.Option(
"./build", "--sandbox-path", "-s",
help="You may optionally specify a location for the sandbox in which the code generator operates."
"Default: ./build"
),
max_epochs: int = typer.Option(
5, "--max-epochs", "-n", help="The maximum number of times to let the code generator try"
"before giving up."
)
) -> None:
set_global(
Settings(
Expand All @@ -138,15 +161,11 @@ def main(
stream=not no_stream,
)
)
assert os.path.isfile(class_skeleton), f"The class skeleton file provided does not exist! Got {class_skeleton}"
assert os.path.exists(test_dir), f"The test-dir provided does not exist! Got {test_dir}"

tr: GenericTestRunner = SubProcessTestRunner("", test_dir)
chat(
class_skeleton=class_skeleton,
test_dir=test_dir,
test_runner=tr
)

sandbox = CodeGenSandbox(project_dir, class_skeleton_path, test_path, sandbox_path)
sandbox.init_sandbox()
tr: GenericTestRunner = SubProcessTestRunner(sandbox)
chat(sandbox, tr, max_epochs=max_epochs)


if __name__ == "__main__":
Expand Down