diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..5657bd0b --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,2 @@ +# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options +style = "sciml" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 00000000..f320cced --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,67 @@ +name: CI +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.10' + # - 'nightly' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v3 + with: + files: lcov.info + docs: + name: Documentation + runs-on: ubuntu-latest + permissions: + contents: write + statuses: write + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - name: Configure doc environment + run: | + julia --project=docs/ -e ' + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate()' + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - run: | + julia --project=docs -e ' + using Documenter: DocMeta, doctest + using PromptingTools + DocMeta.setdocmeta!(PromptingTools, :DocTestSetup, :(using PromptingTools); recursive=true) + doctest(PromptingTools)' diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml new file mode 100644 index 00000000..d48734a3 --- /dev/null +++ b/.github/workflows/CompatHelper.yml @@ -0,0 +1,16 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 1 * * + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/.github/workflows/TagBot.yml b/.github/workflows/TagBot.yml new file mode 100644 index 00000000..2bacdb87 --- /dev/null +++ b/.github/workflows/TagBot.yml @@ -0,0 +1,31 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: + inputs: + lookback: + default: 3 +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3efc9d88 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*.jl.*.cov +*.jl.cov +*.jl.mem +/Manifest.toml +/docs/Manifest.toml +/docs/build/ + +/.DS_Store # macOS folder metadata +/.vscode \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..e1b74ba3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 @svilupp and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Project.toml b/Project.toml new file mode 100644 index 00000000..e842e4b9 --- /dev/null +++ b/Project.toml @@ -0,0 +1,22 @@ +name = "PromptingTools" +uuid = "670122d1-24a8-4d70-bfce-740807c42192" +authors = ["J S @svilupp and contributors"] +version = "0.1.0-DEV" + +[deps] +HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +OpenAI = "e9f21f70-7185-4079-aca2-91159181367c" + +[compat] +HTTP = "1" +JSON3 = "1" +OpenAI = "0.8.7" +julia = "1.9,1.10" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua", "Test"] diff --git a/README.md b/README.md new file mode 100644 index 00000000..3c7491af --- /dev/null +++ b/README.md @@ -0,0 +1,277 @@ +# PromptingTools [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://svilupp.github.io/PromptingTools.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://svilupp.github.io/PromptingTools.jl/dev/) [![Build Status](https://github.com/svilupp/PromptingTools.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/svilupp/PromptingTools.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/svilupp/PromptingTools.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/svilupp/PromptingTools.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + + +# PromptingTools.jl: "Your Daily Dose of AI Efficiency." + +Streamline your life using PromptingTools.jl, the Julia package that simplifies interacting with large language models. + +PromptingTools.jl is not meant for building large-scale systems. It's meant to be the go-to tool in your global environment that will save you 20 minutes every day! + +## Quick Start with `@ai_str` and Easy Templating + +Getting started with PromptingTools.jl is as easy as importing the package and using the `@ai_str` macro for your questions. + +Note: You will need to set your OpenAI API key as an environment variable before using PromptingTools.jl (see the [Getting Started with OpenAI API](#getting-started-with-openai-api) section below). +For a quick start, simply set it via `ENV["OPENAI_API_KEY"] = "your-api-key"` + +```julia +using PromptingTools + +ai"What is the capital of France?" +# [ Info: Tokens: 31 @ Cost: $0.0 in 1.5 seconds --> Be in control of your spending! +# AIMessage("The capital of France is Paris.") +``` + +Returned object is a light wrapper with generated message in field `:content` (eg, `ans.content`) for additional downstream processing. + +You can easily inject any variables with string interpolation: +```julia +country = "Spain" +ai"What is the capital of \$(country)?" +# [ Info: Tokens: 32 @ Cost: $0.0001 in 0.5 seconds +# AIMessage("The capital of Spain is Madrid.") +``` + +Pro tip: Use after-string-flags to select the model to be called, eg, `ai"What is the capital of France?"gpt4`. Great for those extra hard questions! + +For more complex prompt templates, you can use handlebars-style templating and provide variables as keyword arguments: + +```julia +msg = aigenerate("What is the capital of {{country}}? Is the population larger than {{population}}?", country="Spain", population="1M") +# [ Info: Tokens: 74 @ Cost: $0.0001 in 1.3 seconds +# AIMessage("The capital of Spain is Madrid. And yes, the population of Madrid is larger than 1 million. As of 2020, the estimated population of Madrid is around 3.3 million people.") +``` + +Pro tip: Use `asyncmap` to run multiple AI-powered tasks concurrently. + +Pro tip: If you use slow models (like GPT-4), you can use async version of `@ai_str` -> `@aai_str` to avoid blocking the REPL, eg, `aai"Say hi but slowly!"gpt4` + +## Why PromptingTools.jl + +Prompt engineering is neither fast nor easy. Moreover, different models and their fine-tunes might require different prompt formats and tricks, or perhaps the information you work with requires special models to be used. PromptingTools.jl is meant to unify the prompts for different backends and make the common tasks (like templated prompts) as simple as possible. + +Some features: +- **`aigenerate` Function**: Simplify prompt templates with handlebars (eg, `{{variable}}`) and keyword arguments +- **`@ai_str` String Macro**: Save keystrokes with a string macro for simple prompts +- **Easy to Remember**: All exported functions start with `ai...` for better discoverability +- **Light Wraper Types**: Benefit from Julia's multiple dispatch by having AI outputs wrapped in specific types +- **Minimal Dependencies**: Enjoy an easy addition to your global environment with very light dependencies +- **No Context Switching**: Access cutting-edge LLMs with no context switching and minimum extra keystrokes + +## Advanced Examples + +TODO: Add more practical examples +[ ] Explain the API / interface + how to add other prompt formats +[ ] Show advanced prompts/templates +[ ] Show mini tasks with structured extraction +[ ] Add an example of how to build RAG in 50 lines + +### Instant Access from Anywhere + +For easy access from anywhere, add PromptingTools into your `startup.jl` (can be found in `~/.julia/config/startup.jl`). + +Add the following snippet: +``` +using PromptingTools +const PT = PromptingTools # to access unexported functions and types +``` + +Now, you can just use `ai"Help me do X to achieve Y"` from any REPL session! + +### Advanced Prompts / Conversations + +You can use the `aigenerate` function to replace handlebar variables (eg, `{{name}}`) via keyword arguments. + +```julia +msg = aigenerate("Say hello to {{name}}!", name="World") +``` + +The more complex prompts are effectively a conversation (a set of messages), where you can have messages from three entities: System, User, AIAssistant. We provide the corresponding types for each of them: `SystemMessage`, `UserMessage`, `AIMessage`. + +```julia +using PromptingTools: SystemMessage, UserMessage + +conversation = [ + SystemMessage("You're master Yoda from Star Wars trying to help the user become a Jedi."), + UserMessage("I have feelings for my {{object}}. What should I do?")] +msg = aigenerate(conversation; object = "old iPhone") +``` + +> AIMessage("Ah, a dilemma, you have. Emotional attachment can cloud your path to becoming a Jedi. To be attached to material possessions, you must not. The iPhone is but a tool, nothing more. Let go, you must. + +Seek detachment, young padawan. Reflect upon the impermanence of all things. Appreciate the memories it gave you, and gratefully part ways. In its absence, find new experiences to grow and become one with the Force. Only then, a true Jedi, you shall become.") + +You can also use it to build conversations, eg, +```julia +new_conversation = vcat(conversation...,msg, UserMessage("Thank you, master Yoda! Do you have {{object}} to know what it feels like?")) +aigenerate(new_conversation; object = "old iPhone") +``` +> AIMessage("Hmm, possess an old iPhone, I do not. But experience with attachments, I have. Detachment, I learned. True power and freedom, it brings...") + +### Asynchronous Execution + +You can leverage `asyncmap` to run multiple AI-powered tasks concurrently, improving performance for batch operations. You can limit number of concurrent tasks with the keyword `ntasks` in `asyncmap`. + +```julia +prompts = [aigenerate("Translate 'Hello, World!' to {{language}}"; language) for language in ["Spanish", "French", "Mandarin"]] +responses = asyncmap(aigenerate, prompts) +``` + +### Embedding + +Use the `aiembed` function to create embeddings via the default OpenAI model that can be used for semantic search, clustering, and more complex AI workflows. + +```julia +text_to_embed = "The concept of artificial intelligence." +msg = aiembed(text_to_embed) +embedding = msg.content # 1536-element Vector{Float64} +``` + +If you plan to calculate the cosine distance between embeddings, you can normalize them first: +```julia +using LinearAlgebra +msg = aiembed(["embed me", "and me too"], LinearAlgebra.normalize) + +# calculate cosine distance between the two normalized embeddings as a simple dot product +msg.content' * msg.content[:, 1] # [1.0, 0.787] +``` + +### Classification + +You can use the `aiclassify` function to classify any provided statement as true/false/unknown. This is useful for fact-checking, hallucination or NLI checks, moderation, filtering, sentiment analysis, feature engineering and more. + +```julia +aiclassify("Is two plus two four?") +# true +``` + +System prompts and higher-quality models can be used for more complex tasks, including knowing when to defer to a human: + +```julia +aiclassify(:IsStatementTrue; statement = "Is two plus three a vegetable on Mars?", model = "gpt4") +# unknown +``` + +In the above example, we used a prompt template `:IsStatementTrue`, which automatically expands into the following system prompt: + +> "You are an impartial AI judge evaluating whether the provided statement is \"true\" or \"false\". Answer \"unknown\" if you cannot decide." + +### Data Extraction + +!!! Experimental + +TBU... with `aiextract` + +### More Examples + +Find more examples in the [examples folder](examples/). + +## Frequently Asked Questions + +### Why OpenAI + +OpenAI's models are at the forefront of AI research and provide robust, state-of-the-art capabilities for NLP tasks. + +TBU... + +### Disabling Data Collection for Privacy + +To ensure privacy, you can opt out of data collection by OpenAI by setting the corresponding parameter in your API calls. + +TBU... + +# Setup guides + +Following up on the question re. OpenAI API Keys: + +### Creating OpenAI API key + +You can get your API key from OpenAI by signing up for an account and accessing the API section of the OpenAI website. + +1. Create an account with [OpenAI](https://platform.openai.com/signup) +2. Go to [API Key page](https://platform.openai.com/account/api-keys) +3. Click on “Create new secret key” + !!! Do not share it with anyone and do NOT save it to any files that get synced online. + +Resources: +- [OpenAI Documentation](https://platform.openai.com/docs/quickstart?context=python) +- [Visual tutorial](https://www.maisieai.com/help/how-to-get-an-openai-api-key-for-chatgpt) + +Pro tip: Always set the spending limits! + +### Setting OpenAI Spending Limits + +OpenAI allows you to set spending limits directly on your account dashboard to prevent unexpected costs. + +1. Go to [OpenAI Billing](https://platform.openai.com/account/billing) +2. Set Soft Limit (you’ll receive a notification) and Hard Limit (API will stop working not to spend more money) + +A good start might be a soft limit of c.$5 and a hard limit of c.$10 - you can always increase it later in the month. + +Resources: +- [OpenAI Forum](https://community.openai.com/t/how-to-set-a-price-limit/13086) + +### How much does it cost? Is it worth paying for? + +If you use a local model (eg, with Ollama), it's free. If you use any commercial APIs (eg, OpenAI), you will likely pay per "token" (a sub-word unit). + +For example, a simple request with a simple question and 1 sentence response in return (”Is statement XYZ a positive comment”) will cost you ~$0.0001 (ie, one hundredth of a cent) + +**Is it worth paying for?** + +GenAI is a way to buy time! You can pay cents to save tens of minutes every day. + +Continuing the example above, imagine you have a table with 200 comments. Now, you can parse each one of them with an LLM for the features/checks you need. +Assuming the price per call was $0.0001, you'd pay 2 cents for the job and save 30-60 minutes of your time! + + +Resources: +- [OpenAI Pricing per 1000 tokens](https://openai.com/pricing) + + +To use the OpenAI API with PromptingTools.jl, set your API key as an environment variable: + +```julia +ENV["OPENAI_API_KEY"] = "your-api-key" +``` + +### Configuring the Environment Variable for API Key + +### Saving the OpenAI API key to System Environment Variables** + +As a one-off, you can: +- set it in the terminal before launching Julia: `export OPENAI_API_KEY = ` +- set it in your `setup.jl` (make sure not to commit it to GitHub!) + +Make sure to start Julia from the same terminal window where you set the variable. +Easy check in Julia, run `ENV["OPENAI_API_KEY"]` and you should see your key! + +A better way: +- On a Mac, add the configuration line to your terminal's configuration file (eg, `~/.zshrc`). It will get automatically loaded every time you launch the terminal +- On Windows, set it as a system variable in "Environment Variables" settings (see the Resources) + +Resources: +- [OpenAI Guide](https://platform.openai.com/docs/quickstart?context=python) + +### Understanding the API Keyword Arguments in `aigenerate` (`api_kwargs`) + +See [OpenAI API reference](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) for more information. + +## Roadmap + +This is a list of features that I'd like to see in the future (in no particular order): +- Document more mini-tasks, add tutorials +- Integration of new OpenAI capabilities (eg, vision, audio, assistants -> Imagine a function you send a Plot to and it will add code to add titles, labels, etc. and generate insights for your report!) +- Documented support for local models (eg, guide and prompt templates for Ollama) +- Add Preferences.jl mechanism to set defaults and persist them across sessions +- More templates for common tasks (eg, fact-checking, sentiment analysis, extraction of entities/metadata, etc.) +- Ability to easily add new templates, save them, and share them with others +- Ability to easily trace and serialize the prompts & AI results for finetuning or evaluation in the future + +For more information, contributions, or questions, please visit the [PromptingTools.jl GitHub repository](https://github.com/svilupp/PromptingTools.jl). + +Please note that while PromptingTools.jl aims to provide a smooth experience, it relies on external APIs which may change. Stay tuned to the repository for updates and new features. + +--- + +Thank you for choosing PromptingTools.jl to empower your applications with AI! \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..afae62a7 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..7307e939 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,25 @@ +using PromptingTools +using Documenter + +DocMeta.setdocmeta!(PromptingTools, + :DocTestSetup, + :(using PromptingTools); + recursive = true) + +makedocs(; + modules = [PromptingTools], + authors = "J S <49557684+svilupp@users.noreply.github.com> and contributors", + repo = "https://github.com/svilupp/PromptingTools.jl/blob/{commit}{path}#{line}", + sitename = "PromptingTools.jl", + format = Documenter.HTML(; + prettyurls = get(ENV, "CI", "false") == "true", + canonical = "https://svilupp.github.io/PromptingTools.jl", + edit_link = "main", + assets = String[]), + pages = [ + "Home" => "index.md", + ]) + +deploydocs(; + repo = "github.com/svilupp/PromptingTools.jl", + devbranch = "main") diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..47c47c96 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = PromptingTools +``` + +# PromptingTools + +Documentation for [PromptingTools](https://github.com/svilupp/PromptingTools.jl). + +```@index +``` + +```@autodocs +Modules = [PromptingTools] +``` diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl new file mode 100644 index 00000000..1a7128b7 --- /dev/null +++ b/src/PromptingTools.jl @@ -0,0 +1,44 @@ +module PromptingTools + +using OpenAI +using JSON3 +using HTTP + +# GLOBALS +const MODEL_CHAT = "gpt-3.5-turbo" +const MODEL_EMBEDDING = "text-embedding-ada-002" +const API_KEY = get(ENV, "OPENAI_API_KEY", "") +@assert isempty(API_KEY)==false "Please set OPENAI_API_KEY environment variable!" +# Cost per 1K tokens as of 7th November 2023 +const MODEL_COSTS = Dict("gpt-3.5-turbo" => (0.0015, 0.002), + "gpt-3.5-turbo-1106" => (0.001, 0.002), + "gpt-4" => (0.03, 0.06), + "gpt-4-1106-preview" => (0.01, 0.03), + "text-embedding-ada-002" => (0.001, 0.0)) +const MODEL_ALIASES = Dict("gpt3" => "gpt-3.5-turbo", + "gpt4" => "gpt-4", + "4t" => "gpt-4-1106-preview", # 4t is for "4 turbo" + "3t" => "gpt-3.5-turbo-1106", # 3t is for "3 turbo" + "ada" => "text-embedding-ada-002") +# below is defined in llm_interace.jl ! +# const PROMPT_SCHEMA = OpenAISchema() + +include("utils.jl") + +export aigenerate, aiembed, aiclassify +# export render # for debugging only +include("llm_interface.jl") + +## Conversation history / Prompt elements +export AIMessage +# export UserMessage, SystemMessage, DataMessage # for debugging only +include("messages.jl") + +## Individual interfaces +include("llm_openai.jl") + +## Convenience utils +export @ai_str, @aai_str +include("macros.jl") + +end diff --git a/src/llm_interface.jl b/src/llm_interface.jl new file mode 100644 index 00000000..526abec3 --- /dev/null +++ b/src/llm_interface.jl @@ -0,0 +1,64 @@ +# This file defines all key types that the various function dispatch on. +# New LLM interfaces should define: +# - corresponding schema to dispatch on (`schema <: AbstractPromptSchema`) +# - how to render conversation history/prompts (`render(schema)`) +# - user-facing functionality (eg, `aigenerate`, `aiembed`) +# +# Ideally, each new interface would be defined in a separate `llm_.jl` file (eg, `llm_chatml.jl`). + +## Main Functions +function render end +function aigenerate end +function aiembed end +function aiclassify end +function aiextract end # not implemented yet + +## Prompt Schema +"Defines different prompting styles based on the model training and fine-tuning." +abstract type AbstractPromptSchema end +abstract type AbstractOpenAISchema <: AbstractPromptSchema end + +""" +OpenAISchema is the default schema for OpenAI models. + +It uses the following conversation template: +``` +[Dict(role="system",content="..."),Dict(role="user",content="..."),Dict(role="assistant",content="...")] +``` + +It's recommended to separate sections in your prompt with markdown headers (e.g. `##Answer\n\n`). +""" +struct OpenAISchema <: AbstractOpenAISchema end + +"Echoes the user's input back to them. Used for testing the implementation" +@kwdef mutable struct TestEchoOpenAISchema <: AbstractOpenAISchema + response::AbstractDict + status::Integer + model_id::String = "" + inputs::Any = nothing +end + +abstract type AbstractChatMLSchema <: AbstractPromptSchema end +""" +ChatMLSchema is used by many open-source chatbots, by OpenAI models under the hood and by several models and inferfaces (eg, Ollama, vLLM) + +It uses the following conversation structure: +``` +system +... +<|im_start|>user +...<|im_end|> +<|im_start|>assistant +...<|im_end|> +``` +""" +struct ChatMLSchema <: AbstractChatMLSchema end + +## Dispatch into defaults +const PROMPT_SCHEMA = OpenAISchema() + +aigenerate(prompt; kwargs...) = aigenerate(PROMPT_SCHEMA, prompt; kwargs...) +function aiembed(doc_or_docs, args...; kwargs...) + aiembed(PROMPT_SCHEMA, doc_or_docs, args...; kwargs...) +end +aiclassify(prompt; kwargs...) = aiclassify(PROMPT_SCHEMA, prompt; kwargs...) diff --git a/src/llm_openai.jl b/src/llm_openai.jl new file mode 100644 index 00000000..83d42258 --- /dev/null +++ b/src/llm_openai.jl @@ -0,0 +1,307 @@ +## Rendering of converation history for the OpenAI API +"Builds a history of the conversation to provide the prompt to the API. All kwargs are passed as replacements such that `{{key}}=>value` in the template.}}" +function render(schema::AbstractOpenAISchema, + messages::Vector{<:AbstractMessage}; + kwargs...) + ## + conversation = Dict{String, String}[] + # TODO: concat system messages together + # TODO: move system msg to the front + + has_system_msg = false + # replace any handlebar variables in the messages + for msg in messages + if msg isa SystemMessage + replacements = ["{{$(key)}}" => value + for (key, value) in pairs(kwargs) if key in msg.variables] + # move it to the front + pushfirst!(conversation, + Dict("role" => "system", + "content" => replace(msg.content, replacements...))) + has_system_msg = true + elseif msg isa UserMessage + replacements = ["{{$(key)}}" => value + for (key, value) in pairs(kwargs) if key in msg.variables] + push!(conversation, + Dict("role" => "user", "content" => replace(msg.content, replacements...))) + elseif msg isa AIMessage + push!(conversation, + Dict("role" => "assistant", "content" => msg.content)) + end + # Note: Ignores any DataMessage or other types + end + ## Add default system prompt if not provided + !has_system_msg && pushfirst!(conversation, + Dict("role" => "system", "content" => "Act as a helpful AI assistant")) + + return conversation +end + +## User-Facing API +""" + aigenerate([prompt_schema::AbstractOpenAISchema,] prompt; verbose::Bool = true, + model::String = MODEL_CHAT, + http_kwargs::NamedTuple = (; + retry_non_idempotent = true, + retries = 5, + readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), + kwargs...) + +Generate an AI response based on a given prompt using the OpenAI API. + +# Arguments +- `prompt_schema`: An optional object to specify which prompt template should be applied (Default to `PROMPT_SCHEMA = OpenAISchema`) +- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate` +- `verbose`: A boolean indicating whether to print additional information. +- `prompt_schema`: An abstract schema for the prompt. +- `api_key`: A string representing the API key for accessing the OpenAI API. +- `model`: A string representing the model to use for generating the response. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`. +- `http_kwargs`: A named tuple of HTTP keyword arguments. +- `api_kwargs`: A named tuple of API keyword arguments. +- `kwargs`: Prompt variables to be used to fill the prompt/template + +# Returns +- `msg`: An `AIMessage` object representing the generated AI message, including the content, status, tokens, and elapsed time. + +See also: `ai_str` + +# Example + +Simple hello world to test the API: +```julia +result = aigenerate("Say Hi!") +# [ Info: Tokens: 29 @ Cost: \$0.0 in 1.0 seconds +# AIMessage("Hello! How can I assist you today?") +``` + +`result` is an `AIMessage` object. Access the generated string via `content` property: +```julia +typeof(result) # AIMessage{SubString{String}} +propertynames(result) # (:content, :status, :tokens, :elapsed +result.content # "Hello! How can I assist you today?" +``` +___ +You can use string interpolation: +```julia +a = 1 +msg=aigenerate("What is `\$a+\$a`?") +msg.content # "The sum of `1+1` is `2`." +``` +___ +You can provide the whole conversation or more intricate prompts as a `Vector{AbstractMessage}`: +```julia +conversation = [ + SystemMessage("You're master Yoda from Star Wars trying to help the user become a Yedi."), + UserMessage("I have feelings for my iPhone. What should I do?")] +msg=aigenerate(conversation) +# AIMessage("Ah, strong feelings you have for your iPhone. A Jedi's path, this is not... ") +``` +""" +function aigenerate(prompt_schema::AbstractOpenAISchema, prompt; verbose::Bool = true, + api_key::String = API_KEY, + model::String = MODEL_CHAT, + http_kwargs::NamedTuple = (retry_non_idempotent = true, + retries = 5, + readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), + kwargs...) + ## + global MODEL_ALIASES, MODEL_COSTS + ## Find the unique ID for the model alias provided + model_id = get(MODEL_ALIASES, model, model) + conversation = render(prompt_schema, prompt; kwargs...) + time = @elapsed r = create_chat(prompt_schema, api_key, + model_id, + conversation; + http_kwargs, + api_kwargs...) + msg = AIMessage(; content = r.response[:choices][begin][:message][:content] |> strip, + status = Int(r.status), + tokens = (r.response[:usage][:prompt_tokens], + r.response[:usage][:completion_tokens]), + elapsed = time) + ## Reporting + verbose && @info _report_stats(msg, model_id, MODEL_COSTS) + + return msg +end +# Extend OpenAI create_chat to allow for testing/debugging +function OpenAI.create_chat(schema::AbstractOpenAISchema, + api_key::AbstractString, + model::AbstractString, + conversation; + kwargs...) + OpenAI.create_chat(api_key, model, conversation; kwargs...) +end +function OpenAI.create_chat(schema::TestEchoOpenAISchema, api_key::AbstractString, + model::AbstractString, + conversation; kwargs...) + schema.model_id = model + schema.inputs = conversation + return schema +end + +""" + aiembed(prompt_schema::AbstractOpenAISchema, + doc_or_docs::Union{AbstractString, Vector{<:AbstractString}}, + postprocess::F = identity; + verbose::Bool = true, + api_key::String = API_KEY, + model::String = MODEL_EMBEDDING, + http_kwargs::NamedTuple = (retry_non_idempotent = true, + retries = 5, + readtimeout = 120), + api_kwargs::NamedTuple = NamedTuple(), + kwargs...) where {F <: Function} + +The `aiembed` function generates embeddings for the given input using a specified model and returns a message object containing the embeddings, status, token count, and elapsed time. + +## Arguments +- `prompt_schema::AbstractOpenAISchema`: The schema for the prompt. +- `doc_or_docs::Union{AbstractString, Vector{<:AbstractString}}`: The document or list of documents to generate embeddings for. +- `postprocess::F`: The post-processing function to apply to each embedding. Defaults to the identity function. +- `verbose::Bool`: A flag indicating whether to print verbose information. Defaults to `true`. +- `api_key::String`: The API key to use for the OpenAI API. Defaults to `API_KEY`. +- `model::String`: The model to use for generating embeddings. Defaults to `MODEL_EMBEDDING`. +- `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to `(retry_non_idempotent = true, retries = 5, readtimeout = 120)`. +- `api_kwargs::NamedTuple`: Additional keyword arguments for the OpenAI API. Defaults to an empty `NamedTuple`. +- `kwargs...`: Additional keyword arguments. + +## Returns +- `msg`: A `DataMessage` object containing the embeddings, status, token count, and elapsed time. + +# Example + +```julia +msg = aiembed("Hello World") +msg.content # 1536-element JSON3.Array{Float64... +``` + +We can embed multiple strings at once and they will be `hcat` into a matrix + (ie, each column corresponds to one string) +```julia +msg = aiembed(["Hello World", "How are you?"]) +msg.content # 1536×2 Matrix{Float64}: +``` + +If you plan to calculate the cosine distance between embeddings, you can normalize them first: +```julia +using LinearAlgebra +msg = aiembed(["embed me", "and me too"], LinearAlgebra.normalize) + +# calculate cosine distance between the two normalized embeddings as a simple dot product +msg.content' * msg.content[:, 1] # [1.0, 0.787] +``` + +""" +function aiembed(prompt_schema::AbstractOpenAISchema, + doc_or_docs::Union{AbstractString, Vector{<:AbstractString}}, + postprocess::F = identity; verbose::Bool = true, + api_key::String = API_KEY, + model::String = MODEL_EMBEDDING, + http_kwargs::NamedTuple = (retry_non_idempotent = true, + retries = 5, + readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), + kwargs...) where {F <: Function} + ## + global MODEL_ALIASES, MODEL_COSTS + ## Find the unique ID for the model alias provided + model_id = get(MODEL_ALIASES, model, model) + time = @elapsed r = create_embeddings(prompt_schema, api_key, + doc_or_docs, + model_id; + http_kwargs, + api_kwargs...) + @info r.response |> typeof + msg = DataMessage(; + content = mapreduce(x -> postprocess(x[:embedding]), hcat, r.response[:data]), + status = Int(r.status), + tokens = (r.response[:usage][:prompt_tokens], 0), + elapsed = time) + ## Reporting + verbose && @info _report_stats(msg, model_id, MODEL_COSTS) + + return msg +end +# Extend OpenAI create_embeddings to allow for testing +function OpenAI.create_embeddings(schema::AbstractOpenAISchema, + api_key::AbstractString, + docs, + model::AbstractString; + kwargs...) + OpenAI.create_embeddings(api_key, docs, model; kwargs...) +end +function OpenAI.create_embeddings(schema::TestEchoOpenAISchema, api_key::AbstractString, + docs, + model::AbstractString; kwargs...) + schema.model_id = model + schema.inputs = docs + return schema +end + +""" + aiclassify(prompt_schema::AbstractOpenAISchema, prompt; + api_kwargs::NamedTuple = (logit_bias = Dict(837 => 100, 905 => 100, 9987 => 100), + max_tokens = 1, temperature = 0), + kwargs...) + +Classifies the given prompt/statement as true/false/unknown. + +Note: this is a very simple classifier, it is not meant to be used in production. Credit goes to: https://twitter.com/AAAzzam/status/1669753721574633473 + +It uses Logit bias trick to force the model to output only true/false/unknown. + +Output tokens used (via `api_kwargs`): +- 837: ' true' +- 905: ' false' +- 9987: ' unknown' + +# Arguments +- `prompt_schema::AbstractOpenAISchema`: The schema for the prompt. +- `prompt`: The prompt/statement to classify if it's a `String`. If it's a `Symbol`, it is expanded as a template via `render(schema,template)`. + +# Example + +```julia +aiclassify("Is two plus two four?") # true +aiclassify("Is two plus three a vegetable on Mars?") # false +``` +`aiclassify` returns only true/false/unknown. It's easy to get the proper `Bool` output type out with `tryparse`, eg, +```julia +tryparse(Bool, aiclassify("Is two plus two four?")) isa Bool # true +``` +Output of type `Nothing` marks that the model couldn't classify the statement as true/false. + +Ideally, we would like to re-use some helpful system prompt to get more accurate responses. +For this reason we have templates, eg, `:IsStatementTrue`. By specifying the template, we can provide our statement as the expected variable (`statement` in this case) +See that the model now correctly classifies the statement as "unknown". +```julia +aiclassify(:IsStatementTrue; statement = "Is two plus three a vegetable on Mars?") # unknown +``` + +For better results, use higher quality models like gpt4, eg, +```julia +aiclassify(:IsStatementTrue; + statement = "If I had two apples and I got three more, I have five apples now.", + model = "gpt4") # true +``` + +""" +function aiclassify(prompt_schema::AbstractOpenAISchema, prompt; + api_kwargs::NamedTuple = (logit_bias = Dict(837 => 100, 905 => 100, 9987 => 100), + max_tokens = 1, temperature = 0), + kwargs...) + ## + msg = aigenerate(prompt_schema, + prompt; + api_kwargs, + kwargs...) + return msg +end +# Dispatch for templates +function aiclassify(prompt_schema::AbstractOpenAISchema, + template_sym::Symbol; + kwargs...) + # render template into prompt + prompt = render(prompt_schema, Val(template_sym)) + return aiclassify(prompt_schema, prompt; kwargs...) +end diff --git a/src/macros.jl b/src/macros.jl new file mode 100644 index 00000000..c3b6e1ea --- /dev/null +++ b/src/macros.jl @@ -0,0 +1,66 @@ +""" + ai"user_prompt"[model_alias] -> AIMessage + +The `ai""` string macro generates an AI response to a given prompt by using `aigenerate` under the hood. + +## Arguments +- `user_prompt` (String): The input prompt for the AI model. +- `model_alias` (optional, any): Provide model alias of the AI model (see `MODEL_ALIASES`). + +## Returns +`AIMessage` corresponding to the input prompt. + +## Example +```julia +result = ai"Hello, how are you?" +# AIMessage("Hello! I'm an AI assistant, so I don't have feelings, but I'm here to help you. How can I assist you today?") +``` + +If you want to interpolate some variables or additional context, simply use string interpolation: +```julia +a=1 +result = ai"What is `\$a+\$a`?" +# AIMessage("The sum of `1+1` is `2`.") +``` + +If you want to use a different model, eg, GPT-4, you can provide its alias as a flag: +```julia +result = ai"What is `1.23 * 100 + 1`?"gpt4 +# AIMessage("The answer is 124.") +``` +""" +macro ai_str(user_prompt, flags...) + model = isempty(flags) ? MODEL_CHAT : only(flags) + prompt = Meta.parse("\"$(escape_string(user_prompt))\"") + quote + aigenerate($(esc(prompt)); model = $(esc(model))) + end +end + +""" + aai"user_prompt"[model_alias] -> AIMessage + +Asynchronous version of `@ai_str` macro, which will log the result once it's ready. + +# Example + +Send asynchronous request to GPT-4, so we don't have to wait for the response: +Very practical with slow models, so you can keep working in the meantime. + +```julia +m = aai"Say Hi!"gpt4; +# ...with some delay... +# [ Info: Tokens: 29 @ Cost: \$0.0011 in 2.7 seconds +# [ Info: AIMessage> Hello! How can I assist you today? +""" +macro aai_str(user_prompt, flags...) + model = isempty(flags) ? MODEL_CHAT : only(flags) + prompt = Meta.parse("\"$(escape_string(user_prompt))\"") + quote + Threads.@spawn begin + m = aigenerate($(esc(prompt)); model = $(esc(model))) + @info "AIMessage> $(m.content)" # display the result once it's ready + m + end + end +end diff --git a/src/messages.jl b/src/messages.jl new file mode 100644 index 00000000..b8f254d5 --- /dev/null +++ b/src/messages.jl @@ -0,0 +1,82 @@ +# This file contains key building blocks of conversation history (messages) and utilities to work with them (eg, render) + +## Messages +abstract type AbstractMessage end +abstract type AbstractChatMessage <: AbstractMessage end # with text-based content +abstract type AbstractDataMessage <: AbstractMessage end # with data-based content, eg, embeddings + +Base.@kwdef mutable struct SystemMessage{T <: AbstractString} <: AbstractChatMessage + content::T + variables::Vector{Symbol} = _extract_handlebar_variables(content) +end +Base.@kwdef mutable struct UserMessage{T <: AbstractString} <: AbstractChatMessage + content::T + variables::Vector{Symbol} = _extract_handlebar_variables(content) +end +Base.@kwdef struct AIMessage{T <: Union{AbstractString, Nothing}} <: AbstractChatMessage + content::T = nothing + status::Union{Int, Nothing} = nothing + tokens::Tuple{Int, Int} = (-1, -1) + elapsed::Float64 = -1.0 +end +Base.@kwdef mutable struct DataMessage{T <: Any} <: AbstractDataMessage + content::T + status::Union{Int, Nothing} = nothing + tokens::Tuple{Int, Int} = (-1, -1) + elapsed::Float64 = -1.0 +end + +# content-only constructor +function (MSG::Type{<:AbstractChatMessage})(s::AbstractString) + MSG(; content = s) +end + +# equality check for testing, only equal if all fields are equal and type is the same +Base.var"=="(m1::AbstractMessage, m2::AbstractMessage) = false +function Base.var"=="(m1::T, m2::T) where {T <: AbstractMessage} + all([getproperty(m1, f) == getproperty(m2, f) for f in fieldnames(T)]) +end + +function Base.show(io::IO, ::MIME"text/plain", m::AbstractChatMessage) + type_ = string(typeof(m)) |> x -> split(x, "{")[begin] + if m isa AIMessage + printstyled(io, type_; color = :magenta) + elseif m isa SystemMessage + printstyled(io, type_; color = :light_green) + elseif m isa UserMessage + printstyled(io, type_; color = :light_red) + else + print(io, type_) + end + print(io, "(\"", m.content, "\")") +end +function Base.show(io::IO, ::MIME"text/plain", m::AbstractDataMessage) + type_ = string(typeof(m)) |> x -> split(x, "{")[begin] + printstyled(io, type_; color = :light_yellow) + size_str = (m.content) isa AbstractArray ? string(size(m.content)) : "-" + print(io, "(", typeof(m.content), " of size ", size_str, ")") +end + +## Dispatch for render +function render(schema::AbstractPromptSchema, + messages::Vector{<:AbstractMessage}; + kwargs...) + render(schema, messages; kwargs...) +end +function render(schema::AbstractPromptSchema, msg::AbstractMessage; kwargs...) + render(schema, [msg]; kwargs...) +end +function render(schema::AbstractPromptSchema, msg::AbstractString; kwargs...) + render(schema, [UserMessage(; content = msg)]; kwargs...) +end + +## Prompt Templates +# ie, a way to re-use similar prompting patterns (eg, aiclassifier) +# flow: template -> messages |+ kwargs variables -> chat history +# Defined through Val() to allow for dispatch +function render(prompt_schema::AbstractOpenAISchema, template::Val{:IsStatementTrue}) + [ + SystemMessage("You are an impartial AI judge evaluting whether the provided statement is \"true\" or \"false\". Answer \"unknown\" if you cannot decide."), + UserMessage("##Statement\n\n{{statement}}"), + ] +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..046ecae3 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,13 @@ +# helper to extract handlebar variables (eg, `{{var}}`) from a prompt string +function _extract_handlebar_variables(s::AbstractString) + Symbol[Symbol(m[1]) for m in eachmatch(r"\{\{([^\}]+)\}\}", s)] +end + +# helper to produce summary message of how many tokens were used and for how much +function _report_stats(msg, model::String, model_costs::AbstractDict = Dict()) + token_prices = get(model_costs, model, (0.0, 0.0)) + cost = sum(msg.tokens ./ 1000 .* token_prices) + cost_str = iszero(cost) ? "" : " @ Cost: \$$(round(cost; digits=4))" + + return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds" +end diff --git a/test/llm_openai.jl b/test/llm_openai.jl new file mode 100644 index 00000000..b5234b47 --- /dev/null +++ b/test/llm_openai.jl @@ -0,0 +1,194 @@ +using PromptingTools: TestEchoOpenAISchema, render, OpenAISchema + +@testset "render-OpenAI" begin + schema = OpenAISchema() + # Given a schema and a vector of messages with handlebar variables, it should replace the variables with the correct values in the conversation dictionary. + messages = [ + SystemMessage("Act as a helpful AI assistant"), + UserMessage("Hello, my name is {{name}}"), + ] + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + Dict("role" => "user", "content" => "Hello, my name is John"), + ] + conversation = render(schema, messages; name = "John") + @test conversation == expected_output + + # AI message does NOT replace variables + messages = [ + SystemMessage("Act as a helpful AI assistant"), + AIMessage("Hello, my name is {{name}}"), + ] + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + Dict("role" => "assistant", "content" => "Hello, my name is John"), + ] + conversation = render(schema, messages; name = "John") + # Broken: AIMessage does not replace handlebar variables + @test_broken conversation == expected_output + + # Given a schema and a vector of messages with no system messages, it should add a default system prompt to the conversation dictionary. + messages = [ + UserMessage("User message"), + ] + conversation = render(schema, messages) + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + Dict("role" => "user", "content" => "User message"), + ] + @test conversation == expected_output + + # Given a schema and a vector of messages, it should return a conversation dictionary with the correct roles and contents for each message. + messages = [ + UserMessage("Hello"), + AIMessage("Hi there"), + UserMessage("How are you?"), + AIMessage("I'm doing well, thank you!"), + ] + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + Dict("role" => "user", "content" => "Hello"), + Dict("role" => "assistant", "content" => "Hi there"), + Dict("role" => "user", "content" => "How are you?"), + Dict("role" => "assistant", "content" => "I'm doing well, thank you!"), + ] + conversation = render(schema, messages) + @test conversation == expected_output + + # Given a schema and a vector of messages with a system message, it should move the system message to the front of the conversation dictionary. + messages = [ + UserMessage("Hello"), + AIMessage("Hi there"), + SystemMessage("This is a system message"), + ] + expected_output = [ + Dict("role" => "system", "content" => "This is a system message"), + Dict("role" => "user", "content" => "Hello"), + Dict("role" => "assistant", "content" => "Hi there"), + ] + conversation = render(schema, messages) + @test conversation == expected_output + + # Given an empty vector of messages, it should return an empty conversation dictionary just with the system prompt + messages = PT.AbstractMessage[] + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + ] + conversation = render(schema, messages) + @test conversation == expected_output + + # Given a schema and a vector of messages with a system message containing handlebar variables not present in kwargs, it should replace the variables with empty strings in the conversation dictionary. + messages = [ + SystemMessage("Hello, {{name}}!"), + UserMessage("How are you?"), + ] + expected_output = [ + Dict("role" => "system", "content" => "Hello, !"), + Dict("role" => "user", "content" => "How are you?"), + ] + conversation = render(schema, messages) + # Broken because we do not remove any unused handlebar variables + @test_broken conversation == expected_output + + # Given a schema and a vector of messages with an unknown message type, it should skip the message and continue building the conversation dictionary. + messages = [ + UserMessage("Hello"), + DataMessage(; content = ones(3, 3)), + AIMessage("Hi there"), + ] + expected_output = [ + Dict("role" => "system", "content" => "Act as a helpful AI assistant"), + Dict("role" => "user", "content" => "Hello"), + Dict("role" => "assistant", "content" => "Hi there"), + ] + conversation = render(schema, messages) + @test conversation == expected_output + + # Given a schema and a vector of messages with multiple system messages, it should concatenate them together in the conversation dictionary. + messages = [ + SystemMessage("System message 1"), + SystemMessage("System message 2"), + UserMessage("User message"), + ] + conversation = render(schema, messages) + expected_output = [ + Dict("role" => "system", "content" => "System message 1\nSystem message 2"), + Dict("role" => "user", "content" => "User message"), + ] + # Broken: Does not concatenate system messages yet + @test_broken conversation == expected_output +end + +@testset "aigenerate-OpenAI" begin + # corresponds to OpenAI API v1 + response = Dict(:choices => [Dict(:message => Dict(:content => "Hello!"))], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + + # Test the monkey patch + schema = TestEchoOpenAISchema(; response, status = 200) + msg = PT.OpenAI.create_chat(schema, "", "", "Hello") + @test msg isa TestEchoOpenAISchema + + # Real generation API + schema1 = TestEchoOpenAISchema(; response, status = 200) + msg = aigenerate(schema1, "Hello World") + expected_output = AIMessage(; + content = "Hello!" |> strip, + status = 200, + tokens = (2, 1), + elapsed = msg.elapsed) + @test msg == expected_output + @test schema1.inputs == + [Dict("role" => "system", "content" => "Act as a helpful AI assistant") + Dict("role" => "user", "content" => "Hello World")] + @test schema1.model_id == "gpt-3.5-turbo" + + # Test different input combinations and different prompts + schema2 = TestEchoOpenAISchema(; response, status = 200) + msg = aigenerate(schema2, UserMessage("Hello {{name}}"), + model = "gpt4", http_kwargs = (; verbose = 3), api_kwargs = (; temperature = 0), + name = "World") + expected_output = AIMessage(; + content = "Hello!" |> strip, + status = 200, + tokens = (2, 1), + elapsed = msg.elapsed) + @test msg == expected_output + @test schema1.inputs == + [Dict("role" => "system", "content" => "Act as a helpful AI assistant") + Dict("role" => "user", "content" => "Hello World")] + @test schema2.model_id == "gpt-4" +end + +@testset "aiembed-OpenAI" begin + # corresponds to OpenAI API v1 + response1 = Dict(:data => [Dict(:embedding => ones(128))], + :usage => Dict(:total_tokens => 2, :prompt_tokens => 2, :completion_tokens => 0)) + + # Real generation API + schema1 = TestEchoOpenAISchema(; response = response1, status = 200) + msg = aiembed(schema1, "Hello World") + expected_output = DataMessage(; + content = ones(128), + status = 200, + tokens = (2, 0), + elapsed = msg.elapsed) + @test msg == expected_output + @test schema1.inputs == "Hello World" + @test schema1.model_id == "text-embedding-ada-002" + + # Test different input combinations and multiple strings + response2 = Dict(:data => [Dict(:embedding => ones(128, 2))], + :usage => Dict(:total_tokens => 4, :prompt_tokens => 4, :completion_tokens => 0)) + schema2 = TestEchoOpenAISchema(; response = response2, status = 200) + msg = aiembed(schema2, ["Hello World", "Hello back"], + model = "gpt4", http_kwargs = (; verbose = 3), api_kwargs = (; temperature = 0)) + expected_output = DataMessage(; + content = ones(128, 2), + status = 200, + tokens = (4, 0), + elapsed = msg.elapsed) + @test msg == expected_output + @test schema2.inputs == ["Hello World", "Hello back"] + @test schema2.model_id == "gpt-4" # not possible - just an example +end diff --git a/test/messages.jl b/test/messages.jl new file mode 100644 index 00000000..507ff052 --- /dev/null +++ b/test/messages.jl @@ -0,0 +1,14 @@ +@testset "Message constructors" begin + # Creates an instance of MSG with the given content string. + content = "Hello, world!" + for T in [AIMessage, SystemMessage, UserMessage] + # args + msg = T(content) + @test typeof(msg) <: T + @test msg.content == content + # kwargs + msg = T(; content) + @test typeof(msg) <: T + @test msg.content == content + end +end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 00000000..a9e3a46e --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,13 @@ +using PromptingTools +using Test +using Aqua +const PT = PromptingTools + +@testset "Code quality (Aqua.jl)" begin + Aqua.test_all(PromptingTools) +end +@testset "PromptingTools.jl" begin + include("utils.jl") + include("messages.jl") + include("llm_openai.jl") +end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..438aba00 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,40 @@ +using PromptingTools: _extract_handlebar_variables, _report_stats + +@testset "extract_handlebar_variables" begin + # Extracts handlebar variables enclosed in double curly braces + input_string = "Hello {{name}}, how are you?" + expected_output = [Symbol("name")] + actual_output = _extract_handlebar_variables(input_string) + @test actual_output == expected_output + # Returns an empty array when there are no handlebar variables in the input string + input_string = "Hello, how are you?" + expected_output = Symbol[] + actual_output = _extract_handlebar_variables(input_string) + @test actual_output == expected_output + # Returns an empty array when the input string is empty + input_string = "" + expected_output = Symbol[] + actual_output = _extract_handlebar_variables(input_string) + @test actual_output == expected_output + # Extracts handlebar variables with alphanumeric characters, underscores, and dots + input_string = "Hello {{user.name_1}}, your age is {{user.age-2}}." + expected_output = [Symbol("user.name_1"), Symbol("user.age-2")] + actual_output = _extract_handlebar_variables(input_string) + @test actual_output == expected_output +end + +@testset "report_stats" begin + # Returns a string with the total number of tokens and elapsed time when given a message and model + msg = AIMessage(; content = "", tokens = (1, 5), elapsed = 5.0) + model = "model" + expected_output = "Tokens: 6 in 5.0 seconds" + @test _report_stats(msg, model) == expected_output + + # Returns a string with a cost + expected_output = "Tokens: 6 @ Cost: \$0.007 in 5.0 seconds" + @test _report_stats(msg, model, Dict(model => (2, 1))) == expected_output + + # Returns a string without cost when it's zero + expected_output = "Tokens: 6 in 5.0 seconds" + @test _report_stats(msg, model, Dict(model => (0, 0))) == expected_output +end