Skip to content

Commit

Permalink
Add TraceMessage for observability (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Apr 15, 2024
1 parent 9482d67 commit 1cda053
Show file tree
Hide file tree
Showing 13 changed files with 785 additions and 8 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Added a few new open-weights models hosted by Fireworks.ai to the registry (DBRX Instruct, Mixtral 8x22b Instruct, Qwen 72b). If you're curious about how well they work, try them!
- Added basic support for observability downstream. Created custom callback infrastructure with `initialize_tracer` and `finalize_tracer` and dedicated types are `TracerMessage` and `TracerMessageLike`. See `?TracerMessage` for more information and the corresponding `aigenerate` docstring.

### Updated
- Changed default model for `RAGTools.CohereReranker` to "cohere-rerank-english-v3.0".

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.19.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -31,6 +32,7 @@ RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra"]
AbstractTrees = "0.4"
Aqua = "0.7"
Base64 = "<0.0.1, 1"
Dates = "<0.0.1, 1"
GoogleGenAI = "0.3"
HTTP = "1"
JSON3 = "1"
Expand Down
8 changes: 4 additions & 4 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ end
verbose::Bool = false,
api_key::AbstractString = PT.COHERE_API_KEY,
top_n::Integer = length(candidates.scores),
model::AbstractString = "rerank-english-v2.0",
model::AbstractString = "rerank-english-v3.0",
return_documents::Bool = false,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Expand All @@ -404,10 +404,10 @@ Re-ranks a list of candidate chunks using the Cohere Rerank API. See https://coh
- `question`: The query to be used for the search.
- `candidates`: The candidate chunks to be re-ranked.
- `top_n`: The number of most relevant documents to return. Default is `length(documents)`.
- `model`: The model to use for reranking. Default is `rerank-english-v2.0`.
- `model`: The model to use for reranking. Default is `rerank-english-v3.0`.
- `return_documents`: A boolean flag indicating whether to return the reranked documents in the response. Default is `false`.
- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`.
- `cost_tracker`: An atomic counter to track the cost of the retrieval. Default is `Threads.Atomic{Float64}(0.0)`. Not currently tracked (cost unclear).
- `cost_tracker`: An atomic counter to track the cost of the retrieval. Not implemented /tracked (cost unclear). Provided for consistency.
"""
function rerank(
Expand All @@ -416,7 +416,7 @@ function rerank(
verbose::Bool = false,
api_key::AbstractString = PT.COHERE_API_KEY,
top_n::Integer = length(candidates.scores),
model::AbstractString = "rerank-english-v2.0",
model::AbstractString = "rerank-english-v3.0",
return_documents::Bool = false,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module PromptingTools

import AbstractTrees
using Base64: base64encode
using Dates: now, DateTime
using Logging
using OpenAI
using JSON3
Expand Down Expand Up @@ -72,6 +73,7 @@ include("llm_ollama.jl")
include("llm_google.jl")
include("llm_anthropic.jl")
include("llm_sharegpt.jl")
include("llm_tracer.jl")

## Convenience utils
export @ai_str, @aai_str, @ai!_str, @aai!_str
Expand Down
22 changes: 22 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,28 @@ Frequently used schema for finetuning LLMs. Conversations are recorded as a vect
"""
struct ShareGPTSchema <: AbstractShareGPTSchema end

abstract type AbstractTracerSchema <: AbstractPromptSchema end

"""
TracerSchema <: AbstractTracerSchema
A schema designed to wrap another schema, enabling pre- and post-execution callbacks for tracing and additional functionalities. This type is specifically utilized within the `TracerMessage` type to trace the execution flow, facilitating observability and debugging in complex conversational AI systems.
The `TracerSchema` acts as a middleware, allowing developers to insert custom logic before and after the execution of the primary schema's functionality. This can include logging, performance measurement, or any other form of tracing required to understand or improve the execution flow.
# Usage
```julia
wrap_schema = TracerSchema(OpenAISchema())
msg = aigenerate(wrap_schema, "Say hi!"; model="gpt-4")
# output type should be TracerMessage
msg isa TracerMessage
```
You can define your own tracer schema and the corresponding methods: `initialize_tracer`, `finalize_tracer`. See `src/llm_tracer.jl`
"""
struct TracerSchema <: AbstractTracerSchema
schema::AbstractPromptSchema
end

## Dispatch into a default schema (can be set by Preferences.jl)
# Since we load it as strings, we need to convert it to a symbol and instantiate it
global PROMPT_SCHEMA::AbstractPromptSchema = @load_preference("PROMPT_SCHEMA",
Expand Down
198 changes: 198 additions & 0 deletions src/llm_tracer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Tracing infrastructure for logging and other callbacks
# - Define your own schema that is subtype of AbstractTracerSchema and wraps the underlying LLM provider schema
# - Customize initialize_tracer and finalize_tracer with your custom callback
# - Call your ai* function with the tracer schema as usual

# Simple passthrough, do nothing
"""
render(tracer_schema::AbstractTracerSchema,
conv::AbstractVector{<:AbstractMessage}; kwargs...)
Passthrough. No changes.
"""
function render(tracer_schema::AbstractTracerSchema,
conv::AbstractVector{<:AbstractMessage}; kwargs...)
return conv
end

"""
initialize_tracer(
tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...)
Initializes `tracer`/callback (if necessary). Can provide any keyword arguments in `tracer_kwargs` (eg, `parent_id`, `thread_id`, `run_id`).
Is executed prior to the `ai*` calls.
In the default implementation, we just collect the necessary data to build the tracer object in `finalize_tracer`.
"""
function initialize_tracer(
tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...)
return (; time_sent = now(), model, tracer_kwargs...)
end

"""
finalize_tracer(
tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...)
Finalizes the calltracer of whatever is nedeed after the `ai*` calls. Use `tracer_kwargs` to provide any information necessary (eg, `parent_id`, `thread_id`, `run_id`).
In the default implementation, we convert all non-tracer messages into `TracerMessage`.
"""
function finalize_tracer(
tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...)
# We already captured all kwargs, they are already in `tracer`, we can ignore them in this implementation
time_received = now()
# work with arrays for unified processing
is_vector = msg_or_conv isa AbstractVector
conv = msg_or_conv isa AbstractVector{<:AbstractMessage} ?
convert(Vector{AbstractMessage}, msg_or_conv) :
AbstractMessage[msg_or_conv]
# all msg non-traced, set times
for i in eachindex(conv)
msg = conv[i]
# change into TracerMessage if not already, use the current kwargs
if !istracermessage(msg)
# we saved our data for `tracer`
conv[i] = TracerMessage(; object = msg, tracer..., time_received)
end
end
return is_vector ? conv : first(conv)
end

"""
aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aigenerate` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aigenerate` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
# Example
```julia
wrap_schema = PT.TracerSchema(PT.OpenAISchema())
msg = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t")
msg isa TracerMessage # true
msg.content # access content like if it was the message
PT.pprint(msg) # pretty-print the message
```
It works on a vector of messages and converts only the non-tracer ones, eg,
```julia
wrap_schema = PT.TracerSchema(PT.OpenAISchema())
conv = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t", return_all = true)
all(PT.istracermessage, conv) #true
```
"""
function aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs, kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
msg_or_conv = aigenerate(tracer_schema.schema, prompt; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, msg_or_conv; model, tracer_kwargs, kwargs...)
end

"""
aiembed(tracer_schema::AbstractTracerSchema,
doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aiembed` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aiembed` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
"""
function aiembed(tracer_schema::AbstractTracerSchema,
doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
embed_or_conv = aiembed(
tracer_schema.schema, doc_or_docs, postprocess; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, embed_or_conv; model, tracer_kwargs..., kwargs...)
end

"""
aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aiclassify` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aiclassify` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
"""
function aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
classify_or_conv = aiclassify(tracer_schema.schema, prompt; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, classify_or_conv; model, tracer_kwargs..., kwargs...)
end

"""
aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aiextract` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aiextract` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
"""
function aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
extract_or_conv = aiextract(tracer_schema.schema, prompt; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, extract_or_conv; model, tracer_kwargs..., kwargs...)
end

"""
aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aiscan` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aiscan` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
"""
function aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
scan_or_conv = aiscan(tracer_schema.schema, prompt; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, scan_or_conv; model, tracer_kwargs..., kwargs...)
end

"""
aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
Wraps the normal `aiimage` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
Logic:
- calls `initialize_tracer`
- calls `aiimage` (with the `tracer_schema.schema`)
- calls `finalize_tracer`
"""
function aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
tracer_kwargs = NamedTuple(), model = "", kwargs...)
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
image_or_conv = aiimage(tracer_schema.schema, prompt; merged_kwargs...)
return finalize_tracer(
tracer_schema, tracer, image_or_conv; model, tracer_kwargs..., kwargs...)
end
Loading

0 comments on commit 1cda053

Please sign in to comment.