Skip to content

Commit

Permalink
Multiline input in REPL (TheR1D#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
arafatsyed authored Dec 12, 2023
1 parent e61caf5 commit f50d544
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ class ReplHandler(ChatHandler):
def __init__(self, chat_id: str, role: SystemRole) -> None:
super().__init__(chat_id, role)

def get_multiline_input(self) -> str:
multiline_input = ""
while True:
user_input = typer.prompt("...", prompt_suffix="")
multiline_input += user_input + "\n"
if user_input == '"""':
break
return multiline_input

def handle(self, prompt: str, **kwargs: Any) -> None: # type: ignore
if self.initiated:
rich_print(Rule(title="Chat History", style="bold magenta"))
Expand All @@ -34,6 +43,8 @@ def handle(self, prompt: str, **kwargs: Any) -> None: # type: ignore
while True:
# Infinite loop until user exits with Ctrl+C.
prompt = typer.prompt(">>>", prompt_suffix=" ")
if prompt == '"""':
prompt = self.get_multiline_input()
if prompt == "exit()":
# This is also useful during tests.
raise typer.Exit()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,31 @@ def test_repl_default(
assert ">>> What is my favorite number + 2?" in result.stdout
assert "8" in result.stdout

def test_repl_multiline(
self,
):
dict_arguments = {
"prompt": "",
"--repl": "temp",
}
inputs = [
'"""',
"Please remember my favorite number: 6",
"What is my favorite number + 2?",
'"""',
"exit()",
]
result = runner.invoke(
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
)

assert result.exit_code == 0
assert '"""' in result.stdout
assert "Please remember my favorite number: 6" in result.stdout
assert "What is my favorite number + 2?" in result.stdout
assert '"""' in result.stdout
assert "8" in result.stdout

def test_repl_shell(self):
# Temp chat session from previous test should be overwritten.
dict_arguments = {
Expand Down

0 comments on commit f50d544

Please sign in to comment.