|
35 | 35 | "source": [ |
36 | 36 | "#| export\n", |
37 | 37 | "from datetime import datetime\n", |
| 38 | + "from itertools import accumulate\n", |
38 | 39 | "from fastcore.script import *\n", |
39 | 40 | "from fastcore.tools import *\n", |
40 | 41 | "from fastcore.utils import *\n", |
41 | 42 | "from fastlite import database\n", |
42 | 43 | "from functools import partial, wraps\n", |
43 | | - "from lisette import *\n", |
| 44 | + "from rich.live import Live\n", |
| 45 | + "from rich.spinner import Spinner\n", |
44 | 46 | "from rich.console import Console\n", |
45 | 47 | "from rich.markdown import Markdown\n", |
46 | 48 | "from shell_sage import __version__\n", |
47 | 49 | "from shell_sage.config import *\n", |
48 | 50 | "from subprocess import check_output as co, DEVNULL\n", |
49 | 51 | "\n", |
50 | | - "import asyncio,litellm,os,pyperclip,re,subprocess,sys" |
| 52 | + "import asyncio,os,pyperclip,re,subprocess,sys" |
51 | 53 | ] |
52 | 54 | }, |
53 | 55 | { |
|
58 | 60 | "outputs": [], |
59 | 61 | "source": [ |
60 | 62 | "#| export\n", |
61 | | - "litellm.drop_params = True\n", |
62 | 63 | "console = Console()\n", |
63 | 64 | "print = console.print" |
64 | 65 | ] |
65 | 66 | }, |
| 67 | + { |
| 68 | + "cell_type": "code", |
| 69 | + "execution_count": null, |
| 70 | + "id": "977bd215", |
| 71 | + "metadata": {}, |
| 72 | + "outputs": [], |
| 73 | + "source": [ |
| 74 | + "#| export\n", |
| 75 | + "def Chat(*arg, **kw):\n", |
| 76 | + " \"Lazy load lisette to make ssage more responsive\"\n", |
| 77 | + " import litellm \n", |
| 78 | + " from lisette import Chat\n", |
| 79 | + " \n", |
| 80 | + " litellm.drop_params = True\n", |
| 81 | + " return Chat(*arg, **kw)" |
| 82 | + ] |
| 83 | + }, |
| 84 | + { |
| 85 | + "cell_type": "code", |
| 86 | + "execution_count": null, |
| 87 | + "id": "257cdfa7", |
| 88 | + "metadata": {}, |
| 89 | + "outputs": [], |
| 90 | + "source": [ |
| 91 | + "from contextlib import contextmanager\n", |
| 92 | + "from IPython.display import clear_output\n", |
| 93 | + "# jupyter does work with rich.live.Live, the fixes this.\n", |
| 94 | + "@contextmanager\n", |
| 95 | + "def Live(start, **kw):\n", |
| 96 | + " print(start)\n", |
| 97 | + " def update(s, refresh=False): clear_output(True);print(s)\n", |
| 98 | + " yield NS(update=update)" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "code", |
| 103 | + "execution_count": null, |
| 104 | + "id": "b56fc5bc", |
| 105 | + "metadata": {}, |
| 106 | + "outputs": [], |
| 107 | + "source": [ |
| 108 | + "def print_md(md_stream):\n", |
| 109 | + " \"Print streamed markdown\"\n", |
| 110 | + " with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n", |
| 111 | + " for part in md_stream: live.update(Markdown(part), refresh=True)" |
| 112 | + ] |
| 113 | + }, |
66 | 114 | { |
67 | 115 | "cell_type": "markdown", |
68 | 116 | "id": "c643b9f0", |
|
572 | 620 | { |
573 | 621 | "cell_type": "code", |
574 | 622 | "execution_count": null, |
575 | | - "id": "36283633", |
| 623 | + "id": "68be9484", |
576 | 624 | "metadata": {}, |
577 | 625 | "outputs": [], |
578 | 626 | "source": [ |
579 | 627 | "#| export\n", |
580 | 628 | "def get_res(sage, q, opts):\n", |
| 629 | + " from litellm.types.utils import ModelResponseStream # lazy load\n", |
581 | 630 | " # need to use stream=True to get search citations\n", |
582 | | - " for o in sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key): ...\n", |
583 | | - " return o.choices[0].message.content" |
| 631 | + " gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key) \n", |
| 632 | + " yield from accumulate(o.choices[0].delta.content or \"\" for o in gen if isinstance(o, ModelResponseStream))" |
| 633 | + ] |
| 634 | + }, |
| 635 | + { |
| 636 | + "cell_type": "code", |
| 637 | + "execution_count": null, |
| 638 | + "id": "325e2cbd", |
| 639 | + "metadata": {}, |
| 640 | + "outputs": [ |
| 641 | + { |
| 642 | + "data": { |
| 643 | + "text/plain": [ |
| 644 | + "['', '', '', '', '', '', '', 'No', 'No']" |
| 645 | + ] |
| 646 | + }, |
| 647 | + "execution_count": null, |
| 648 | + "metadata": {}, |
| 649 | + "output_type": "execute_result" |
| 650 | + } |
| 651 | + ], |
| 652 | + "source": [ |
| 653 | + "opts=NS(api_base='', api_key='')\n", |
| 654 | + "list(get_res(ssage, 'Use tools to check if we have a .git in the current directory. Respond with yes/no', opts))" |
584 | 655 | ] |
585 | 656 | }, |
586 | 657 | { |
|
630 | 701 | } |
631 | 702 | ], |
632 | 703 | "source": [ |
633 | | - "opts=NS(api_base='', api_key='')\n", |
634 | | - "print(Markdown(get_res(ssage, 'Hi!', opts)))" |
| 704 | + "print_md(get_res(ssage, 'Hi!', opts))" |
635 | 705 | ] |
636 | 706 | }, |
637 | 707 | { |
|
661 | 731 | } |
662 | 732 | ], |
663 | 733 | "source": [ |
664 | | - "print(Markdown(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts)))" |
| 734 | + "print_md(get_res(ssage, 'Please use your view command to see what files are in the current directory. Only respond with a single paragraph', opts))" |
665 | 735 | ] |
666 | 736 | }, |
667 | 737 | { |
|
697 | 767 | } |
698 | 768 | ], |
699 | 769 | "source": [ |
700 | | - "print(Markdown(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts)))" |
| 770 | + "print_md(get_res(ssage, 'Please search the web for interesting facts about Linux. Only respond with a single paragraph.', opts));" |
701 | 771 | ] |
702 | 772 | }, |
703 | 773 | { |
|
773 | 843 | " opts = get_opts(history_lines=history_lines, model=model, search=search,\n", |
774 | 844 | " api_base=api_base, api_key=api_key, code_theme=code_theme,\n", |
775 | 845 | " code_lexer=code_lexer, log=None)\n", |
| 846 | + " res=\"\"\n", |
| 847 | + " try:\n", |
| 848 | + " with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n", |
| 849 | + " \n", |
| 850 | + " if mode not in ['default', 'sassy']:\n", |
| 851 | + " raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n", |
| 852 | + " \n", |
| 853 | + " md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n", |
| 854 | + " inline_code_theme=opts.code_theme)\n", |
| 855 | + " query = ' '.join(query)\n", |
| 856 | + " ctxt = '' if skip_system else _sys_info()\n", |
776 | 857 | "\n", |
777 | | - " if mode not in ['default', 'sassy']:\n", |
778 | | - " raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n", |
779 | | - " \n", |
780 | | - " md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n", |
781 | | - " inline_code_theme=opts.code_theme)\n", |
782 | | - " query = ' '.join(query)\n", |
783 | | - " ctxt = '' if skip_system else _sys_info()\n", |
784 | | - "\n", |
785 | | - " # Get tmux history if in a tmux session\n", |
786 | | - " if os.environ.get('TMUX'):\n", |
787 | | - " if opts.history_lines is None or opts.history_lines < 0:\n", |
788 | | - " opts.history_lines = tmux_history_lim()\n", |
789 | | - " history = get_history(opts.history_lines, pid)\n", |
790 | | - " if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n", |
| 858 | + " # Get tmux history if in a tmux session\n", |
| 859 | + " if os.environ.get('TMUX'):\n", |
| 860 | + " if opts.history_lines is None or opts.history_lines < 0:\n", |
| 861 | + " opts.history_lines = tmux_history_lim()\n", |
| 862 | + " history = get_history(opts.history_lines, pid)\n", |
| 863 | + " if history: ctxt += f'<terminal_history>\\n{history}\\n</terminal_history>'\n", |
791 | 864 | "\n", |
792 | | - " # Read from stdin if available\n", |
793 | | - " if not sys.stdin.isatty():\n", |
794 | | - " ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n", |
795 | | - " \n", |
796 | | - " query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n", |
| 865 | + " # Read from stdin if available\n", |
| 866 | + " if not sys.stdin.isatty() and not IN_NOTEBOOK:\n", |
| 867 | + " ctxt += f'\\n<context>\\n{sys.stdin.read()}</context>'\n", |
| 868 | + " \n", |
| 869 | + " query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n", |
797 | 870 | "\n", |
798 | | - " sage = get_sage(opts.model, mode, search=opts.search)\n", |
799 | | - " res = get_res(sage, query, opts)\n", |
800 | | - " \n", |
801 | | - " # Handle logging if the log flag is set\n", |
802 | | - " if opts.log:\n", |
803 | | - " db = mk_db()\n", |
804 | | - " db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n", |
805 | | - " response=res, model=opts.model, mode=mode))\n", |
806 | | - " print(md(res))" |
| 871 | + " sage = get_sage(opts.model, mode, search=opts.search)\n", |
| 872 | + " for res in get_res(sage, query, opts):\n", |
| 873 | + " live.update(md(res), refresh=True)\n", |
| 874 | + " \n", |
| 875 | + " # Handle logging if the log flag is set\n", |
| 876 | + " if opts.log:\n", |
| 877 | + " db = mk_db()\n", |
| 878 | + " db.logs.insert(Log(timestamp=datetime.now().isoformat(), query=query,\n", |
| 879 | + " response=res, model=opts.model, mode=mode))\n", |
| 880 | + " except KeyboardInterrupt:\n", |
| 881 | + " print(\"Interrupted.\")" |
807 | 882 | ] |
808 | 883 | }, |
809 | 884 | { |
|
853 | 928 | } |
854 | 929 | ], |
855 | 930 | "source": [ |
856 | | - "main('Teach me about rsync. Reply with a single paragraph.', history_lines=0)" |
| 931 | + "main(['Teach me about rsync. Reply with a single paragraph.'], history_lines=0)" |
857 | 932 | ] |
858 | 933 | }, |
859 | 934 | { |
|
974 | 1049 | ] |
975 | 1050 | } |
976 | 1051 | ], |
977 | | - "metadata": {}, |
| 1052 | + "metadata": { |
| 1053 | + "language_info": { |
| 1054 | + "name": "python" |
| 1055 | + } |
| 1056 | + }, |
978 | 1057 | "nbformat": 4, |
979 | 1058 | "nbformat_minor": 5 |
980 | 1059 | } |
0 commit comments