Skip to content

Commit 258a639

Browse files
committed
Show spinner, then stream, add ctrl+c support
Streaming allows users to start reading the response as soon as the first fragment is generated by the API, significantly improving perceived speed. litellm is now imported lazily so that a spinner can be shown while it loads and sends the first response. This makes the ssage command feel faster and more responsive. Added support for Ctrl+C, letting users interrupt ssage mid-sentence — perfect for recalling an exact command and skipping the rest of the explanation. Example: ```py ssage gh mk pr I can see you're trying to create a pull request with GitHub CLI, but the command syntax isn't quite right. The correct command is: ``` gh pr create ``` This will interactively prompt you Interrupted. ```
1 parent 5a1bd7e commit 258a639

File tree

3 files changed

+192
-96
lines changed

3 files changed

+192
-96
lines changed

nbs/00_core.ipynb

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,21 @@
3535
"source": [
3636
"#| export\n",
3737
"from datetime import datetime\n",
38+
"from itertools import accumulate\n",
3839
"from fastcore.script import *\n",
3940
"from fastcore.tools import *\n",
4041
"from fastcore.utils import *\n",
4142
"from fastlite import database\n",
4243
"from functools import partial, wraps\n",
43-
"from lisette import *\n",
44+
"from rich.live import Live\n",
45+
"from rich.spinner import Spinner\n",
4446
"from rich.console import Console\n",
4547
"from rich.markdown import Markdown\n",
4648
"from shell_sage import __version__\n",
4749
"from shell_sage.config import *\n",
4850
"from subprocess import check_output as co, DEVNULL\n",
4951
"\n",
50-
"import asyncio,litellm,os,pyperclip,re,subprocess,sys"
52+
"import asyncio,os,pyperclip,re,subprocess,sys"
5153
]
5254
},
5355
{
@@ -58,11 +60,57 @@
5860
"outputs": [],
5961
"source": [
6062
"#| export\n",
61-
"litellm.drop_params = True\n",
6263
"console = Console()\n",
6364
"print = console.print"
6465
]
6566
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"id": "977bd215",
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"#| export\n",
75+
"def Chat(*arg, **kw):\n",
76+
" \"Lazy load lisette to make ssage more responsive\"\n",
77+
" import litellm \n",
78+
" from lisette import Chat\n",
79+
" \n",
80+
" litellm.drop_params = True\n",
81+
" return Chat(*arg, **kw)"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"id": "257cdfa7",
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"from contextlib import contextmanager\n",
92+
"from IPython.display import clear_output\n",
93+
"# jupyter does work with rich.live.Live, the fixes this.\n",
94+
"@contextmanager\n",
95+
"def Live(start, **kw):\n",
96+
" print(start)\n",
97+
" def update(s, refresh=False): clear_output(True);print(s)\n",
98+
" yield NS(update=update)"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"id": "b56fc5bc",
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"def print_md(md_stream):\n",
109+
" \"Print streamed markdown\"\n",
110+
" with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n",
111+
" for part in md_stream: live.update(Markdown(part), refresh=True)"
112+
]
113+
},
66114
{
67115
"cell_type": "markdown",
68116
"id": "c643b9f0",
@@ -572,15 +620,38 @@
572620
{
573621
"cell_type": "code",
574622
"execution_count": null,
575-
"id": "36283633",
623+
"id": "68be9484",
576624
"metadata": {},
577625
"outputs": [],
578626
"source": [
579627
"#| export\n",
580628
"def get_res(sage, q, opts):\n",
629+
" from litellm.types.utils import ModelResponseStream # lazy load\n",
581630
" # need to use stream=True to get search citations\n",
582-
" for o in sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key): ...\n",
583-
" return o.choices[0].message.content"
631+
" gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key) \n",
632+
" yield from accumulate(o.choices[0].delta.content or \"\" for o in gen if isinstance(o, ModelResponseStream))"
633+
]
634+
},
635+
{
636+
"cell_type": "code",
637+
"execution_count": null,
638+
"id": "325e2cbd",
639+
"metadata": {},
640+
"outputs": [
641+
{
642+
"data": {
643+
"text/plain": [
644+
"['', '', '', '', '', '', '', 'No', 'No']"
645+
]
646+
},
647+
"execution_count": null,
648+
"metadata": {},
649+
"output_type": "execute_result"
650+
}
651+
],
652+
"source": [
653+
"opts=NS(api_base='', api_key='')\n",
654+
"list(get_res(ssage, 'Use tools to check if we have a .git in the current directory. Respond with yes/no', opts))"
584655
]
585656
},
586657
{
@@ -630,8 +701,7 @@
630701
}
631702
],
632703
"source": [
633-
"opts=NS(api_base='', api_key='')\n",
634-
"print(Markdown(get_res(ssage, 'Hi!', opts)))"
704+
"print_md(get_res(ssage, 'Hi!', opts))"
635705
]
636706
},
637707
{
@@ -661,7 +731,7 @@
661731
}
662732
],
663733
"source": [
664-
"print(Markdown(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts)))"
734+
"print_md(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts))"
665735
]
666736
},
667737
{
@@ -697,7 +767,7 @@
697767
}
698768
],
699769
"source": [
700-
"print(Markdown(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts)))"
770+
"print_md(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts));"
701771
]
702772
},
703773
{
@@ -773,37 +843,42 @@
773843
" opts = get_opts(history_lines=history_lines, model=model, search=search,\n",
774844
" api_base=api_base, api_key=api_key, code_theme=code_theme,\n",
775845
" code_lexer=code_lexer, log=None)\n",
846+
" res=\"\"\n",
847+
" try:\n",
848+
" with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n",
849+
" \n",
850+
" if mode not in ['default', 'sassy']:\n",
851+
" raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n",
852+
" \n",
853+
" md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
854+
" inline_code_theme=opts.code_theme)\n",
855+
" query = ' '.join(query)\n",
856+
" ctxt = '' if skip_system else _sys_info()\n",
776857
"\n",
777-
" if mode not in ['default', 'sassy']:\n",
778-
" raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n",
779-
" \n",
780-
" md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
781-
" inline_code_theme=opts.code_theme)\n",
782-
" query = ' '.join(query)\n",
783-
" ctxt = '' if skip_system else _sys_info()\n",
784-
"\n",
785-
" # Get tmux history if in a tmux session\n",
786-
" if os.environ.get('TMUX'):\n",
787-
" if opts.history_lines is None or opts.history_lines < 0:\n",
788-
" opts.history_lines = tmux_history_lim()\n",
789-
" history = get_history(opts.history_lines, pid)\n",
790-
" if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n",
858+
" # Get tmux history if in a tmux session\n",
859+
" if os.environ.get('TMUX'):\n",
860+
" if opts.history_lines is None or opts.history_lines < 0:\n",
861+
" opts.history_lines = tmux_history_lim()\n",
862+
" history = get_history(opts.history_lines, pid)\n",
863+
" if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n",
791864
"\n",
792-
" # Read from stdin if available\n",
793-
" if not sys.stdin.isatty():\n",
794-
" ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n",
795-
" \n",
796-
" query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n",
865+
" # Read from stdin if available\n",
866+
" if not sys.stdin.isatty() and not IN_NOTEBOOK:\n",
867+
" ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n",
868+
" \n",
869+
" query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n",
797870
"\n",
798-
" sage = get_sage(opts.model, mode, search=opts.search)\n",
799-
" res = get_res(sage, query, opts)\n",
800-
" \n",
801-
" # Handle logging if the log flag is set\n",
802-
" if opts.log:\n",
803-
" db = mk_db()\n",
804-
" db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n",
805-
" response=res, model=opts.model, mode=mode))\n",
806-
" print(md(res))"
871+
" sage = get_sage(opts.model, mode, search=opts.search)\n",
872+
" for res in get_res(sage, query, opts):\n",
873+
" live.update(md(res), refresh=True)\n",
874+
" \n",
875+
" # Handle logging if the log flag is set\n",
876+
" if opts.log:\n",
877+
" db = mk_db()\n",
878+
" db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n",
879+
" response=res, model=opts.model, mode=mode))\n",
880+
" except KeyboardInterrupt:\n",
881+
" print(\"Interrupted.\")"
807882
]
808883
},
809884
{
@@ -853,7 +928,7 @@
853928
}
854929
],
855930
"source": [
856-
"main('Teach me about rsync. Reply with a single paragraph.', history_lines=0)"
931+
"main(['Teach me about rsync. Reply with a single paragraph.'], history_lines=0)"
857932
]
858933
},
859934
{
@@ -974,7 +1049,11 @@
9741049
]
9751050
}
9761051
],
977-
"metadata": {},
1052+
"metadata": {
1053+
"language_info": {
1054+
"name": "python"
1055+
}
1056+
},
9781057
"nbformat": 4,
9791058
"nbformat_minor": 5
9801059
}

shell_sage/_modidx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
'syms': { 'shell_sage.config': { 'shell_sage.config.ShellSageConfig': ('config.html#shellsageconfig', 'shell_sage/config.py'),
99
'shell_sage.config._cfg_path': ('config.html#_cfg_path', 'shell_sage/config.py'),
1010
'shell_sage.config.get_cfg': ('config.html#get_cfg', 'shell_sage/config.py')},
11-
'shell_sage.core': { 'shell_sage.core.Log': ('core.html#log', 'shell_sage/core.py'),
11+
'shell_sage.core': { 'shell_sage.core.Chat': ('core.html#chat', 'shell_sage/core.py'),
12+
'shell_sage.core.Log': ('core.html#log', 'shell_sage/core.py'),
1213
'shell_sage.core._aliases': ('core.html#_aliases', 'shell_sage/core.py'),
1314
'shell_sage.core._sys_info': ('core.html#_sys_info', 'shell_sage/core.py'),
1415
'shell_sage.core.extract': ('core.html#extract', 'shell_sage/core.py'),

0 commit comments

Comments
 (0)