forked from TheR1D/shell_gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgpt.py
173 lines (142 loc) · 5.96 KB
/
sgpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
shell-gpt: An interface to OpenAI's GPT-3 API
This module provides a simple interface for OpenAI's GPT-3 API using Typer
as the command line interface. It supports different modes of output including
shell commands and code, and allows users to specify the desired OpenAI model
and length and other options of the output. Additionally, it supports executing
shell commands directly from the interface.
API Key is stored locally for easy use in future runs.
"""
import os
import platform
from enum import Enum
from time import sleep
from pathlib import Path
from getpass import getpass
from types import DynamicClassAttribute
from tempfile import NamedTemporaryFile
import typer
import requests
# Click is part of typer.
from click import MissingParameter, BadParameter
from rich.progress import Progress, SpinnerColumn, TextColumn
API_URL = "https://api.openai.com/v1/completions"
DATA_FOLDER = os.path.expanduser("~/.config")
KEY_FILE = Path(DATA_FOLDER) / "shell-gpt" / "api_key.txt"
# pylint: disable=invalid-name
class Model(str, Enum):
davinci = "text-davinci-003"
curie = "text-curie-001"
codex = "code-davinci-002"
def __str__(self):
return self.name
@DynamicClassAttribute
def value(self):
return self.name
# pylint: enable=invalid-name
def get_api_key():
if not KEY_FILE.exists():
api_key = getpass(prompt="Please enter your API secret key")
KEY_FILE.parent.mkdir(parents=True, exist_ok=True)
KEY_FILE.write_text(api_key)
else:
api_key = KEY_FILE.read_text().strip()
return api_key
def loading_spinner(func):
def wrapper(*args, **kwargs):
if not kwargs.pop("spinner"):
return func(*args, **kwargs)
text = TextColumn("[green]Requesting OpenAI...")
with Progress(SpinnerColumn(), text, transient=True) as progress:
progress.add_task("request")
return func(*args, **kwargs)
return wrapper
def get_edited_prompt():
with NamedTemporaryFile(suffix=".txt", delete=False) as file:
# Create file and store path.
file_path = file.name
editor = os.environ.get("EDITOR", "vim")
# This will write text to file using $EDITOR.
os.system(f"{editor} {file_path}")
# Read file when editor is closed.
with open(file_path, "r") as file:
output = file.read()
os.remove(file_path)
if not output:
raise BadParameter("Couldn't get valid PROMPT from $EDITOR")
return output
@loading_spinner
def openai_request(prompt, model, max_tokens, api_key, temperature, top_p):
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = {
"prompt": prompt,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
response = requests.post(API_URL, headers=headers, json=data, timeout=180)
response.raise_for_status()
return response.json()["choices"][0]["text"]
def typer_writer(text, code, shell, animate):
shell_or_code = shell or code
color = "magenta" if shell_or_code else None
if animate and not shell_or_code:
for char in text:
typer.secho(char, nl=False, fg=color, bold=shell_or_code)
sleep(0.015)
# Add new line at the end, to prevent % from appearing.
typer.echo("")
return
typer.secho(text, fg=color, bold=shell_or_code)
# Using lambda to pass a function to default value, which make it appear as "dynamic" in help.
def main(
prompt: str = typer.Argument(None, show_default=False, help="The prompt to generate completions for."),
model: Model = typer.Option("davinci", help="GPT-3 model name.", show_choices=True),
max_tokens: int = typer.Option(lambda: 2048, help="Strict length of output (words)."),
temperature: float = typer.Option(lambda: 1.0, min=0.0, max=1.0, help="Randomness of generated output."),
top_probability: float = typer.Option(lambda: 1.0, min=0.1, max=1.0, help="Limits highest probable tokens."),
shell: bool = typer.Option(False, "--shell", "-s", help="Provide shell command as output."),
execute: bool = typer.Option(False, "--execute", "-e", help="Will execute --shell command."),
code: bool = typer.Option(False, help="Provide code as output."),
editor: bool = typer.Option(False, help="Open $EDITOR to provide a prompt."),
animation: bool = typer.Option(True, help="Typewriter animation."),
spinner: bool = typer.Option(True, help="Show loading spinner during API request."),
):
api_key = get_api_key()
if not prompt and not editor:
raise MissingParameter(param_hint="PROMPT", param_type="string")
if shell:
# If default values where not changed, make it more accurate.
if temperature == 1.0 == top_probability:
temperature, top_probability = 0.2, 1.0
current_shell = "PowerShell" if platform.system() == "Windows" else "Bash"
prompt = f"""
Context: Provide only {current_shell} command as output.
Prompt: {prompt}
Command:
"""
elif code:
# If default values where not changed, make output more creative (diverse).
if temperature == 1.0 == top_probability:
temperature, top_probability = 0.8, 0.2
prompt = f"""
Context: Provide only code as output.
Prompt: {prompt}
Code:
"""
# Curie has hard cap 2048 + prompt.
if model == "text-curie-001" and max_tokens == 2048:
max_tokens = 1024
if editor:
prompt = get_edited_prompt()
response_text = openai_request(prompt, model, max_tokens, api_key, temperature, top_probability, spinner=spinner)
# For some reason OpenAI returns several leading/trailing white spaces.
response_text = response_text.strip()
typer_writer(response_text, code, shell, animation)
if shell and execute and typer.confirm("Execute shell command?"):
os.system(response_text)
def entry_point():
typer.run(main)
if __name__ == "__main__":
entry_point()