Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 119 additions & 40 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,21 @@
"source": [
"#| export\n",
"from datetime import datetime\n",
"from itertools import accumulate\n",
"from fastcore.script import *\n",
"from fastcore.tools import *\n",
"from fastcore.utils import *\n",
"from fastlite import database\n",
"from functools import partial, wraps\n",
"from lisette import *\n",
"from rich.live import Live\n",
"from rich.spinner import Spinner\n",
"from rich.console import Console\n",
"from rich.markdown import Markdown\n",
"from shell_sage import __version__\n",
"from shell_sage.config import *\n",
"from subprocess import check_output as co, DEVNULL\n",
"\n",
"import asyncio,litellm,os,pyperclip,re,subprocess,sys"
"import asyncio,os,pyperclip,re,subprocess,sys"
]
},
{
Expand All @@ -58,11 +60,57 @@
"outputs": [],
"source": [
"#| export\n",
"litellm.drop_params = True\n",
"console = Console()\n",
"print = console.print"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "977bd215",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def Chat(*arg, **kw):\n",
" \"Lazy load lisette to make ssage more responsive\"\n",
" import litellm \n",
" from lisette import Chat\n",
" \n",
" litellm.drop_params = True\n",
" return Chat(*arg, **kw)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "257cdfa7",
"metadata": {},
"outputs": [],
"source": [
"from contextlib import contextmanager\n",
"from IPython.display import clear_output\n",
"# jupyter does work with rich.live.Live, the fixes this.\n",
"@contextmanager\n",
"def Live(start, **kw):\n",
" print(start)\n",
" def update(s, refresh=False): clear_output(True);print(s)\n",
" yield NS(update=update)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b56fc5bc",
"metadata": {},
"outputs": [],
"source": [
"def print_md(md_stream):\n",
" \"Print streamed markdown\"\n",
" with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n",
" for part in md_stream: live.update(Markdown(part), refresh=True)"
]
},
{
"cell_type": "markdown",
"id": "c643b9f0",
Expand Down Expand Up @@ -572,15 +620,38 @@
{
"cell_type": "code",
"execution_count": null,
"id": "36283633",
"id": "68be9484",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def get_res(sage, q, opts):\n",
" from litellm.types.utils import ModelResponseStream # lazy load\n",
" # need to use stream=True to get search citations\n",
" for o in sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key): ...\n",
" return o.choices[0].message.content"
" gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key) \n",
" yield from accumulate(o.choices[0].delta.content or \"\" for o in gen if isinstance(o, ModelResponseStream))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "325e2cbd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['', '', '', '', '', '', '', 'No', 'No']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"opts=NS(api_base='', api_key='')\n",
"list(get_res(ssage, 'Use tools to check if we have a .git in the current directory. Respond with yes/no', opts))"
]
},
{
Expand Down Expand Up @@ -630,8 +701,7 @@
}
],
"source": [
"opts=NS(api_base='', api_key='')\n",
"print(Markdown(get_res(ssage, 'Hi!', opts)))"
"print_md(get_res(ssage, 'Hi!', opts))"
]
},
{
Expand Down Expand Up @@ -661,7 +731,7 @@
}
],
"source": [
"print(Markdown(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts)))"
"print_md(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts))"
]
},
{
Expand Down Expand Up @@ -697,7 +767,7 @@
}
],
"source": [
"print(Markdown(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts)))"
"print_md(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts));"
]
},
{
Expand Down Expand Up @@ -773,37 +843,42 @@
" opts = get_opts(history_lines=history_lines, model=model, search=search,\n",
" api_base=api_base, api_key=api_key, code_theme=code_theme,\n",
" code_lexer=code_lexer, log=None)\n",
" res=\"\"\n",
" try:\n",
" with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n",
" \n",
" if mode not in ['default', 'sassy']:\n",
" raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n",
" \n",
" md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
" inline_code_theme=opts.code_theme)\n",
" query = ' '.join(query)\n",
" ctxt = '' if skip_system else _sys_info()\n",
"\n",
" if mode not in ['default', 'sassy']:\n",
" raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n",
" \n",
" md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
" inline_code_theme=opts.code_theme)\n",
" query = ' '.join(query)\n",
" ctxt = '' if skip_system else _sys_info()\n",
"\n",
" # Get tmux history if in a tmux session\n",
" if os.environ.get('TMUX'):\n",
" if opts.history_lines is None or opts.history_lines < 0:\n",
" opts.history_lines = tmux_history_lim()\n",
" history = get_history(opts.history_lines, pid)\n",
" if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n",
" # Get tmux history if in a tmux session\n",
" if os.environ.get('TMUX'):\n",
" if opts.history_lines is None or opts.history_lines < 0:\n",
" opts.history_lines = tmux_history_lim()\n",
" history = get_history(opts.history_lines, pid)\n",
" if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n",
"\n",
" # Read from stdin if available\n",
" if not sys.stdin.isatty():\n",
" ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n",
" \n",
" query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n",
" # Read from stdin if available\n",
" if not sys.stdin.isatty() and not IN_NOTEBOOK:\n",
" ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n",
" \n",
" query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n",
"\n",
" sage = get_sage(opts.model, mode, search=opts.search)\n",
" res = get_res(sage, query, opts)\n",
" \n",
" # Handle logging if the log flag is set\n",
" if opts.log:\n",
" db = mk_db()\n",
" db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n",
" response=res, model=opts.model, mode=mode))\n",
" print(md(res))"
" sage = get_sage(opts.model, mode, search=opts.search)\n",
" for res in get_res(sage, query, opts):\n",
" live.update(md(res), refresh=True)\n",
" \n",
" # Handle logging if the log flag is set\n",
" if opts.log:\n",
" db = mk_db()\n",
" db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n",
" response=res, model=opts.model, mode=mode))\n",
" except KeyboardInterrupt:\n",
" print(\"Interrupted.\")"
]
},
{
Expand Down Expand Up @@ -853,7 +928,7 @@
}
],
"source": [
"main('Teach me about rsync. Reply with a single paragraph.', history_lines=0)"
"main(['Teach me about rsync. Reply with a single paragraph.'], history_lines=0)"
]
},
{
Expand Down Expand Up @@ -974,7 +1049,11 @@
]
}
],
"metadata": {},
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion shell_sage/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
'syms': { 'shell_sage.config': { 'shell_sage.config.ShellSageConfig': ('config.html#shellsageconfig', 'shell_sage/config.py'),
'shell_sage.config._cfg_path': ('config.html#_cfg_path', 'shell_sage/config.py'),
'shell_sage.config.get_cfg': ('config.html#get_cfg', 'shell_sage/config.py')},
'shell_sage.core': { 'shell_sage.core.Log': ('core.html#log', 'shell_sage/core.py'),
'shell_sage.core': { 'shell_sage.core.Chat': ('core.html#chat', 'shell_sage/core.py'),
'shell_sage.core.Log': ('core.html#log', 'shell_sage/core.py'),
'shell_sage.core._aliases': ('core.html#_aliases', 'shell_sage/core.py'),
'shell_sage.core._sys_info': ('core.html#_sys_info', 'shell_sage/core.py'),
'shell_sage.core.extract': ('core.html#extract', 'shell_sage/core.py'),
Expand Down
Loading