From 5be0a21bbcc6795ffbc0b6bc4f47cdf204864394 Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Tue, 6 Jun 2023 16:25:41 -0700 Subject: [PATCH] Register, update, and delete aliases (#136) * Validates registry name * WIP: Register alias * Raises exceptions * Refactors, adds delete and update commands * Additional examples * Update sample notebook * Update docs * List aliases * Refactoring * Recommends using 'update' command * WIP: Gets variable from user namespace, tests whether it's a chain * Updates sample workbook, calls custom chain * Updates user docs for aliases * Edits sample notebook * Alias list in text display, updates messaging * Updates sample workbook * Updates sample notebook, parsers to use click * Additional cleanup * Updates sample notebook, removes unahppy case * Fix error from rebase, updates sample notebook * Fixed error when --format is used * Update docs/source/users/index.md Co-authored-by: Piyush Jain * Update docs/source/users/index.md Co-authored-by: Piyush Jain * Wraps ValueError exceptions to not print stack trace --------- Co-authored-by: Piyush Jain --- docs/source/users/index.md | 53 ++ examples/commands.ipynb | 829 ++++++++++++++++-- .../jupyter_ai_magics/magics.py | 230 ++++- .../jupyter_ai_magics/parsers.py | 40 + 4 files changed, 1016 insertions(+), 136 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 52c758c15..1de889461 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -548,3 +548,56 @@ As a shortcut for explaining errors, you can use the `%ai error` command, which %ai error anthropic:claude-v1.2 ``` +### Creating and managing aliases + +You can create an alias for a model using the `%ai register` command. For example, the command: + +``` +%ai register claude anthropic:claude-v1.2 +``` + +will register the alias `claude` as pointing to the `anthropic` provider's `claude-v1.2` model. You can then use this alias as you would use any other model name: + +``` +%%ai claude +Write a poem about C++. +``` + +You can also define a custom LangChain chain: + +``` +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain.llms import OpenAI + +llm = OpenAI(temperature=0.9) +prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", +) +chain = LLMChain(llm=llm, prompt=prompt) +``` + +… and then use `%ai register` to give it a name: + +``` +%ai register companyname chain +``` + +You can change an alias's target using the `%ai update` command: + +``` +%ai update claude anthropic:claude-instant-v1.0 +``` + +You can delete an alias using the `%ai delete` command: + +``` +%ai delete claude +``` + +You can see a list of all aliases by running the `%ai list` command. + +Aliases' names can contain ASCII letters (uppercase and lowercase), numbers, hyphens, underscores, and periods. They may not contain colons. They may also not override built-in commands — run `%ai help` for a list of these commands. + +Aliases must refer to models or `LLMChain` objects; they cannot refer to other aliases. diff --git a/examples/commands.ipynb b/examples/commands.ipynb index 76634cee8..e87e90205 100644 --- a/examples/commands.ipynb +++ b/examples/commands.ipynb @@ -90,9 +90,13 @@ " --help Show this message and exit.\n", "\n", "Commands:\n", - " error Explains the most recent error.\n", - " help Show this message and exit.\n", - " list List language models. See `%ai list --help` for options.\n", + " delete Delete an alias. See `%ai delete --help` for options.\n", + " error Explains the most recent error.\n", + " help Show this message and exit.\n", + " list List language models. See `%ai list --help` for options.\n", + " register Register a new alias. See `%ai register --help` for options.\n", + " update Update the target of an alias. See `%ai update --help` for\n", + " options.\n", "\n" ] } @@ -104,64 +108,31 @@ { "cell_type": "code", "execution_count": 3, - "id": "e1f2b767-0834-4b21-b132-093730efaffb", + "id": "1f249bdc-410b-42f4-b21b-b7bde9e06387", "metadata": { "tags": [] }, "outputs": [ { - "data": { - "text/markdown": [ - "There have been no errors since the kernel started." - ], - "text/plain": [ - "There have been no errors since the kernel started." - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Usage: %ai register [OPTIONS] NAME TARGET\n", + "\n", + " Register a new alias called NAME for the model or chain named TARGET.\n", + "\n", + "Options:\n", + " --help Show this message and exit.\n" + ] } ], "source": [ - "%ai error chatgpt" + "%ai register --help" ] }, { "cell_type": "code", "execution_count": 4, - "id": "0f073caa-265d-40d6-b537-d025b8df9f41", - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "Cannot determine model provider from model ID `foo`.\n", - "\n", - "To see a list of models you can use, run `%ai list`\n", - "\n", - "If you were trying to run a command, run `%ai help` to see a list of commands." - ], - "text/plain": [ - "Cannot determine model provider from model ID 'foo'.\n", - "\n", - "To see a list of models you can use, run '%ai list'\n", - "\n", - "If you were trying to run a command, run '%ai help' to see a list of commands." - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%ai foo" - ] - }, - { - "cell_type": "code", - "execution_count": 5, "id": "bad2d8a8-6141-4247-9af7-7583426c59a6", "metadata": {}, "outputs": [ @@ -177,7 +148,16 @@ "| `openai` | `OPENAI_API_KEY` | | `openai:text-davinci-003`, `openai:text-davinci-002`, `openai:text-curie-001`, `openai:text-babbage-001`, `openai:text-ada-001`, `openai:davinci`, `openai:curie`, `openai:babbage`, `openai:ada` |\n", "| `openai-chat` | `OPENAI_API_KEY` | | `openai-chat:gpt-4`, `openai-chat:gpt-4-0314`, `openai-chat:gpt-4-32k`, `openai-chat:gpt-4-32k-0314`, `openai-chat:gpt-3.5-turbo`, `openai-chat:gpt-3.5-turbo-0301` |\n", "| `openai-chat-new` | `OPENAI_API_KEY` | | `openai-chat-new:gpt-4`, `openai-chat-new:gpt-4-0314`, `openai-chat-new:gpt-4-32k`, `openai-chat-new:gpt-4-32k-0314`, `openai-chat-new:gpt-3.5-turbo`, `openai-chat-new:gpt-3.5-turbo-0301` |\n", - "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n" + "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n", + "\n", + "Aliases and custom commands:\n", + "\n", + "| Name | Target |\n", + "|------|--------|\n", + "| `gpt2` | `huggingface_hub:gpt2` |\n", + "| `gpt3` | `openai:text-davinci-003` |\n", + "| `chatgpt` | `openai-chat:gpt-3.5-turbo` |\n", + "| `gpt4` | `openai-chat:gpt-4` |\n" ], "text/plain": [ "ai21\n", @@ -241,10 +221,16 @@ "\n", "sagemaker-endpoint\n", "* This provider does not define a list of models.\n", - "\n" + "\n", + "\n", + "Aliases and custom commands:\n", + "gpt2 - huggingface_hub:gpt2\n", + "gpt3 - openai:text-davinci-003\n", + "chatgpt - openai-chat:gpt-3.5-turbo\n", + "gpt4 - openai-chat:gpt-4\n" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -255,7 +241,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "4d84fcac-7348-4c02-9ec3-34b300ec8459", "metadata": {}, "outputs": [ @@ -281,7 +267,7 @@ "\n" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -290,29 +276,157 @@ "%ai list openai" ] }, + { + "cell_type": "markdown", + "id": "a7b640f4-c8ba-4c4e-ba91-7bb14758064c", + "metadata": {}, + "source": [ + "## Model aliases\n", + "\n", + "Using the syntax `%ai register NAME TARGET`, you can create a new alias to an existing alias's target. The target must be specified using the full `provider:model` syntax. You cannot create an alias to another alias." + ] + }, { "cell_type": "code", - "execution_count": 7, - "id": "b56ff0e3-42c2-4927-affd-be6a089dfa43", + "execution_count": 6, + "id": "539263a2-1c30-4338-9622-bdc500c17830", "metadata": {}, "outputs": [ { - "ename": "SyntaxError", - "evalue": "Missing parentheses in call to 'print'. Did you mean print(...)? (1142230402.py, line 1)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m Cell \u001b[0;32mIn[7], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m print 'foo'\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m Missing parentheses in call to 'print'. Did you mean print(...)?\n" - ] + "data": { + "text/markdown": [ + "Registered new alias `mychat`" + ], + "text/plain": [ + "Registered new alias `mychat`" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai register mychat openai-chat:gpt-4" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4a453ca5-9f33-4393-a936-d7bc6c4c8f63", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "| Provider | Environment variable | Set? | Models |\n", + "|----------|----------------------|------|--------|\n", + "| `ai21` | `AI21_API_KEY` | | `ai21:j1-large`, `ai21:j1-grande`, `ai21:j1-jumbo`, `ai21:j1-grande-instruct`, `ai21:j2-large`, `ai21:j2-grande`, `ai21:j2-jumbo`, `ai21:j2-grande-instruct`, `ai21:j2-jumbo-instruct` |\n", + "| `anthropic` | `ANTHROPIC_API_KEY` | | `anthropic:claude-v1`, `anthropic:claude-v1.0`, `anthropic:claude-v1.2`, `anthropic:claude-instant-v1`, `anthropic:claude-instant-v1.0` |\n", + "| `cohere` | `COHERE_API_KEY` | | `cohere:medium`, `cohere:xlarge` |\n", + "| `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | | This provider does not define a list of models. |\n", + "| `openai` | `OPENAI_API_KEY` | | `openai:text-davinci-003`, `openai:text-davinci-002`, `openai:text-curie-001`, `openai:text-babbage-001`, `openai:text-ada-001`, `openai:davinci`, `openai:curie`, `openai:babbage`, `openai:ada` |\n", + "| `openai-chat` | `OPENAI_API_KEY` | | `openai-chat:gpt-4`, `openai-chat:gpt-4-0314`, `openai-chat:gpt-4-32k`, `openai-chat:gpt-4-32k-0314`, `openai-chat:gpt-3.5-turbo`, `openai-chat:gpt-3.5-turbo-0301` |\n", + "| `openai-chat-new` | `OPENAI_API_KEY` | | `openai-chat-new:gpt-4`, `openai-chat-new:gpt-4-0314`, `openai-chat-new:gpt-4-32k`, `openai-chat-new:gpt-4-32k-0314`, `openai-chat-new:gpt-3.5-turbo`, `openai-chat-new:gpt-3.5-turbo-0301` |\n", + "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n", + "\n", + "Aliases and custom commands:\n", + "\n", + "| Name | Target |\n", + "|------|--------|\n", + "| `gpt2` | `huggingface_hub:gpt2` |\n", + "| `gpt3` | `openai:text-davinci-003` |\n", + "| `chatgpt` | `openai-chat:gpt-3.5-turbo` |\n", + "| `gpt4` | `openai-chat:gpt-4` |\n", + "| `mychat` | `openai-chat:gpt-4` |\n" + ], + "text/plain": [ + "ai21\n", + "Requires environment variable AI21_API_KEY (set)\n", + "* ai21:j1-large\n", + "* ai21:j1-grande\n", + "* ai21:j1-jumbo\n", + "* ai21:j1-grande-instruct\n", + "* ai21:j2-large\n", + "* ai21:j2-grande\n", + "* ai21:j2-jumbo\n", + "* ai21:j2-grande-instruct\n", + "* ai21:j2-jumbo-instruct\n", + "\n", + "anthropic\n", + "Requires environment variable ANTHROPIC_API_KEY (set)\n", + "* anthropic:claude-v1\n", + "* anthropic:claude-v1.0\n", + "* anthropic:claude-v1.2\n", + "* anthropic:claude-instant-v1\n", + "* anthropic:claude-instant-v1.0\n", + "\n", + "cohere\n", + "Requires environment variable COHERE_API_KEY (set)\n", + "* cohere:medium\n", + "* cohere:xlarge\n", + "\n", + "huggingface_hub\n", + "Requires environment variable HUGGINGFACEHUB_API_TOKEN (set)\n", + "* This provider does not define a list of models.\n", + "\n", + "openai\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai:text-davinci-003\n", + "* openai:text-davinci-002\n", + "* openai:text-curie-001\n", + "* openai:text-babbage-001\n", + "* openai:text-ada-001\n", + "* openai:davinci\n", + "* openai:curie\n", + "* openai:babbage\n", + "* openai:ada\n", + "\n", + "openai-chat\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat:gpt-4\n", + "* openai-chat:gpt-4-0314\n", + "* openai-chat:gpt-4-32k\n", + "* openai-chat:gpt-4-32k-0314\n", + "* openai-chat:gpt-3.5-turbo\n", + "* openai-chat:gpt-3.5-turbo-0301\n", + "\n", + "openai-chat-new\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat-new:gpt-4\n", + "* openai-chat-new:gpt-4-0314\n", + "* openai-chat-new:gpt-4-32k\n", + "* openai-chat-new:gpt-4-32k-0314\n", + "* openai-chat-new:gpt-3.5-turbo\n", + "* openai-chat-new:gpt-3.5-turbo-0301\n", + "\n", + "sagemaker-endpoint\n", + "* This provider does not define a list of models.\n", + "\n", + "\n", + "Aliases and custom commands:\n", + "gpt2 - huggingface_hub:gpt2\n", + "gpt3 - openai:text-davinci-003\n", + "chatgpt - openai-chat:gpt-3.5-turbo\n", + "gpt4 - openai-chat:gpt-4\n", + "mychat - openai-chat:gpt-4\n" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "print 'foo'" + "%ai list" ] }, { "cell_type": "code", "execution_count": 8, - "id": "63581a74-9237-4f11-aff2-48612c97cb27", + "id": "ea75ac14-d83d-45b4-8c1d-40e4fbc1d263", "metadata": { "tags": [] }, @@ -320,7 +434,9 @@ { "data": { "text/markdown": [ - "The error message \"SyntaxError: Missing parentheses in call to 'print'. Did you mean print(...)?\", occurs when attempting to use the print function without parentheses in Python 3, which is not syntactically valid. In Python 2, it was possible to use the print statement without parentheses, but in Python 3 it has become a function and its use requires parentheses. The error message suggests that the print statement should be replaced with print() to rectify the error." + "\n", + "\n", + "This is an example of a **Model X**." ], "text/plain": [ "" @@ -330,8 +446,8 @@ "metadata": { "text/markdown": { "jupyter_ai": { - "model_id": "gpt-3.5-turbo", - "provider_id": "openai-chat" + "model_id": "text-davinci-003", + "provider_id": "openai" } } }, @@ -339,56 +455,59 @@ } ], "source": [ - "%ai error chatgpt" + "%%ai gpt3\n", + "What model is this?" ] }, { "cell_type": "code", "execution_count": 9, - "id": "1afe4536-f908-4bd7-aec4-f8d1cc3bf01f", + "id": "3e7fb7fb-d61c-4909-959b-14bc97af409a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "Updated target of alias `mychat`" + ], + "text/plain": [ + "Updated target of alias `mychat`" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai update mychat openai:text-davinci-003" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a89a4c2a-b37f-4a3d-86cb-67747113d3d7", "metadata": { "tags": [] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/miniconda3/envs/jupyter-ai/lib/python3.10/site-packages/langchain/llms/anthropic.py:134: UserWarning: This Anthropic LLM is deprecated. Please use `from langchain.chat_models import ChatAnthropic` instead\n", - " warnings.warn(\n" - ] - }, { "data": { "text/markdown": [ "\n", - "The error `Cell In[7], line 1 print 'foo' ^ SyntaxError: Missing parentheses in call to 'print' . Did you mean print(...)?` \n", - "is occurring because in Python 3, the `print` statement has changed. \n", - "\n", - "In Python 2, you could simply do:\n", - "`print 'foo'`\n", - "\n", - "to print the string `foo`. \n", "\n", - "However, in Python 3, the `print` function requires parentheses:\n", - "`print('foo')`\n", - "\n", - "So the error is telling you that you're trying to use the Python 2 `print` statement in Python 3 code. \n", - "It's suggesting that you likely meant to call the `print()` function instead, with parentheses: \n", - "`print('foo')`.\n", - "\n", - "Adding the parentheses will fix the error." + "This model is not specified." ], "text/plain": [ "" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": { "text/markdown": { "jupyter_ai": { - "model_id": "claude-v1.2", - "provider_id": "anthropic" + "model_id": "text-davinci-003", + "provider_id": "openai" } } }, @@ -396,13 +515,531 @@ } ], "source": [ - "%ai error anthropic:claude-v1.2" + "%%ai mychat\n", + "What model is this?" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c7453152-6c70-4f91-bef6-4ae38d700f52", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "| Provider | Environment variable | Set? | Models |\n", + "|----------|----------------------|------|--------|\n", + "| `ai21` | `AI21_API_KEY` | | `ai21:j1-large`, `ai21:j1-grande`, `ai21:j1-jumbo`, `ai21:j1-grande-instruct`, `ai21:j2-large`, `ai21:j2-grande`, `ai21:j2-jumbo`, `ai21:j2-grande-instruct`, `ai21:j2-jumbo-instruct` |\n", + "| `anthropic` | `ANTHROPIC_API_KEY` | | `anthropic:claude-v1`, `anthropic:claude-v1.0`, `anthropic:claude-v1.2`, `anthropic:claude-instant-v1`, `anthropic:claude-instant-v1.0` |\n", + "| `cohere` | `COHERE_API_KEY` | | `cohere:medium`, `cohere:xlarge` |\n", + "| `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | | This provider does not define a list of models. |\n", + "| `openai` | `OPENAI_API_KEY` | | `openai:text-davinci-003`, `openai:text-davinci-002`, `openai:text-curie-001`, `openai:text-babbage-001`, `openai:text-ada-001`, `openai:davinci`, `openai:curie`, `openai:babbage`, `openai:ada` |\n", + "| `openai-chat` | `OPENAI_API_KEY` | | `openai-chat:gpt-4`, `openai-chat:gpt-4-0314`, `openai-chat:gpt-4-32k`, `openai-chat:gpt-4-32k-0314`, `openai-chat:gpt-3.5-turbo`, `openai-chat:gpt-3.5-turbo-0301` |\n", + "| `openai-chat-new` | `OPENAI_API_KEY` | | `openai-chat-new:gpt-4`, `openai-chat-new:gpt-4-0314`, `openai-chat-new:gpt-4-32k`, `openai-chat-new:gpt-4-32k-0314`, `openai-chat-new:gpt-3.5-turbo`, `openai-chat-new:gpt-3.5-turbo-0301` |\n", + "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n", + "\n", + "Aliases and custom commands:\n", + "\n", + "| Name | Target |\n", + "|------|--------|\n", + "| `gpt2` | `huggingface_hub:gpt2` |\n", + "| `gpt3` | `openai:text-davinci-003` |\n", + "| `chatgpt` | `openai-chat:gpt-3.5-turbo` |\n", + "| `gpt4` | `openai-chat:gpt-4` |\n", + "| `mychat` | `openai:text-davinci-003` |\n" + ], + "text/plain": [ + "ai21\n", + "Requires environment variable AI21_API_KEY (set)\n", + "* ai21:j1-large\n", + "* ai21:j1-grande\n", + "* ai21:j1-jumbo\n", + "* ai21:j1-grande-instruct\n", + "* ai21:j2-large\n", + "* ai21:j2-grande\n", + "* ai21:j2-jumbo\n", + "* ai21:j2-grande-instruct\n", + "* ai21:j2-jumbo-instruct\n", + "\n", + "anthropic\n", + "Requires environment variable ANTHROPIC_API_KEY (set)\n", + "* anthropic:claude-v1\n", + "* anthropic:claude-v1.0\n", + "* anthropic:claude-v1.2\n", + "* anthropic:claude-instant-v1\n", + "* anthropic:claude-instant-v1.0\n", + "\n", + "cohere\n", + "Requires environment variable COHERE_API_KEY (set)\n", + "* cohere:medium\n", + "* cohere:xlarge\n", + "\n", + "huggingface_hub\n", + "Requires environment variable HUGGINGFACEHUB_API_TOKEN (set)\n", + "* This provider does not define a list of models.\n", + "\n", + "openai\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai:text-davinci-003\n", + "* openai:text-davinci-002\n", + "* openai:text-curie-001\n", + "* openai:text-babbage-001\n", + "* openai:text-ada-001\n", + "* openai:davinci\n", + "* openai:curie\n", + "* openai:babbage\n", + "* openai:ada\n", + "\n", + "openai-chat\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat:gpt-4\n", + "* openai-chat:gpt-4-0314\n", + "* openai-chat:gpt-4-32k\n", + "* openai-chat:gpt-4-32k-0314\n", + "* openai-chat:gpt-3.5-turbo\n", + "* openai-chat:gpt-3.5-turbo-0301\n", + "\n", + "openai-chat-new\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat-new:gpt-4\n", + "* openai-chat-new:gpt-4-0314\n", + "* openai-chat-new:gpt-4-32k\n", + "* openai-chat-new:gpt-4-32k-0314\n", + "* openai-chat-new:gpt-3.5-turbo\n", + "* openai-chat-new:gpt-3.5-turbo-0301\n", + "\n", + "sagemaker-endpoint\n", + "* This provider does not define a list of models.\n", + "\n", + "\n", + "Aliases and custom commands:\n", + "gpt2 - huggingface_hub:gpt2\n", + "gpt3 - openai:text-davinci-003\n", + "chatgpt - openai-chat:gpt-3.5-turbo\n", + "gpt4 - openai-chat:gpt-4\n", + "mychat - openai:text-davinci-003\n" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai list" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "971f823d-8f35-4201-a6cb-12a21a87628a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "Deleted alias `mychat`" + ], + "text/plain": [ + "Deleted alias `mychat`" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai delete mychat" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6fb1b54a-ee4f-41c0-b666-ea9e0a1279ee", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "| Provider | Environment variable | Set? | Models |\n", + "|----------|----------------------|------|--------|\n", + "| `ai21` | `AI21_API_KEY` | | `ai21:j1-large`, `ai21:j1-grande`, `ai21:j1-jumbo`, `ai21:j1-grande-instruct`, `ai21:j2-large`, `ai21:j2-grande`, `ai21:j2-jumbo`, `ai21:j2-grande-instruct`, `ai21:j2-jumbo-instruct` |\n", + "| `anthropic` | `ANTHROPIC_API_KEY` | | `anthropic:claude-v1`, `anthropic:claude-v1.0`, `anthropic:claude-v1.2`, `anthropic:claude-instant-v1`, `anthropic:claude-instant-v1.0` |\n", + "| `cohere` | `COHERE_API_KEY` | | `cohere:medium`, `cohere:xlarge` |\n", + "| `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | | This provider does not define a list of models. |\n", + "| `openai` | `OPENAI_API_KEY` | | `openai:text-davinci-003`, `openai:text-davinci-002`, `openai:text-curie-001`, `openai:text-babbage-001`, `openai:text-ada-001`, `openai:davinci`, `openai:curie`, `openai:babbage`, `openai:ada` |\n", + "| `openai-chat` | `OPENAI_API_KEY` | | `openai-chat:gpt-4`, `openai-chat:gpt-4-0314`, `openai-chat:gpt-4-32k`, `openai-chat:gpt-4-32k-0314`, `openai-chat:gpt-3.5-turbo`, `openai-chat:gpt-3.5-turbo-0301` |\n", + "| `openai-chat-new` | `OPENAI_API_KEY` | | `openai-chat-new:gpt-4`, `openai-chat-new:gpt-4-0314`, `openai-chat-new:gpt-4-32k`, `openai-chat-new:gpt-4-32k-0314`, `openai-chat-new:gpt-3.5-turbo`, `openai-chat-new:gpt-3.5-turbo-0301` |\n", + "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n", + "\n", + "Aliases and custom commands:\n", + "\n", + "| Name | Target |\n", + "|------|--------|\n", + "| `gpt2` | `huggingface_hub:gpt2` |\n", + "| `gpt3` | `openai:text-davinci-003` |\n", + "| `chatgpt` | `openai-chat:gpt-3.5-turbo` |\n", + "| `gpt4` | `openai-chat:gpt-4` |\n" + ], + "text/plain": [ + "ai21\n", + "Requires environment variable AI21_API_KEY (set)\n", + "* ai21:j1-large\n", + "* ai21:j1-grande\n", + "* ai21:j1-jumbo\n", + "* ai21:j1-grande-instruct\n", + "* ai21:j2-large\n", + "* ai21:j2-grande\n", + "* ai21:j2-jumbo\n", + "* ai21:j2-grande-instruct\n", + "* ai21:j2-jumbo-instruct\n", + "\n", + "anthropic\n", + "Requires environment variable ANTHROPIC_API_KEY (set)\n", + "* anthropic:claude-v1\n", + "* anthropic:claude-v1.0\n", + "* anthropic:claude-v1.2\n", + "* anthropic:claude-instant-v1\n", + "* anthropic:claude-instant-v1.0\n", + "\n", + "cohere\n", + "Requires environment variable COHERE_API_KEY (set)\n", + "* cohere:medium\n", + "* cohere:xlarge\n", + "\n", + "huggingface_hub\n", + "Requires environment variable HUGGINGFACEHUB_API_TOKEN (set)\n", + "* This provider does not define a list of models.\n", + "\n", + "openai\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai:text-davinci-003\n", + "* openai:text-davinci-002\n", + "* openai:text-curie-001\n", + "* openai:text-babbage-001\n", + "* openai:text-ada-001\n", + "* openai:davinci\n", + "* openai:curie\n", + "* openai:babbage\n", + "* openai:ada\n", + "\n", + "openai-chat\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat:gpt-4\n", + "* openai-chat:gpt-4-0314\n", + "* openai-chat:gpt-4-32k\n", + "* openai-chat:gpt-4-32k-0314\n", + "* openai-chat:gpt-3.5-turbo\n", + "* openai-chat:gpt-3.5-turbo-0301\n", + "\n", + "openai-chat-new\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat-new:gpt-4\n", + "* openai-chat-new:gpt-4-0314\n", + "* openai-chat-new:gpt-4-32k\n", + "* openai-chat-new:gpt-4-32k-0314\n", + "* openai-chat-new:gpt-3.5-turbo\n", + "* openai-chat-new:gpt-3.5-turbo-0301\n", + "\n", + "sagemaker-endpoint\n", + "* This provider does not define a list of models.\n", + "\n", + "\n", + "Aliases and custom commands:\n", + "gpt2 - huggingface_hub:gpt2\n", + "gpt3 - openai:text-davinci-003\n", + "chatgpt - openai-chat:gpt-3.5-turbo\n", + "gpt4 - openai-chat:gpt-4\n" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai list" + ] + }, + { + "cell_type": "markdown", + "id": "178e9916-5bd8-40bf-a273-ea03553663b4", + "metadata": {}, + "source": [ + "## Custom chains\n", + "\n", + "You can define a LangChain chain in a local variable and use that as the target in a magic `%ai register` command." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9ef639c7-1ca8-48af-b5e5-f76cfa5779f3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import OpenAI\n", + "\n", + "llm = OpenAI(temperature=0.9)\n", + "prompt = PromptTemplate(\n", + " input_variables=[\"product\"],\n", + " template=\"What is a good name for a company that makes {product}?\",\n", + ")\n", + "chain = LLMChain(llm=llm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "29d5239f-7601-405e-b059-4e881ebf7ab1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "chain = LLMChain(llm=llm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "43e7a77c-93af-4ef7-a104-f932c9f54183", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Bright Feet Socks\n" + ] + } + ], + "source": [ + "# Run the chain only specifying the input variable.\n", + "print(chain.run(\"colorful socks\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9badc567-9720-4e33-ab4a-54fda5129f36", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "Registered new alias `company`" + ], + "text/plain": [ + "Registered new alias `company`" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai register company chain" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "92b75d71-8844-4872-b424-b0023706abb1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "| Provider | Environment variable | Set? | Models |\n", + "|----------|----------------------|------|--------|\n", + "| `ai21` | `AI21_API_KEY` | | `ai21:j1-large`, `ai21:j1-grande`, `ai21:j1-jumbo`, `ai21:j1-grande-instruct`, `ai21:j2-large`, `ai21:j2-grande`, `ai21:j2-jumbo`, `ai21:j2-grande-instruct`, `ai21:j2-jumbo-instruct` |\n", + "| `anthropic` | `ANTHROPIC_API_KEY` | | `anthropic:claude-v1`, `anthropic:claude-v1.0`, `anthropic:claude-v1.2`, `anthropic:claude-instant-v1`, `anthropic:claude-instant-v1.0` |\n", + "| `cohere` | `COHERE_API_KEY` | | `cohere:medium`, `cohere:xlarge` |\n", + "| `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | | This provider does not define a list of models. |\n", + "| `openai` | `OPENAI_API_KEY` | | `openai:text-davinci-003`, `openai:text-davinci-002`, `openai:text-curie-001`, `openai:text-babbage-001`, `openai:text-ada-001`, `openai:davinci`, `openai:curie`, `openai:babbage`, `openai:ada` |\n", + "| `openai-chat` | `OPENAI_API_KEY` | | `openai-chat:gpt-4`, `openai-chat:gpt-4-0314`, `openai-chat:gpt-4-32k`, `openai-chat:gpt-4-32k-0314`, `openai-chat:gpt-3.5-turbo`, `openai-chat:gpt-3.5-turbo-0301` |\n", + "| `openai-chat-new` | `OPENAI_API_KEY` | | `openai-chat-new:gpt-4`, `openai-chat-new:gpt-4-0314`, `openai-chat-new:gpt-4-32k`, `openai-chat-new:gpt-4-32k-0314`, `openai-chat-new:gpt-3.5-turbo`, `openai-chat-new:gpt-3.5-turbo-0301` |\n", + "| `sagemaker-endpoint` | Not applicable. | N/A | This provider does not define a list of models. |\n", + "\n", + "Aliases and custom commands:\n", + "\n", + "| Name | Target |\n", + "|------|--------|\n", + "| `gpt2` | `huggingface_hub:gpt2` |\n", + "| `gpt3` | `openai:text-davinci-003` |\n", + "| `chatgpt` | `openai-chat:gpt-3.5-turbo` |\n", + "| `gpt4` | `openai-chat:gpt-4` |\n", + "| `company` | *custom chain* |\n" + ], + "text/plain": [ + "ai21\n", + "Requires environment variable AI21_API_KEY (set)\n", + "* ai21:j1-large\n", + "* ai21:j1-grande\n", + "* ai21:j1-jumbo\n", + "* ai21:j1-grande-instruct\n", + "* ai21:j2-large\n", + "* ai21:j2-grande\n", + "* ai21:j2-jumbo\n", + "* ai21:j2-grande-instruct\n", + "* ai21:j2-jumbo-instruct\n", + "\n", + "anthropic\n", + "Requires environment variable ANTHROPIC_API_KEY (set)\n", + "* anthropic:claude-v1\n", + "* anthropic:claude-v1.0\n", + "* anthropic:claude-v1.2\n", + "* anthropic:claude-instant-v1\n", + "* anthropic:claude-instant-v1.0\n", + "\n", + "cohere\n", + "Requires environment variable COHERE_API_KEY (set)\n", + "* cohere:medium\n", + "* cohere:xlarge\n", + "\n", + "huggingface_hub\n", + "Requires environment variable HUGGINGFACEHUB_API_TOKEN (set)\n", + "* This provider does not define a list of models.\n", + "\n", + "openai\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai:text-davinci-003\n", + "* openai:text-davinci-002\n", + "* openai:text-curie-001\n", + "* openai:text-babbage-001\n", + "* openai:text-ada-001\n", + "* openai:davinci\n", + "* openai:curie\n", + "* openai:babbage\n", + "* openai:ada\n", + "\n", + "openai-chat\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat:gpt-4\n", + "* openai-chat:gpt-4-0314\n", + "* openai-chat:gpt-4-32k\n", + "* openai-chat:gpt-4-32k-0314\n", + "* openai-chat:gpt-3.5-turbo\n", + "* openai-chat:gpt-3.5-turbo-0301\n", + "\n", + "openai-chat-new\n", + "Requires environment variable OPENAI_API_KEY (set)\n", + "* openai-chat-new:gpt-4\n", + "* openai-chat-new:gpt-4-0314\n", + "* openai-chat-new:gpt-4-32k\n", + "* openai-chat-new:gpt-4-32k-0314\n", + "* openai-chat-new:gpt-3.5-turbo\n", + "* openai-chat-new:gpt-3.5-turbo-0301\n", + "\n", + "sagemaker-endpoint\n", + "* This provider does not define a list of models.\n", + "\n", + "\n", + "Aliases and custom commands:\n", + "gpt2 - huggingface_hub:gpt2\n", + "gpt3 - openai:text-davinci-003\n", + "chatgpt - openai-chat:gpt-3.5-turbo\n", + "gpt4 - openai-chat:gpt-4\n", + "company - custom chain\n" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%ai list" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "cfef0fee-a7c6-49e4-8d90-9aa12f7b91d1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "\n", + "**Funky Toes**" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": { + "text/markdown": { + "jupyter_ai": { + "custom_chain_id": "company" + } + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "%%ai company\n", + "colorful socks" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "06c698e7-e2cf-41b5-88de-2be4d3b60eba", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "\n", + "Rainbow Toe Socks." + ] + }, + "execution_count": 20, + "metadata": { + "jupyter_ai": { + "custom_chain_id": "company" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "%%ai company --format text\n", + "colorful socks" ] }, { "cell_type": "code", "execution_count": null, - "id": "d47aa5ec-31e0-4670-b1d1-6b4cd7f75832", + "id": "849f48a3-9477-4a12-afa6-fd2e79c0764f", "metadata": {}, "outputs": [], "source": [] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index bc1f9b362..70a16a07a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -1,7 +1,9 @@ import base64 import json +import keyword import os import re +import sys import warnings from typing import Optional @@ -12,8 +14,17 @@ from jupyter_ai_magics.utils import decompose_model_id, load_providers from .providers import BaseProvider -from .parsers import cell_magic_parser, line_magic_parser, CellArgs, ErrorArgs, HelpArgs, ListArgs - +from .parsers import (cell_magic_parser, + line_magic_parser, + CellArgs, + DeleteArgs, + ErrorArgs, + HelpArgs, + ListArgs, + RegisterArgs, + UpdateArgs) + +from langchain.chains import LLMChain MODEL_ID_ALIASES = { "gpt2": "huggingface_hub:gpt2", @@ -68,6 +79,8 @@ def _repr_mimebundle_(self, include=None, exclude=None): "text": TextWithMetadata } +NA_MESSAGE = 'N/A' + MARKDOWN_PROMPT_TEMPLATE = '{prompt}\n\nProduce output in markdown format only.' PROVIDER_NO_MODELS = 'This provider does not define a list of models.' @@ -92,7 +105,7 @@ def _repr_mimebundle_(self, include=None, exclude=None): "text": '{prompt}' # No customization } -AI_COMMANDS = { "error", "help", "list" } +AI_COMMANDS = { "delete", "error", "help", "list", "register", "update" } class FormatDict(dict): """Subclass of dict to be passed to str#format(). Suppresses KeyError and @@ -119,7 +132,9 @@ def __init__(self, shell): "`from langchain.chat_models import ChatOpenAI`") self.providers = load_providers() - + + # initialize a registry of custom model/chain names + self.custom_model_registry = MODEL_ID_ALIASES def _ai_bulleted_list_models_for_provider(self, provider_id, Provider): output = "" @@ -146,7 +161,7 @@ def _ai_inline_list_models_for_provider(self, provider_id, Provider): # Is the required environment variable set? def _ai_env_status_for_provider_markdown(self, provider_id): - na_message = 'Not applicable. | N/A ' + na_message = 'Not applicable. | ' + NA_MESSAGE if (provider_id not in self.providers or self.providers[provider_id].auth_strategy == None): @@ -185,6 +200,80 @@ def _ai_env_status_for_provider_text(self, provider_id): return output + "\n" + # Is this a name of a Python variable that can be called as a LangChain chain? + def _is_langchain_chain(self, name): + # Reserved word in Python? + if (keyword.iskeyword(name)): + return False; + + acceptable_name = re.compile('^[a-zA-Z0-9_]+$') + if (not acceptable_name.match(name)): + return False; + + ipython = get_ipython() + return(name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain)) + + # Is this an acceptable name for an alias? + def _validate_name(self, register_name): + # A registry name contains ASCII letters, numbers, hyphens, underscores, + # and periods. No other characters, including a colon, are permitted + acceptable_name = re.compile('^[a-zA-Z0-9._-]+$') + if (not acceptable_name.match(register_name)): + raise ValueError('A registry name may contain ASCII letters, numbers, hyphens, underscores, ' + + 'and periods. No other characters, including a colon, are permitted') + + # Initially set or update an alias to a target + def _safely_set_target(self, register_name, target): + # If target is a string, treat this as an alias to another model. + if self._is_langchain_chain(target): + ip = get_ipython() + self.custom_model_registry[register_name] = ip.user_ns[target] + else: + # Ensure that the destination is properly formatted + if (':' not in target): + raise ValueError( + 'Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format') + + self.custom_model_registry[register_name] = target + + def handle_delete(self, args: DeleteArgs): + if (args.name in AI_COMMANDS): + raise ValueError(f"Reserved command names, including {args.name}, cannot be deleted") + + if (args.name not in self.custom_model_registry): + raise ValueError(f"There is no alias called {args.name}") + + del self.custom_model_registry[args.name] + output = f"Deleted alias `{args.name}`" + return TextOrMarkdown(output, output) + + def handle_register(self, args: RegisterArgs): + # Existing command names are not allowed + if (args.name in AI_COMMANDS): + raise ValueError(f"The name {args.name} is reserved for a command") + + # Existing registered names are not allowed + if (args.name in self.custom_model_registry): + raise ValueError(f"The name {args.name} is already associated with a custom model; " + + 'use %ai update to change its target') + + # Does the new name match expected format? + self._validate_name(args.name) + + self._safely_set_target(args.name, args.target) + output = f"Registered new alias `{args.name}`" + return TextOrMarkdown(output, output) + + def handle_update(self, args: UpdateArgs): + if (args.name in AI_COMMANDS): + raise ValueError(f"Reserved command names, including {args.name}, cannot be updated") + + if (args.name not in self.custom_model_registry): + raise ValueError(f"There is no alias called {args.name}") + + self._safely_set_target(args.name, args.target) + output = f"Updated target of alias `{args.name}`" + return TextOrMarkdown(output, output) def _ai_list_command_markdown(self, single_provider=None): output = ("| Provider | Environment variable | Set? | Models |\n" @@ -201,6 +290,20 @@ def _ai_list_command_markdown(self, single_provider=None): + self._ai_inline_list_models_for_provider(provider_id, Provider) + " |\n") + # Also list aliases. + if (single_provider is None and len(self.custom_model_registry) > 0): + output += ("\nAliases and custom commands:\n\n" + + "| Name | Target |\n" + + "|------|--------|\n") + for key, value in self.custom_model_registry.items(): + output += f"| `{key}` | " + if isinstance(value, str): + output += f"`{value}`" + else: + output += "*custom chain*" + + output += " |\n" + return output def _ai_list_command_text(self, single_provider=None): @@ -216,6 +319,18 @@ def _ai_list_command_text(self, single_provider=None): + self._ai_env_status_for_provider_text(provider_id) # includes \n if nonblank + self._ai_bulleted_list_models_for_provider(provider_id, Provider)) + # Also list aliases. + if (single_provider is None and len(self.custom_model_registry) > 0): + output += "\nAliases and custom commands:\n" + for key, value in self.custom_model_registry.items(): + output += f"{key} - " + if isinstance(value, str): + output += value + else: + output += "custom chain" + + output += "\n" + return output def handle_error(self, args: ErrorArgs): @@ -261,6 +376,10 @@ def _append_exchange_openai(self, prompt: str, output: str): }) def _decompose_model_id(self, model_id: str): + """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" + if model_id in self.custom_model_registry: + model_id = self.custom_model_registry[model_id] + return decompose_model_id(model_id, self.providers) def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: @@ -269,6 +388,35 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: return None return self.providers[provider_id] + + def display_output(self, output, display_format, md): + # build output display + DisplayClass = DISPLAYS_BY_FORMAT[display_format] + + # if the user wants code, add another cell with the output. + if display_format == 'code': + # Strip a leading language indicator and trailing triple-backticks + lang_indicator = r'^```[a-zA-Z0-9]*\n' + output = re.sub(lang_indicator, '', output) + output = re.sub(r'\n```$', '', output) + new_cell_payload = dict( + source='set_next_input', + text=output, + replace=False, + ) + ip = get_ipython() + ip.payload_manager.write_payload(new_cell_payload) + return HTML('AI generated code inserted below ⬇️', metadata=md); + + if DisplayClass is None: + return output + if display_format == 'json': + # JSON display expects a dict, not a JSON string + output = json.loads(output) + output_display = DisplayClass(output, metadata=md) + + # finally, display output display + return output_display def handle_help(self, _: HelpArgs): with click.Context(cell_magic_parser, info_name="%%ai") as ctx: @@ -287,7 +435,24 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt = prompt) - # determine provider and local model IDs + # interpolate user namespace into prompt + ip = get_ipython() + prompt = prompt.format_map(FormatDict(ip.user_ns)) + + # Determine provider and local model IDs + # If this is a custom chain, send the message to the custom chain. + if (args.model_id in self.custom_model_registry and + isinstance(self.custom_model_registry[args.model_id], LLMChain)): + + return self.display_output( + self.custom_model_registry[args.model_id].run(prompt), + args.format, + { + "jupyter_ai": { + "custom_chain_id": args.model_id + } + }) + provider_id, local_model_id = self._decompose_model_id(args.model_id) Provider = self._get_provider(provider_id) if Provider is None: @@ -328,9 +493,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str): if provider_id == "openai-chat": self._append_exchange_openai(prompt, output) - # build output display - DisplayClass = DISPLAYS_BY_FORMAT[args.format] - md = { "jupyter_ai": { "provider_id": provider_id, @@ -338,30 +500,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): } } - # if the user wants code, add another cell with the output. - if args.format == 'code': - # Strip a leading language indicator and trailing triple-backticks - lang_indicator = r'^```[a-zA-Z0-9]*\n' - output = re.sub(lang_indicator, '', output) - output = re.sub(r'\n```$', '', output) - new_cell_payload = dict( - source='set_next_input', - text=output, - replace=False, - ) - ip = get_ipython() - ip.payload_manager.write_payload(new_cell_payload) - return HTML('AI generated code inserted below ⬇️', metadata=md); - - if DisplayClass is None: - return output - if args.format == 'json': - # JSON display expects a dict, not a JSON string - output = json.loads(output) - output_display = DisplayClass(output, metadata=md) - - # finally, display output display - return output_display + return self.display_output(output, args.format, md) @line_cell_magic def ai(self, line, cell=None): @@ -376,13 +515,24 @@ def ai(self, line, cell=None): # case we want to exit early. return - if args.type == "error": - return self.handle_error(args) - if args.type == "help": - return self.handle_help(args) - if args.type == "list": - return self.handle_list(args) - + # If a value error occurs, don't print the full stacktrace + try: + if args.type == "error": + return self.handle_error(args) + if args.type == "help": + return self.handle_help(args) + if args.type == "list": + return self.handle_list(args) + if args.type == "register": + return self.handle_register(args) + if args.type == "delete": + return self.handle_delete(args) + if args.type == "update": + return self.handle_update(args) + except ValueError as e: + print(e, file=sys.stderr) + return + # hint to the IDE that this object must be of type `RootArgs` args: CellArgs = args diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index c4820f5f4..db63cd4da 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -25,6 +25,20 @@ class ListArgs(BaseModel): type: Literal["list"] = "list" provider_id: Optional[str] +class RegisterArgs(BaseModel): + type: Literal["register"] = "register" + name: str + target: str + +class DeleteArgs(BaseModel): + type: Literal["delete"] = "delete" + name: str + +class UpdateArgs(BaseModel): + type: Literal["update"] = "update" + name: str + target: str + class LineMagicGroup(click.Group): """Helper class to print the help string for cell magics as well when `%ai --help` is called.""" @@ -89,3 +103,29 @@ def help_subparser(): def list_subparser(**kwargs): """List language models, optionally scoped to PROVIDER_ID.""" return ListArgs(**kwargs) + +@line_magic_parser.command(name='register', + short_help="Register a new alias. See `%ai register --help` for options." +) +@click.argument('name') +@click.argument('target') +def register_subparser(**kwargs): + """Register a new alias called NAME for the model or chain named TARGET.""" + return RegisterArgs(**kwargs) + +@line_magic_parser.command(name='delete', + short_help="Delete an alias. See `%ai delete --help` for options." +) +@click.argument('name') +def register_subparser(**kwargs): + """Delete an alias called NAME.""" + return DeleteArgs(**kwargs) + +@line_magic_parser.command(name='update', + short_help="Update the target of an alias. See `%ai update --help` for options." +) +@click.argument('name') +@click.argument('target') +def register_subparser(**kwargs): + """Update an alias called NAME to refer to the model or chain named TARGET.""" + return UpdateArgs(**kwargs)