Skip to content

Commit

Permalink
Add ShareGPT template (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Mar 26, 2024
1 parent 5077557 commit 4085aeb
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 4 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support for `aigenerate` with Anthropic API. Preset model aliases are `claudeo`, `claudes`, and `claudeh`, for Claude 3 Opus, Sonnet, and Haiku, respectively.
- Enabled the GoogleGenAI extension since `GoogleGenAI.jl` is now officially registered. You can use `aigenerate` by setting the model to `gemini` and providing the `GOOGLE_API_KEY` environment variable.
- Added utilities to make preparation of finetuning datasets easier. You can now export your conversations in JSONL format with ShareGPT formatting (eg, for Axolotl). See `?PT.save_conversations` for more information.

### Fixed

## [0.16.1]

### Fixed
- Fixed a bug where `set_node_style!` was not accepting any Stylers expect for the vanilla `Styler`.
- Fixed a bug where `set_node_style!` was not accepting any Stylers except for the vanilla `Styler`.

## [0.16.0]

Expand Down
1 change: 1 addition & 0 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ include("llm_ollama_managed.jl")
include("llm_ollama.jl")
include("llm_google.jl")
include("llm_anthropic.jl")
include("llm_sharegpt.jl")

## Convenience utils
export @ai_str, @aai_str, @ai!_str, @aai!_str
Expand Down
9 changes: 9 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,15 @@ struct AnthropicSchema <: AbstractAnthropicSchema end
inputs::Any = nothing
end

abstract type AbstractShareGPTSchema <: AbstractPromptSchema end

"""
ShareGPTSchema <: AbstractShareGPTSchema
Frequently used schema for finetuning LLMs. Conversations are recorded as a vector of dicts with keys `from` and `value` (similar to OpenAI).
"""
struct ShareGPTSchema <: AbstractShareGPTSchema end

## Dispatch into a default schema (can be set by Preferences.jl)
# Since we load it as strings, we need to convert it to a symbol and instantiate it
global PROMPT_SCHEMA::AbstractPromptSchema = @load_preference("PROMPT_SCHEMA",
Expand Down
38 changes: 38 additions & 0 deletions src/llm_sharegpt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
### RENDERING
function sharegpt_role(::AbstractMessage)
throw(ArgumentError("Unsupported message type $(typeof(msg))"))
end
sharegpt_role(::AIMessage) = "gpt"
sharegpt_role(::UserMessage) = "human"
sharegpt_role(::SystemMessage) = "system"

function render(::AbstractShareGPTSchema, conv::AbstractVector{<:AbstractMessage})
Dict("conversations" => [Dict("from" => sharegpt_role(msg), "value" => msg.content)
for msg in conv])
end

### AI Functions
function aigenerate(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aigenerate. Please use OpenAISchema instead.")
end
function aiembed(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aiembed. Please use OpenAISchema instead.")
end
function aiclassify(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aiclassify. Please use OpenAISchema instead.")
end
function aiextract(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aiextract. Please use OpenAISchema instead.")
end
function aiscan(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aiscan. Please use OpenAISchema instead.")
end
function aiimage(prompt_schema::AbstractShareGPTSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("ShareGPT schema does not support aiimage. Please use OpenAISchema instead.")
end
50 changes: 50 additions & 0 deletions src/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,53 @@ Loads a conversation (`messages`) from `io_or_file`
function load_conversation(io_or_file::Union{IO, AbstractString})
messages = JSON3.read(io_or_file, Vector{AbstractMessage})
end

"""
save_conversations(schema::AbstractPromptSchema, filename::AbstractString,
conversations::Vector{<:AbstractVector{<:PT.AbstractMessage}})
Saves provided conversations (vector of vectors of `messages`) to `filename` rendered in the particular `schema`.
Commonly used for finetuning models with `schema = ShareGPTSchema()`
The format is JSON Lines, where each line is a JSON object representing one provided conversation.
See also: `save_conversation`
# Examples
You must always provide a VECTOR of conversations
```julia
messages = AbstractMessage[SystemMessage("System message 1"),
UserMessage("User message"),
AIMessage("AI message")]
conversation = [messages] # vector of vectors
dir = tempdir()
fn = joinpath(dir, "conversations.jsonl")
save_conversations(fn, conversation)
# Content of the file (one line for each conversation)
# {"conversations":[{"value":"System message 1","from":"system"},{"value":"User message","from":"human"},{"value":"AI message","from":"gpt"}]}
```
"""
function save_conversations(schema::AbstractPromptSchema, filename::AbstractString,
conversations::Vector{<:AbstractVector{<:AbstractMessage}})
@assert endswith(filename, ".jsonl") "Filename must end with `.jsonl` (JSON Lines format)."
io = IOBuffer()
for i in eachindex(conversations)
conv = conversations[i]
rendered_conv = render(schema, conv)
JSON3.write(io, rendered_conv)
# separate each conversation by newline
i < length(conversations) && print(io, "\n")
end
write(filename, String(take!(io)))
return nothing
end

# shortcut for ShareGPTSchema
function save_conversations(filename::AbstractString,
conversations::Vector{<:AbstractVector{<:AbstractMessage}})
save_conversations(ShareGPTSchema(), filename, conversations)
end
42 changes: 42 additions & 0 deletions test/llm_sharegpt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using PromptingTools: render, ShareGPTSchema
using PromptingTools: AIMessage, SystemMessage, AbstractMessage
using PromptingTools: UserMessage, UserMessageWithImages, DataMessage

@testset "render-ShareGPT" begin
schema = ShareGPTSchema()
# Ignores any handlebar replacement, takes conversations as is
messages = [
SystemMessage("Act as a helpful AI assistant"),
UserMessage("Hello, my name is {{name}}"),
AIMessage("Hello, my name is {{name}}")
]
expected_output = Dict("conversations" => [
Dict("value" => "Act as a helpful AI assistant", "from" => "system"),
Dict("value" => "Hello, my name is {{name}}", "from" => "human"),
Dict("value" => "Hello, my name is {{name}}", "from" => "gpt")])
conversation = render(schema, messages)
@test conversation == expected_output

# IT DOES NOT support any advanced message types (UserMessageWithImages, DataMessage)
messages = [
UserMessage("Hello"),
DataMessage(; content = ones(3, 3))
]

@test_throws ArgumentError render(schema, messages)

messages = [
SystemMessage("System message 1"),
UserMessageWithImages("User message"; image_url = "https://example.com/image.png")
]
@test_throws ArgumentError render(schema, messages)
end

@testset "not implemented ai* functions" begin
@test_throws ErrorException aigenerate(ShareGPTSchema(), "prompt")
@test_throws ErrorException aiembed(ShareGPTSchema(), "prompt")
@test_throws ErrorException aiextract(ShareGPTSchema(), "prompt")
@test_throws ErrorException aiclassify(ShareGPTSchema(), "prompt")
@test_throws ErrorException aiscan(ShareGPTSchema(), "prompt")
@test_throws ErrorException aiimage(ShareGPTSchema(), "prompt")
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ end
include("llm_ollama.jl")
include("llm_google.jl")
include("llm_anthropic.jl")
include("llm_sharegpt.jl")
include("macros.jl")
include("templates.jl")
include("serialization.jl")
Expand Down
20 changes: 17 additions & 3 deletions test/serialization.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using PromptingTools: AIMessage,
SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage, DataMessage
using PromptingTools: save_conversation, load_conversation
SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage,
DataMessage, ShareGPTSchema
using PromptingTools: save_conversation, load_conversation, save_conversations
using PromptingTools: save_template, load_template

@testset "Serialization - Messages" begin
Expand All @@ -23,7 +24,7 @@ end
version = "1.1"
msgs = [
SystemMessage("You are an impartial AI judge evaluting whether the provided statement is \"true\" or \"false\". Answer \"unknown\" if you cannot decide."),
UserMessage("# Statement\n\n{{it}}"),
UserMessage("# Statement\n\n{{it}}")
]
tmp, _ = mktemp()
save_template(tmp,
Expand All @@ -36,3 +37,16 @@ end
@test metadata[1].content == "Template Metadata"
@test metadata[1].source == ""
end

@testset "Serialization - Messages" begin
# Test save_conversations
messages = AbstractMessage[SystemMessage("System message 1"),
UserMessage("User message"),
AIMessage("AI message")]
dir = tempdir()
fn = joinpath(dir, "conversations.jsonl")
save_conversations(fn, [messages])
s = read(fn, String)
@test s ==
"""{"conversations":[{"value":"System message 1","from":"system"},{"value":"User message","from":"human"},{"value":"AI message","from":"gpt"}]}"""
end

0 comments on commit 4085aeb

Please sign in to comment.