Skip to content

Commit

Permalink
Describe shell command option (TheR1D#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
chinarjoshi authored May 16, 2023
1 parent 5b5b321 commit 39b5b18
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 62 deletions.
19 changes: 16 additions & 3 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def main(
help="Generate and execute shell commands.",
rich_help_panel="Assistance Options",
),
describe_shell: bool = typer.Option(
False,
"--describe-shell",
"-d",
help="Describe a shell command.",
rich_help_panel="Assistance Options",
),
code: bool = typer.Option(
False,
help="Generate only code.",
Expand Down Expand Up @@ -117,8 +124,10 @@ def main(
if not prompt and not editor and not repl:
raise MissingParameter(param_hint="PROMPT", param_type="string")

if shell and code:
raise BadArgumentUsage("--shell and --code options cannot be used together.")
if sum([shell, describe_shell, code]) > 1:
raise BadArgumentUsage(
"Only one of --shell, --describe-shell, and --code options can be used at a time."
)

if chat and repl:
raise BadArgumentUsage("--chat and --repl options cannot be used together.")
Expand All @@ -131,7 +140,11 @@ def main(

client = OpenAIClient(cfg.get("OPENAI_API_HOST"), cfg.get("OPENAI_API_KEY"))

role_class = DefaultRoles.get(shell, code) if not role else SystemRole.get(role)
role_class = (
DefaultRoles.get(shell, describe_shell, code)
if not role
else SystemRole.get(role)
)

if repl:
# Will be in infinite loop here until user exits with Ctrl+C.
Expand Down
56 changes: 0 additions & 56 deletions sgpt/make_prompt.py

This file was deleted.

14 changes: 12 additions & 2 deletions sgpt/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
Ensure the output is a valid shell command.
If multiple steps required try to combine them together."""

DESCRIBE_SHELL_ROLE = """Provide a terse, single sentence description
of the given shell command. Provide only plain text without Markdown formatting.
Do not show any warnings or information regarding your capabilities.
If you need to store any data, assume it will be stored in the chat."""

CODE_ROLE = """Provide only code as output without any description.
IMPORTANT: Provide only plain text without Markdown formatting.
IMPORTANT: Do not include markdown formatting such as ```.
Expand Down Expand Up @@ -67,6 +72,7 @@ def create_defaults(cls) -> None:
for default_role in (
SystemRole("default", DEFAULT_ROLE, "Answer", variables),
SystemRole("shell", SHELL_ROLE, "Command", variables),
SystemRole("describe_shell", DESCRIBE_SHELL_ROLE, "Description", variables),
SystemRole("code", CODE_ROLE, "Code"),
):
if not default_role.exists:
Expand Down Expand Up @@ -112,7 +118,8 @@ def get(cls, name: str) -> "SystemRole":
def create(cls, name: str) -> None:
role = typer.prompt("Enter role description")
expecting = typer.prompt(
"Enter expecting result, e.g. answer, code, shell command, etc."
"Enter expecting result, e.g. answer, code, \
shell command, command description, etc."
)
role = cls(name, role, expecting)
role.save()
Expand Down Expand Up @@ -183,12 +190,15 @@ def same_role(self, initial_message: str) -> bool:
class DefaultRoles(Enum):
DEFAULT = "default"
SHELL = "shell"
DESCRIBE_SHELL = "describe_shell"
CODE = "code"

@classmethod
def get(cls, shell: bool, code: bool) -> SystemRole:
def get(cls, shell: bool, describe_shell: bool, code: bool) -> SystemRole:
if shell:
return SystemRole.get(DefaultRoles.SHELL.value)
if describe_shell:
return SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
if code:
return SystemRole.get(DefaultRoles.CODE.value)
return SystemRole.get(DefaultRoles.DEFAULT.value)
Expand Down
48 changes: 47 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def test_shell(self):
assert result.exit_code == 0
assert "git commit" in result.stdout

def test_describe_shell(self):
dict_arguments = {
"prompt": "ls",
"--describe-shell": True,
}
result = runner.invoke(app, self.get_arguments(**dict_arguments))
assert result.exit_code == 0
assert "List " in result.stdout

def test_code(self):
"""
This test will request from OpenAI API a python code to make CLI app,
Expand Down Expand Up @@ -145,6 +154,22 @@ def test_chat_shell(self):
# If we are using --code, we cannot use --shell.
assert result.exit_code == 2

def test_chat_describe_shell(self):
chat_name = uuid4()
dict_arguments = {
"prompt": "git add",
"--chat": f"test_{chat_name}",
"--describe-shell": True,
"--temperature": 0,
}
result = runner.invoke(app, self.get_arguments(**dict_arguments))
assert result.exit_code == 0
assert "Add file contents to the index." in result.stdout
dict_arguments["prompt"] = "'-A'"
result = runner.invoke(app, self.get_arguments(**dict_arguments))
assert result.exit_code == 0
assert "all" in result.stdout

def test_chat_code(self):
chat_name = uuid4()
dict_arguments = {
Expand Down Expand Up @@ -194,7 +219,7 @@ def test_validation_code_shell(self):
}
result = runner.invoke(app, self.get_arguments(**dict_arguments))
assert result.exit_code == 2
assert "--shell and --code options cannot be used together" in result.stdout
assert "Only one of --shell, --describe-shell, and --code" in result.stdout

def test_repl_default(
self,
Expand Down Expand Up @@ -243,6 +268,27 @@ def test_repl_shell(self):
assert chat_messages[2]["content"].endswith("\nCommand:")
assert chat_messages[3]["content"] == "ls | sort"

def test_repl_describe_command(self):
# Temp chat session from previous test should be overwritten.
dict_arguments = {
"prompt": "",
"--repl": "temp",
"--describe-shell": True,
}
inputs = ["pacman -S", "-yu", "exit()"]
result = runner.invoke(
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
)
assert result.exit_code == 0
assert "Install" in result.stdout
assert "Update" in result.stdout

chat_storage = cfg.get("CHAT_CACHE_PATH")
tmp_chat = Path(chat_storage) / "temp"
chat_messages = json.loads(tmp_chat.read_text())
assert chat_messages[0]["content"].startswith("###")
assert chat_messages[0]["content"].endswith("\n###\nDescription:")

def test_repl_code(self):
dict_arguments = {
"prompt": "",
Expand Down

0 comments on commit 39b5b18

Please sign in to comment.