From 1cda053423ac23d977cf52d692cdc5978ce20ba3 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 15 Apr 2024 08:52:33 +0100 Subject: [PATCH] Add TraceMessage for observability (#133) --- CHANGELOG.md | 5 + Project.toml | 2 + src/Experimental/RAGTools/retrieval.jl | 8 +- src/PromptingTools.jl | 2 + src/llm_interface.jl | 22 +++ src/llm_tracer.jl | 198 +++++++++++++++++++++++ src/messages.jl | 188 +++++++++++++++++++++- src/templates.jl | 20 +++ src/user_preferences.jl | 22 ++- test/llm_tracer.jl | 213 +++++++++++++++++++++++++ test/messages.jl | 89 ++++++++++- test/runtests.jl | 2 + test/serialization.jl | 22 +++ 13 files changed, 785 insertions(+), 8 deletions(-) create mode 100644 src/llm_tracer.jl create mode 100644 test/llm_tracer.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index 9eb0c084..1390dda4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Added a few new open-weights models hosted by Fireworks.ai to the registry (DBRX Instruct, Mixtral 8x22b Instruct, Qwen 72b). If you're curious about how well they work, try them! +- Added basic support for observability downstream. Created custom callback infrastructure with `initialize_tracer` and `finalize_tracer` and dedicated types are `TracerMessage` and `TracerMessageLike`. See `?TracerMessage` for more information and the corresponding `aigenerate` docstring. + +### Updated +- Changed default model for `RAGTools.CohereReranker` to "cohere-rerank-english-v3.0". ### Fixed diff --git a/Project.toml b/Project.toml index 30bc0fb8..3516cdce 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.19.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -31,6 +32,7 @@ RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra"] AbstractTrees = "0.4" Aqua = "0.7" Base64 = "<0.0.1, 1" +Dates = "<0.0.1, 1" GoogleGenAI = "0.3" HTTP = "1" JSON3 = "1" diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index d0effeb9..948aea38 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -390,7 +390,7 @@ end verbose::Bool = false, api_key::AbstractString = PT.COHERE_API_KEY, top_n::Integer = length(candidates.scores), - model::AbstractString = "rerank-english-v2.0", + model::AbstractString = "rerank-english-v3.0", return_documents::Bool = false, cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) @@ -404,10 +404,10 @@ Re-ranks a list of candidate chunks using the Cohere Rerank API. See https://coh - `question`: The query to be used for the search. - `candidates`: The candidate chunks to be re-ranked. - `top_n`: The number of most relevant documents to return. Default is `length(documents)`. -- `model`: The model to use for reranking. Default is `rerank-english-v2.0`. +- `model`: The model to use for reranking. Default is `rerank-english-v3.0`. - `return_documents`: A boolean flag indicating whether to return the reranked documents in the response. Default is `false`. - `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`. -- `cost_tracker`: An atomic counter to track the cost of the retrieval. Default is `Threads.Atomic{Float64}(0.0)`. Not currently tracked (cost unclear). +- `cost_tracker`: An atomic counter to track the cost of the retrieval. Not implemented /tracked (cost unclear). Provided for consistency. """ function rerank( @@ -416,7 +416,7 @@ function rerank( verbose::Bool = false, api_key::AbstractString = PT.COHERE_API_KEY, top_n::Integer = length(candidates.scores), - model::AbstractString = "rerank-english-v2.0", + model::AbstractString = "rerank-english-v3.0", return_documents::Bool = false, cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index 8e270b42..672f2d5b 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -2,6 +2,7 @@ module PromptingTools import AbstractTrees using Base64: base64encode +using Dates: now, DateTime using Logging using OpenAI using JSON3 @@ -72,6 +73,7 @@ include("llm_ollama.jl") include("llm_google.jl") include("llm_anthropic.jl") include("llm_sharegpt.jl") +include("llm_tracer.jl") ## Convenience utils export @ai_str, @aai_str, @ai!_str, @aai!_str diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 78c90e69..690140f6 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -285,6 +285,28 @@ Frequently used schema for finetuning LLMs. Conversations are recorded as a vect """ struct ShareGPTSchema <: AbstractShareGPTSchema end +abstract type AbstractTracerSchema <: AbstractPromptSchema end + +""" + TracerSchema <: AbstractTracerSchema + +A schema designed to wrap another schema, enabling pre- and post-execution callbacks for tracing and additional functionalities. This type is specifically utilized within the `TracerMessage` type to trace the execution flow, facilitating observability and debugging in complex conversational AI systems. + +The `TracerSchema` acts as a middleware, allowing developers to insert custom logic before and after the execution of the primary schema's functionality. This can include logging, performance measurement, or any other form of tracing required to understand or improve the execution flow. + +# Usage +```julia +wrap_schema = TracerSchema(OpenAISchema()) +msg = aigenerate(wrap_schema, "Say hi!"; model="gpt-4") +# output type should be TracerMessage +msg isa TracerMessage +``` +You can define your own tracer schema and the corresponding methods: `initialize_tracer`, `finalize_tracer`. See `src/llm_tracer.jl` +""" +struct TracerSchema <: AbstractTracerSchema + schema::AbstractPromptSchema +end + ## Dispatch into a default schema (can be set by Preferences.jl) # Since we load it as strings, we need to convert it to a symbol and instantiate it global PROMPT_SCHEMA::AbstractPromptSchema = @load_preference("PROMPT_SCHEMA", diff --git a/src/llm_tracer.jl b/src/llm_tracer.jl new file mode 100644 index 00000000..1d63349a --- /dev/null +++ b/src/llm_tracer.jl @@ -0,0 +1,198 @@ +# Tracing infrastructure for logging and other callbacks +# - Define your own schema that is subtype of AbstractTracerSchema and wraps the underlying LLM provider schema +# - Customize initialize_tracer and finalize_tracer with your custom callback +# - Call your ai* function with the tracer schema as usual + +# Simple passthrough, do nothing +""" + render(tracer_schema::AbstractTracerSchema, + conv::AbstractVector{<:AbstractMessage}; kwargs...) + +Passthrough. No changes. +""" +function render(tracer_schema::AbstractTracerSchema, + conv::AbstractVector{<:AbstractMessage}; kwargs...) + return conv +end + +""" + initialize_tracer( + tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...) + +Initializes `tracer`/callback (if necessary). Can provide any keyword arguments in `tracer_kwargs` (eg, `parent_id`, `thread_id`, `run_id`). +Is executed prior to the `ai*` calls. + +In the default implementation, we just collect the necessary data to build the tracer object in `finalize_tracer`. +""" +function initialize_tracer( + tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...) + return (; time_sent = now(), model, tracer_kwargs...) +end + +""" + finalize_tracer( + tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Finalizes the calltracer of whatever is nedeed after the `ai*` calls. Use `tracer_kwargs` to provide any information necessary (eg, `parent_id`, `thread_id`, `run_id`). + +In the default implementation, we convert all non-tracer messages into `TracerMessage`. +""" +function finalize_tracer( + tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...) + # We already captured all kwargs, they are already in `tracer`, we can ignore them in this implementation + time_received = now() + # work with arrays for unified processing + is_vector = msg_or_conv isa AbstractVector + conv = msg_or_conv isa AbstractVector{<:AbstractMessage} ? + convert(Vector{AbstractMessage}, msg_or_conv) : + AbstractMessage[msg_or_conv] + # all msg non-traced, set times + for i in eachindex(conv) + msg = conv[i] + # change into TracerMessage if not already, use the current kwargs + if !istracermessage(msg) + # we saved our data for `tracer` + conv[i] = TracerMessage(; object = msg, tracer..., time_received) + end + end + return is_vector ? conv : first(conv) +end + +""" + aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aigenerate` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aigenerate` (with the `tracer_schema.schema`) +- calls `finalize_tracer` + +# Example +```julia +wrap_schema = PT.TracerSchema(PT.OpenAISchema()) +msg = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t") +msg isa TracerMessage # true +msg.content # access content like if it was the message +PT.pprint(msg) # pretty-print the message +``` + +It works on a vector of messages and converts only the non-tracer ones, eg, +```julia +wrap_schema = PT.TracerSchema(PT.OpenAISchema()) +conv = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t", return_all = true) +all(PT.istracermessage, conv) #true +``` +""" +function aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs, kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + msg_or_conv = aigenerate(tracer_schema.schema, prompt; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, msg_or_conv; model, tracer_kwargs, kwargs...) +end + +""" + aiembed(tracer_schema::AbstractTracerSchema, + doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aiembed` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aiembed` (with the `tracer_schema.schema`) +- calls `finalize_tracer` +""" +function aiembed(tracer_schema::AbstractTracerSchema, + doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + embed_or_conv = aiembed( + tracer_schema.schema, doc_or_docs, postprocess; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, embed_or_conv; model, tracer_kwargs..., kwargs...) +end + +""" + aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aiclassify` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aiclassify` (with the `tracer_schema.schema`) +- calls `finalize_tracer` +""" +function aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + classify_or_conv = aiclassify(tracer_schema.schema, prompt; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, classify_or_conv; model, tracer_kwargs..., kwargs...) +end + +""" + aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aiextract` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aiextract` (with the `tracer_schema.schema`) +- calls `finalize_tracer` +""" +function aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + extract_or_conv = aiextract(tracer_schema.schema, prompt; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, extract_or_conv; model, tracer_kwargs..., kwargs...) +end + +""" + aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aiscan` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aiscan` (with the `tracer_schema.schema`) +- calls `finalize_tracer` +""" +function aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + scan_or_conv = aiscan(tracer_schema.schema, prompt; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, scan_or_conv; model, tracer_kwargs..., kwargs...) +end + +""" + aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + +Wraps the normal `aiimage` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`). + +Logic: +- calls `initialize_tracer` +- calls `aiimage` (with the `tracer_schema.schema`) +- calls `finalize_tracer` +""" +function aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE; + tracer_kwargs = NamedTuple(), model = "", kwargs...) + tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...) + merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided + image_or_conv = aiimage(tracer_schema.schema, prompt; merged_kwargs...) + return finalize_tracer( + tracer_schema, tracer, image_or_conv; model, tracer_kwargs..., kwargs...) +end diff --git a/src/messages.jl b/src/messages.jl index c0a330e8..fcae0d64 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -4,6 +4,9 @@ abstract type AbstractMessage end abstract type AbstractChatMessage <: AbstractMessage end # with text-based content abstract type AbstractDataMessage <: AbstractMessage end # with data-based content, eg, embeddings +abstract type AbstractTracerMessage{T <: AbstractMessage} <: AbstractMessage end # message with annotation that exposes the underlying message +# Complementary type for tracing, follows the same API as TracerMessage +abstract type AbstractTracer{T <: Any} end ## Allowed inputs for ai* functions, AITemplate is resolved one level higher const ALLOWED_PROMPT_TYPE = Union{ @@ -122,14 +125,23 @@ Base.@kwdef struct DataMessage{T <: Any} <: AbstractDataMessage _type::Symbol = :datamessage end +### Other Message methods # content-only constructor function (MSG::Type{<:AbstractChatMessage})(prompt::AbstractString) MSG(; content = prompt) end +function (MSG::Type{<:AbstractChatMessage})(msg::AbstractChatMessage) + MSG(; msg.content) +end +function (MSG::Type{<:AbstractChatMessage})(msg::AbstractTracerMessage{<:AbstractChatMessage}) + MSG(; msg.content) +end + isusermessage(m::AbstractMessage) = m isa UserMessage issystemmessage(m::AbstractMessage) = m isa SystemMessage isdatamessage(m::AbstractMessage) = m isa DataMessage isaimessage(m::AbstractMessage) = m isa AIMessage +istracermessage(m::AbstractMessage) = m isa AbstractTracerMessage # equality check for testing, only equal if all fields are equal and type is the same Base.var"=="(m1::AbstractMessage, m2::AbstractMessage) = false @@ -193,6 +205,151 @@ function attach_images_to_user_message(msgs::Vector{T}; return msgs end +############################## +### TracerMessages +# - They are mutable (to update iteratively) +# - they contain a message and additional metadata +# - they expose as much of the underlying message as possible to allow the same operations +""" + TracerMessage{T <: Union{AbstractChatMessage, AbstractDataMessage}} <: AbstractTracerMessage + +A mutable wrapper message designed for tracing the flow of messages through the system, allowing for iterative updates and providing additional metadata for observability. + +# Fields +- `object::T`: The original message being traced, which can be either a chat or data message. +- `from::Union{Nothing, Symbol}`: The identifier of the sender of the message. +- `to::Union{Nothing, Symbol}`: The identifier of the intended recipient of the message. +- `viewers::Vector{Symbol}`: A list of identifiers for entities that have access to view the message, in addition to the sender and recipient. +- `time_received::DateTime`: The timestamp when the message was received by the tracing system. +- `time_sent::Union{Nothing, DateTime}`: The timestamp when the message was originally sent, if available. +- `model::String`: The name of the model that generated the message. Defaults to empty. +- `parent_id::Symbol`: An identifier for the job or process that the message is associated with. Higher-level tracing ID. +- `thread_id::Symbol`: An identifier for the thread (series of messages for one model/agent) or execution context within the job where the message originated. It should be the same for messages in the same thread. +- `_type::Symbol`: A fixed symbol identifying the type of the message as `:eventmessage`, used for type discrimination. + +This structure is particularly useful for debugging, monitoring, and auditing the flow of messages in systems that involve complex interactions or asynchronous processing. + +All fields are optional besides the `object`. + +Useful methods: `pprint` (pretty prints the underlying message), `unwrap` (to get the `object` out of tracer), `align_tracer!` (to set all shared IDs in a vector of tracers to the same), `istracermessage` to check if given message is an AbstractTracerMessage + +# Example +```julia +wrap_schema = PT.TracerSchema(PT.OpenAISchema()) +msg = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t") +msg # isa TracerMessage +msg.content # access content like if it was the message +``` +""" +Base.@kwdef mutable struct TracerMessage{T <: + Union{AbstractChatMessage, AbstractDataMessage}} <: + AbstractTracerMessage{T} + object::T + from::Union{Nothing, Symbol} = nothing # who sent it + to::Union{Nothing, Symbol} = nothing # who received it + viewers::Vector{Symbol} = Symbol[] # who has access to it (besides from, to) + time_received::DateTime = now() + time_sent::Union{Nothing, DateTime} = nothing + model::String = "" + parent_id::Symbol = gensym("parent") + thread_id::Symbol = gensym("thread") + run_id::Union{Nothing, Int} = Int(rand(Int32)) + _type::Symbol = :tracermessage +end +function TracerMessage(msg::Union{AbstractChatMessage, AbstractDataMessage}; kwargs...) + TracerMessage(; object = msg, kwargs...) +end + +""" + TracerMessageLike{T <: Any} <: AbstractTracer + +A mutable structure designed for general-purpose tracing within the system, capable of handling any type of object that is part of the AI Conversation. +It provides a flexible way to track and annotate objects as they move through different parts of the system, facilitating debugging, monitoring, and auditing. + +# Fields +- `object::T`: The original object being traced. +- `from::Union{Nothing, Symbol}`: The identifier of the sender or origin of the object. +- `to::Union{Nothing, Symbol}`: The identifier of the intended recipient or destination of the object. +- `viewers::Vector{Symbol}`: A list of identifiers for entities that have access to view the object, in addition to the sender and recipient. +- `time_received::DateTime`: The timestamp when the object was received by the tracing system. +- `time_sent::Union{Nothing, DateTime}`: The timestamp when the object was originally sent, if available. +- `model::String`: The name of the model or process that generated or is associated with the object. Defaults to empty. +- `parent_id::Symbol`: An identifier for the job or process that the object is associated with. Higher-level tracing ID. +- `thread_id::Symbol`: An identifier for the thread or execution context within the job where the object originated. It should be the same for objects in the same thread. +- `run_id::Union{Nothing, Int}`: A unique identifier for the run or instance of the process that generated the object. Defaults to a random integer. +- `_type::Symbol`: A fixed symbol identifying the type of the tracer as `:tracermessage`, used for type discrimination. + +This structure is particularly useful for systems that involve complex interactions or asynchronous processing, where tracking the flow and transformation of objects is crucial. + +All fields are optional besides the `object`. +""" +@kwdef mutable struct TracerMessageLike{T <: Any} <: AbstractTracer{T} + object::T + from::Union{Nothing, Symbol} = nothing # who sent it + to::Union{Nothing, Symbol} = nothing # who received it + viewers::Vector{Symbol} = Symbol[] # who has access to it (besides from, to) + time_received::DateTime = now() + time_sent::Union{Nothing, DateTime} = nothing + model::String = "" + parent_id::Symbol = gensym("parent") + thread_id::Symbol = gensym("thread") + run_id::Union{Nothing, Int} = Int(rand(Int32)) + _type::Symbol = :tracermessagelike + ## TracerMessageLike() = new() +end +## function TracerMessageLike() +## TracerMessageLike(; object = undef) +## end +function TracerMessageLike( + object; kwargs...) + TracerMessageLike(; object, kwargs...) +end +Base.var"=="(t1::AbstractTracer, t2::AbstractTracer) = false +function Base.var"=="(t1::AbstractTracer{T}, t2::AbstractTracer{T}) where {T <: Any} + ## except for run_id, that's random and not important for content comparison + all([getproperty(t1, f) == getproperty(t2, f) for f in fieldnames(T) if f != :run_id]) +end + +# Shared methods +# getproperty for tracer messages forwards to the message when relevant +function Base.getproperty(t::Union{AbstractTracerMessage, AbstractTracer}, f::Symbol) + obj = getfield(t, :object) + if hasproperty(obj, f) + getproperty(obj, f) + else + getfield(t, f) + end +end + +function Base.copy(t::T) where {T <: Union{AbstractTracerMessage, AbstractTracer}} + T([deepcopy(getfield(t, f)) for f in fieldnames(T)]...) +end + +"Unwraps the tracer message, returning the original `object`." +function unwrap(t::Union{AbstractTracerMessage, AbstractTracer}) + getfield(t, :object) +end + +"Aligns the tracer message, updating the `parent_id`, `thread_id`. Often used to align multiple tracers in the vector to have the same IDs." +function align_tracer!( + t::Union{AbstractTracerMessage, AbstractTracer}; parent_id::Symbol = t.parent_id, + thread_id::Symbol = t.thread_id) + t.parent_id = parent_id + t.thread_id = thread_id + return t +end +"Aligns multiple tracers in the vector to have the same Parent and Thread IDs as the first item." +function align_tracer!( + vect::AbstractVector{<:Union{AbstractTracerMessage, AbstractTracer}}) + if !isempty(vect) + t = first(vect) + align_tracer!.(vect; t.parent_id, t.thread_id) + else + vect + end +end + +############################## ## Helpful accessors "Helpful accessor for the last message in `conversation`. Returns the last message in the conversation." function last_message(conversation::AbstractVector{<:AbstractMessage}) @@ -236,6 +393,12 @@ function Base.show(io::IO, ::MIME"text/plain", m::AbstractDataMessage) print(io, "(", typeof(m.content), ")") end end +function Base.show(io::IO, ::MIME"text/plain", t::AbstractTracerMessage) + dump(IOContext(io, :limit => true), t, maxdepth = 1) +end +function Base.show(io::IO, ::MIME"text/plain", t::AbstractTracer) + dump(IOContext(io, :limit => true), t, maxdepth = 1) +end ## Dispatch for render # function render(schema::AbstractPromptSchema, @@ -259,7 +422,8 @@ function StructTypes.subtypes(::Type{AbstractMessage}) aimessage = AIMessage, systemmessage = SystemMessage, metadatamessage = MetadataMessage, - datamessage = DataMessage) + datamessage = DataMessage, + tracermessage = TracerMessage) end StructTypes.StructType(::Type{AbstractChatMessage}) = StructTypes.AbstractType() @@ -272,12 +436,26 @@ function StructTypes.subtypes(::Type{AbstractChatMessage}) metadatamessage = MetadataMessage) end +StructTypes.StructType(::Type{AbstractTracerMessage}) = StructTypes.AbstractType() +StructTypes.subtypekey(::Type{AbstractTracerMessage}) = :_type +function StructTypes.subtypes(::Type{AbstractTracerMessage}) + (tracermessage = TracerMessage,) +end + +StructTypes.StructType(::Type{AbstractTracer}) = StructTypes.AbstractType() +StructTypes.subtypekey(::Type{AbstractTracer}) = :_type +function StructTypes.subtypes(::Type{AbstractTracer}) + (tracermessagelike = TracerMessageLike,) +end + StructTypes.StructType(::Type{MetadataMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{SystemMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{UserMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{UserMessageWithImages}) = StructTypes.Struct() StructTypes.StructType(::Type{AIMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{DataMessage}) = StructTypes.Struct() +StructTypes.StructType(::Type{TracerMessage}) = StructTypes.Struct() # Ignore mutability once we serialize +StructTypes.StructType(::Type{TracerMessageLike}) = StructTypes.Struct() # Ignore mutability once we serialize ### Utilities for Pretty Printing """ @@ -311,6 +489,14 @@ function pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[ print(io, "\n", "-"^20, "\n") print(io, content, "\n\n") end + +function pprint(io::IO, t::Union{AbstractTracerMessage, AbstractTracer}; + text_width::Int = displaysize(io)[2]) + role = "$(nameof(typeof(t))) with:" + print(io, "-"^20, "\n") + print(io, role, "\n") + pprint(io, unwrap(t); text_width) +end """ pprint(io::IO, conversation::AbstractVector{<:AbstractMessage}) diff --git a/src/templates.jl b/src/templates.jl index 399b41a1..617609a1 100644 --- a/src/templates.jl +++ b/src/templates.jl @@ -375,6 +375,26 @@ function aiimage(schema::AbstractPromptSchema, template::Symbol; kwargs...) aiimage(schema, AITemplate(template); kwargs...) end +## Dispatch for TracerSchema to avoid ambiguities +function render(schema::AbstractTracerSchema, template::AITemplate; kwargs...) + render(schema.schema, template; kwargs...) +end +function aigenerate(schema::AbstractTracerSchema, template::Symbol; kwargs...) + aigenerate(schema, render(schema, AITemplate(template)); kwargs...) +end +function aiclassify(schema::AbstractTracerSchema, template::Symbol; kwargs...) + aiclassify(schema, render(schema, AITemplate(template)); kwargs...) +end +function aiextract(schema::AbstractTracerSchema, template::Symbol; kwargs...) + aiextract(schema, render(schema, AITemplate(template)); kwargs...) +end +function aiscan(schema::AbstractTracerSchema, template::Symbol; kwargs...) + aiscan(schema, render(schema, AITemplate(template)); kwargs...) +end +function aiimage(schema::AbstractTracerSchema, template::Symbol; kwargs...) + aiimage(schema, render(schema, AITemplate(template)); kwargs...) +end + ## Utility for creating templates """ create_template(; user::AbstractString, system::AbstractString="Act as a helpful AI assistant.", diff --git a/src/user_preferences.jl b/src/user_preferences.jl index 8387f55a..27c70a61 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -523,9 +523,27 @@ registry = Dict{String, ModelSpec}( "accounts/fireworks/models/mixtral-8x7b-instruct" => ModelSpec( "accounts/fireworks/models/mixtral-8x7b-instruct", FireworksOpenAISchema(), - 4e-7, #unknown, expected 1.25e-7 - 1.6e-6, #unknown, expected 3.75e-7 + 5e-7, + 5e-7, "Mixtral (8x7b) from Mistral, hosted by Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/mixtral-8x7b-instruct)."), + "accounts/fireworks/models/mixtral-8x22b-instruct-preview" => ModelSpec( + "accounts/fireworks/models/mixtral-8x22b-instruct-preview", + FireworksOpenAISchema(), + 9e-7, + 9e-7, + "Mixtral (8x22b) from Mistral, instruction finetuned and hosted by Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/mixtral-8x22b-instruct-preview)."), + "accounts/fireworks/models/dbrx-instruct" => ModelSpec( + "accounts/fireworks/models/dbrx-instruct", + FireworksOpenAISchema(), + 1.6e-6, + 1.6e-6, + "Databricks DBRX Instruct, hosted by Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/dbrx-instruct)."), + "accounts/fireworks/models/qwen-72b-chat" => ModelSpec( + "accounts/fireworks/models/qwen-72b-chat", + FireworksOpenAISchema(), + 9e-7, + 9e-7, + "Qwen is a 72B parameter model from Alibaba Cloud, hosted by from Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/dbrx-instruct)."), "accounts/fireworks/models/firefunction-v1" => ModelSpec( "accounts/fireworks/models/firefunction-v1", FireworksOpenAISchema(), diff --git a/test/llm_tracer.jl b/test/llm_tracer.jl new file mode 100644 index 00000000..8e6d24ed --- /dev/null +++ b/test/llm_tracer.jl @@ -0,0 +1,213 @@ +using PromptingTools: TestEchoOpenAISchema, render, OpenAISchema, TracerSchema +using PromptingTools: AIMessage, SystemMessage, AbstractMessage +using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, TracerMessage +using PromptingTools: CustomProvider, + CustomOpenAISchema, MistralOpenAISchema, MODEL_EMBEDDING, + MODEL_IMAGE_GENERATION +using PromptingTools: initialize_tracer, finalize_tracer, isaimessage, istracermessage, + unwrap, AITemplate + +@testset "render-Tracer" begin + schema = TracerSchema(OpenAISchema()) + # Given a schema and a vector of messages with handlebar variables, it should replace the variables with the correct values in the conversation dictionary. + messages = [ + SystemMessage("Act as a helpful AI assistant"), + UserMessage("Hello, my name is {{name}}") + ] + conv = render(schema, messages) + @test conv == messages + + conv = render(schema, AITemplate(:InputClassifier)) + @test conv isa Vector +end + +@testset "initialize_tracer" begin + schema = TracerSchema(OpenAISchema()) + time_before = now() + + ## default initialization + tracer = initialize_tracer(schema; tracer_kwargs = (; a = 1)) + @test tracer.time_sent >= time_before + @test tracer.model == "" + @test tracer.a == 1 + + ## custom model and tracer_kwargs + custom_model = "custom_model" + custom_tracer_kwargs = (parent_id = :parent, thread_id = :thread, run_id = 1) + tracer = initialize_tracer( + schema; model = custom_model, tracer_kwargs = custom_tracer_kwargs) + @test tracer.time_sent >= time_before + @test tracer.model == custom_model + @test tracer.parent_id == :parent + @test tracer.thread_id == :thread + @test tracer.run_id == 1 +end + +@testset "finalize_tracer" begin + schema = TracerSchema(OpenAISchema()) + tracer = initialize_tracer(schema; model = "test_model", + tracer_kwargs = (parent_id = :parent, thread_id = :thread, run_id = 1)) + time_before = now() + + # single non-tracer message + msg = SystemMessage("Test message") + finalized_msg = finalize_tracer(schema, tracer, msg) + @test finalized_msg isa TracerMessage + @test finalized_msg.object == msg + @test finalized_msg.model == "test_model" + @test finalized_msg.parent_id == :parent + @test finalized_msg.thread_id == :thread + @test finalized_msg.run_id == 1 + @test finalized_msg.time_received >= time_before + + # vector of non-tracer messages + msgs = [SystemMessage("Test message 1"), SystemMessage("Test message 2")] + finalized_msgs = finalize_tracer(schema, tracer, msgs) + @test all(istracermessage, finalized_msgs) + @test length(finalized_msgs) == 2 + @test finalized_msgs[1].object == msgs[1] + @test finalized_msgs[2].object == msgs[2] + @test all(finalized_msgs) do msg + msg.model == "test_model" + end + @test all(finalized_msgs) do msg + msg.time_received >= time_before + end + + # mixed vector of tracer and non-tracer messages + tracer_msg = TracerMessage(; + object = SystemMessage("Already tracer"), tracer..., time_received = now()) + msgs = [UserMessage("Test message"), tracer_msg] + finalized_msgs = finalize_tracer(schema, tracer, msgs) + @test all(istracermessage, finalized_msgs) + @test length(finalized_msgs) == 2 + @test finalized_msgs[1] isa TracerMessage + @test finalized_msgs[2] === tracer_msg # should be the same object, not a new one +end + +@testset "aigenerate-Tracer" begin + # corresponds to OpenAI API v1 + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "Hello!"), + :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + + # Real generation API + schema1 = TestEchoOpenAISchema(; response, status = 200) |> TracerSchema + msg = aigenerate( + schema1, "Hello World"; model = "xyz", tracer_kwargs = (; thread_id = :ABC1)) + @test istracermessage(msg) + @test unwrap(msg) |> isaimessage + @test msg.content == "Hello!" + @test msg.model == "xyz" + @test msg.thread_id == :ABC1 + + msg = aigenerate(schema1, :BlankSystemUser) + @test istracermessage(msg) +end + +@testset "aiembed-Tracer" begin + # corresponds to OpenAI API v1 + response1 = Dict(:data => [Dict(:embedding => ones(128))], + :usage => Dict(:total_tokens => 2, :prompt_tokens => 2, :completion_tokens => 0)) + + # Real generation API + schema1 = TestEchoOpenAISchema(; response = response1, status = 200) |> TracerSchema + msg = aiembed(schema1, "Hello World") + @test istracermessage(msg) + @test unwrap(msg) isa DataMessage +end + +@testset "aiclassify-Tracer" begin + # corresponds to OpenAI API v1 + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "1"), + :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + + # Real generation API + schema1 = TestEchoOpenAISchema(; response, status = 200) |> TracerSchema + choices = [ + ("A", "any animal or creature"), + ("P", "for any plant or tree"), + ("O", "for everything else") + ] + msg = aiclassify(schema1, :InputClassifier; input = "pelican", choices) + @test istracermessage(msg) + @test unwrap(msg) isa AIMessage + @test msg.content == "A" +end + +@testset "aiextract-OpenAI" begin + # mock return type + struct RandomType1235 + x::Int + end + return_type = RandomType1235 + + mock_choice = Dict( + :message => Dict(:content => "Hello!", + :tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))) + ]), + :logprobs => Dict(:content => [Dict(:logprob => -0.5), Dict(:logprob => -0.4)]), + :finish_reason => "stop") + ## Test with a single sample + response = Dict(:choices => [mock_choice], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema1 = TestEchoOpenAISchema(; response, status = 200) |> TracerSchema + msg = aiextract(schema1, "Extract number 1"; return_type, + model = "gpt4", + api_kwargs = (; temperature = 0, n = 2)) + @test istracermessage(msg) + @test unwrap(msg) isa DataMessage + @test msg.content == RandomType1235(1) + @test msg.log_prob ≈ -0.9 + + msg = aiextract(schema1, :BlankSystemUser; return_type) + @test istracermessage(msg) +end + +@testset "aiscan-Tracer" begin + ## Test with single sample and log_probs samples + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "Hello1!"), + :finish_reason => "stop", + :logprobs => Dict(:content => [ + Dict(:logprob => -0.1), + Dict(:logprob => -0.2) + ])) + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema1 = TestEchoOpenAISchema(; response, status = 200) |> TracerSchema + msg = aiscan(schema1, "Describe the image"; + image_url = "https://example.com/image.png", + model = "gpt4", http_kwargs = (; verbose = 3), + api_kwargs = (; temperature = 0)) + @test istracermessage(msg) + @test unwrap(msg) isa AIMessage + @test msg.content == "Hello1!" + @test msg.log_prob ≈ -0.3 + + msg = aiscan(schema1, :BlankSystemUser; image_url = "https://example.com/image.png") + @test istracermessage(msg) +end + +@testset "aiimage-Tracer" begin + # corresponds to OpenAI API v1 for create_images + payload = Dict(:url => "xyz/url", :revised_prompt => "New prompt") + response1 = Dict(:data => [payload]) + schema1 = TestEchoOpenAISchema(; response = response1, status = 200) |> TracerSchema + + msg = aiimage(schema1, "Hello World") + @test istracermessage(msg) + @test unwrap(msg) isa DataMessage + + msg = aiimage(schema1, :BlankSystemUser) + @test istracermessage(msg) +end \ No newline at end of file diff --git a/test/messages.jl b/test/messages.jl index 662034aa..e03f5678 100644 --- a/test/messages.jl +++ b/test/messages.jl @@ -2,7 +2,10 @@ using PromptingTools: AIMessage, SystemMessage, MetadataMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage using PromptingTools: _encode_local_image, attach_images_to_user_message, last_message, last_output -using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage +using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage, + istracermessage +using PromptingTools: TracerMessageLike, TracerMessage, align_tracer!, unwrap, + AbstractTracerMessage, AbstractTracer, pprint @testset "Message constructors" begin # Creates an instance of MSG with the given content string. @@ -31,6 +34,8 @@ using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage @test SystemMessage(content) |> issystemmessage @test DataMessage(; content) |> isdatamessage @test AIMessage(; content) |> isaimessage + @test UserMessage(content) |> AIMessage |> isaimessage + @test UserMessage(content) != AIMessage(content) end @testset "UserMessageWithImages" begin content = "Hello, world!" @@ -110,4 +115,86 @@ end msg = UserMessage("Hello, world 2!") @test last_message(msg) == msg @test last_output(msg) == "Hello, world 2!" +end + +@testset "TracerMessage,TracerMessageLike" begin + # Tracer functionality + msg1 = UserMessage("Hi") + msg2 = AIMessage("Hi there!") + + # Create wrapper + tr1 = TracerMessage(msg1; from = :me, to = :you) + @test istracermessage(tr1) + @test tr1.object == msg1 + @test tr1.from == :me + @test tr1.to == :you + + # Message methods + tr2 = TracerMessage(msg2; from = :you, to = :me) + @test tr1.content == msg1.content + @test tr2.run_id == msg2.run_id + @test tr1 != tr2 + @test tr1 == tr1 + @test UserMessage(tr2).content == msg2.content + @test copy(tr1) == tr1 + @test copy(tr2) !== tr2 + + # Specific methods + # unwrap the tracer + @test unwrap(tr1) == msg1 + + # Align random IDs + conv = [tr1, tr2] + align_tracer!(conv) + @test conv[1].parent_id == conv[2].parent_id + @test conv[1].thread_id == conv[2].thread_id + + empty_ = AbstractTracer[] + @test empty_ == align_tracer!(empty_) + + ## TracerMessageLike + str = "Test Message" + tracer = TracerMessageLike(str) + @test tracer.object == str + @test unwrap(tracer) == str + + # methods + tracer2 = TracerMessageLike(str) + @test tracer == tracer2 + + struct TracerRandom1 <: AbstractTracer{Int} end + tracer3 = TracerRandom1() + @test tracer != tracer3 + + # show and pprint for TracerMessage + # Test show method + io_show = IOBuffer() + show(io_show, MIME("text/plain"), tr1) + show_output = String(take!(io_show)) + @test occursin("TracerMessage", show_output) + @test occursin("UserMessage", show_output) + @test occursin("you", show_output) + + # Test pprint method + io_pprint = IOBuffer() + pprint(io_pprint, tr1) + pprint_output = String(take!(io_pprint)) + @test occursin("TracerMessage with:", pprint_output) + @test occursin("User Message", pprint_output) + @test occursin("Hi", pprint_output) + + # show and pprint for TracerMessageLike + # Test show method + io_show = IOBuffer() + show(io_show, MIME("text/plain"), tracer) + show_output = String(take!(io_show)) + @test occursin("TracerMessageLike{String}", show_output) + @test occursin("Test Message", show_output) + + # Test pprint method + io_pprint = IOBuffer() + pprint(io_pprint, tracer) + pprint_output = String(take!(io_pprint)) + @test occursin("TracerMessageLike with:", pprint_output) + @test occursin("Test Message", pprint_output) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index eebcbdf0..eb8614f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using PromptingTools using OpenAI, HTTP, JSON3 using SparseArrays, LinearAlgebra, Markdown using Statistics +using Dates: now using Test, Pkg const PT = PromptingTools using Aqua @@ -23,6 +24,7 @@ end include("llm_google.jl") include("llm_anthropic.jl") include("llm_sharegpt.jl") + include("llm_tracer.jl") include("macros.jl") include("templates.jl") include("serialization.jl") diff --git a/test/serialization.jl b/test/serialization.jl index 2434147b..8c4dfacb 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -49,4 +49,26 @@ end s = read(fn, String) @test s == """{"conversations":[{"value":"System message 1","from":"system"},{"value":"User message","from":"human"},{"value":"AI message","from":"gpt"}]}""" +end + +@testset "Serialization - TracerMessage" begin + conv = AbstractMessage[SystemMessage("System message 1"), + UserMessage("User message"), + AIMessage("AI message")] + traced_conv = TracerMessage.(conv) + align_tracer!(traced_conv) + tmp, _ = mktemp() + save_conversation(tmp, traced_conv) + loaded_tracer = load_conversation(tmp) + @test loaded_tracer == traced_conv + + # We cannot recover all type information !!! + obj = Dict{String, Any}("a" => 1, "b" => 2) + tr = TracerMessageLike(obj; from = :user, to = :ai, model = "TestModel") + tmp, _ = mktemp() + JSON3.write(tmp, tr) + tr2 = JSON3.read(tmp, TracerMessageLike) + @test tr2.from == tr.from + @test tr2.to == tr.to + @test unwrap(tr) == unwrap(tr2) == obj end \ No newline at end of file