diff --git a/CHANGELOG.md b/CHANGELOG.md index fe223049..6216d685 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added support for `aigenerate` with Anthropic API. Preset model aliases are `claudeo`, `claudes`, and `claudeh`, for Claude 3 Opus, Sonnet, and Haiku, respectively. - Enabled the GoogleGenAI extension since `GoogleGenAI.jl` is now officially registered. You can use `aigenerate` by setting the model to `gemini` and providing the `GOOGLE_API_KEY` environment variable. +- Added utilities to make preparation of finetuning datasets easier. You can now export your conversations in JSONL format with ShareGPT formatting (eg, for Axolotl). See `?PT.save_conversations` for more information. ### Fixed ## [0.16.1] ### Fixed -- Fixed a bug where `set_node_style!` was not accepting any Stylers expect for the vanilla `Styler`. +- Fixed a bug where `set_node_style!` was not accepting any Stylers except for the vanilla `Styler`. ## [0.16.0] diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index 4679c6cf..8e270b42 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -71,6 +71,7 @@ include("llm_ollama_managed.jl") include("llm_ollama.jl") include("llm_google.jl") include("llm_anthropic.jl") +include("llm_sharegpt.jl") ## Convenience utils export @ai_str, @aai_str, @ai!_str, @aai!_str diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 8e749479..78c90e69 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -276,6 +276,15 @@ struct AnthropicSchema <: AbstractAnthropicSchema end inputs::Any = nothing end +abstract type AbstractShareGPTSchema <: AbstractPromptSchema end + +""" + ShareGPTSchema <: AbstractShareGPTSchema + +Frequently used schema for finetuning LLMs. Conversations are recorded as a vector of dicts with keys `from` and `value` (similar to OpenAI). +""" +struct ShareGPTSchema <: AbstractShareGPTSchema end + ## Dispatch into a default schema (can be set by Preferences.jl) # Since we load it as strings, we need to convert it to a symbol and instantiate it global PROMPT_SCHEMA::AbstractPromptSchema = @load_preference("PROMPT_SCHEMA", diff --git a/src/llm_sharegpt.jl b/src/llm_sharegpt.jl new file mode 100644 index 00000000..fc935497 --- /dev/null +++ b/src/llm_sharegpt.jl @@ -0,0 +1,38 @@ +### RENDERING +function sharegpt_role(::AbstractMessage) + throw(ArgumentError("Unsupported message type $(typeof(msg))")) +end +sharegpt_role(::AIMessage) = "gpt" +sharegpt_role(::UserMessage) = "human" +sharegpt_role(::SystemMessage) = "system" + +function render(::AbstractShareGPTSchema, conv::AbstractVector{<:AbstractMessage}) + Dict("conversations" => [Dict("from" => sharegpt_role(msg), "value" => msg.content) + for msg in conv]) +end + +### AI Functions +function aigenerate(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aigenerate. Please use OpenAISchema instead.") +end +function aiembed(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aiembed. Please use OpenAISchema instead.") +end +function aiclassify(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aiclassify. Please use OpenAISchema instead.") +end +function aiextract(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aiextract. Please use OpenAISchema instead.") +end +function aiscan(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aiscan. Please use OpenAISchema instead.") +end +function aiimage(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE; + kwargs...) + error("ShareGPT schema does not support aiimage. Please use OpenAISchema instead.") +end diff --git a/src/serialization.jl b/src/serialization.jl index 93455561..d05645ed 100644 --- a/src/serialization.jl +++ b/src/serialization.jl @@ -60,3 +60,53 @@ Loads a conversation (`messages`) from `io_or_file` function load_conversation(io_or_file::Union{IO, AbstractString}) messages = JSON3.read(io_or_file, Vector{AbstractMessage}) end + +""" + save_conversations(schema::AbstractPromptSchema, filename::AbstractString, + conversations::Vector{<:AbstractVector{<:PT.AbstractMessage}}) + +Saves provided conversations (vector of vectors of `messages`) to `filename` rendered in the particular `schema`. + +Commonly used for finetuning models with `schema = ShareGPTSchema()` + +The format is JSON Lines, where each line is a JSON object representing one provided conversation. + +See also: `save_conversation` + +# Examples + +You must always provide a VECTOR of conversations +```julia +messages = AbstractMessage[SystemMessage("System message 1"), + UserMessage("User message"), + AIMessage("AI message")] +conversation = [messages] # vector of vectors + +dir = tempdir() +fn = joinpath(dir, "conversations.jsonl") +save_conversations(fn, conversation) + +# Content of the file (one line for each conversation) +# {"conversations":[{"value":"System message 1","from":"system"},{"value":"User message","from":"human"},{"value":"AI message","from":"gpt"}]} +``` +""" +function save_conversations(schema::AbstractPromptSchema, filename::AbstractString, + conversations::Vector{<:AbstractVector{<:AbstractMessage}}) + @assert endswith(filename, ".jsonl") "Filename must end with `.jsonl` (JSON Lines format)." + io = IOBuffer() + for i in eachindex(conversations) + conv = conversations[i] + rendered_conv = render(schema, conv) + JSON3.write(io, rendered_conv) + # separate each conversation by newline + i < length(conversations) && print(io, "\n") + end + write(filename, String(take!(io))) + return nothing +end + +# shortcut for ShareGPTSchema +function save_conversations(filename::AbstractString, + conversations::Vector{<:AbstractVector{<:AbstractMessage}}) + save_conversations(ShareGPTSchema(), filename, conversations) +end \ No newline at end of file diff --git a/test/llm_sharegpt.jl b/test/llm_sharegpt.jl new file mode 100644 index 00000000..812269d3 --- /dev/null +++ b/test/llm_sharegpt.jl @@ -0,0 +1,42 @@ +using PromptingTools: render, ShareGPTSchema +using PromptingTools: AIMessage, SystemMessage, AbstractMessage +using PromptingTools: UserMessage, UserMessageWithImages, DataMessage + +@testset "render-ShareGPT" begin + schema = ShareGPTSchema() + # Ignores any handlebar replacement, takes conversations as is + messages = [ + SystemMessage("Act as a helpful AI assistant"), + UserMessage("Hello, my name is {{name}}"), + AIMessage("Hello, my name is {{name}}") + ] + expected_output = Dict("conversations" => [ + Dict("value" => "Act as a helpful AI assistant", "from" => "system"), + Dict("value" => "Hello, my name is {{name}}", "from" => "human"), + Dict("value" => "Hello, my name is {{name}}", "from" => "gpt")]) + conversation = render(schema, messages) + @test conversation == expected_output + + # IT DOES NOT support any advanced message types (UserMessageWithImages, DataMessage) + messages = [ + UserMessage("Hello"), + DataMessage(; content = ones(3, 3)) + ] + + @test_throws ArgumentError render(schema, messages) + + messages = [ + SystemMessage("System message 1"), + UserMessageWithImages("User message"; image_url = "https://example.com/image.png") + ] + @test_throws ArgumentError render(schema, messages) +end + +@testset "not implemented ai* functions" begin + @test_throws ErrorException aigenerate(ShareGPTSchema(), "prompt") + @test_throws ErrorException aiembed(ShareGPTSchema(), "prompt") + @test_throws ErrorException aiextract(ShareGPTSchema(), "prompt") + @test_throws ErrorException aiclassify(ShareGPTSchema(), "prompt") + @test_throws ErrorException aiscan(ShareGPTSchema(), "prompt") + @test_throws ErrorException aiimage(ShareGPTSchema(), "prompt") +end diff --git a/test/runtests.jl b/test/runtests.jl index 367ea98d..eebcbdf0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ end include("llm_ollama.jl") include("llm_google.jl") include("llm_anthropic.jl") + include("llm_sharegpt.jl") include("macros.jl") include("templates.jl") include("serialization.jl") diff --git a/test/serialization.jl b/test/serialization.jl index 9b04667f..2434147b 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -1,6 +1,7 @@ using PromptingTools: AIMessage, - SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage, DataMessage -using PromptingTools: save_conversation, load_conversation + SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage, + DataMessage, ShareGPTSchema +using PromptingTools: save_conversation, load_conversation, save_conversations using PromptingTools: save_template, load_template @testset "Serialization - Messages" begin @@ -23,7 +24,7 @@ end version = "1.1" msgs = [ SystemMessage("You are an impartial AI judge evaluting whether the provided statement is \"true\" or \"false\". Answer \"unknown\" if you cannot decide."), - UserMessage("# Statement\n\n{{it}}"), + UserMessage("# Statement\n\n{{it}}") ] tmp, _ = mktemp() save_template(tmp, @@ -36,3 +37,16 @@ end @test metadata[1].content == "Template Metadata" @test metadata[1].source == "" end + +@testset "Serialization - Messages" begin + # Test save_conversations + messages = AbstractMessage[SystemMessage("System message 1"), + UserMessage("User message"), + AIMessage("AI message")] + dir = tempdir() + fn = joinpath(dir, "conversations.jsonl") + save_conversations(fn, [messages]) + s = read(fn, String) + @test s == + """{"conversations":[{"value":"System message 1","from":"system"},{"value":"User message","from":"human"},{"value":"AI message","from":"gpt"}]}""" +end \ No newline at end of file