From a649431a351725ea32c4720b4b00119d4c06564c Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:26:32 +0100 Subject: [PATCH] Rag Tools fix + relaxing `const` for API key loading --- CHANGELOG.md | 3 + Project.toml | 2 +- src/Experimental/RAGTools/types.jl | 5 +- src/PromptingTools.jl | 3 + src/messages.jl | 11 ++- src/user_preferences.jl | 143 +++++++++++++++-------------- test/messages.jl | 6 ++ 7 files changed, 95 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03abbe7f..2b55e3b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Fixed +- Fixed loading RAGResult when one of the candidate fields was `nothing`. +- Utility type checks like `isusermessage`, `issystemmessage`, `isdatamessage`, `isaimessage`, `istracermessage` do not throw errors when given any arbitrary input types (previously they only worked for `AbstractMessage` types). It's a `isa` check, so it should work for all input types. +- Changed preference loading to use typed `global` instead of `const`, to fix issues with API keys not being loaded properly on start. You can now also call `PromptingTools.load_api_keys!()` to re-load the API keys (and ENV variables) manually. ## [0.33.0] diff --git a/Project.toml b/Project.toml index c2a20a98..a6e98bd8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.33.0" +version = "0.33.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 41c783a3..71686ae7 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -736,9 +736,10 @@ function StructTypes.constructfrom(::Type{RAGResult}, obj::Union{Dict, JSON3.Obj ## Retype where necessary for f in [ :emb_candidates, :tag_candidates, :filtered_candidates, :reranked_candidates] - if haskey(obj, f) && haskey(obj[f], :index_ids) + ## Check for nothing value, because tag_candidates can be empty + if haskey(obj, f) && !isnothing(obj[f]) && haskey(obj[f], :index_ids) obj[f] = StructTypes.constructfrom(MultiCandidateChunks, obj[f]) - elseif haskey(obj, f) + elseif haskey(obj, f) && !isnothing(obj[f]) obj[f] = StructTypes.constructfrom(CandidateChunks, obj[f]) end end diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index c4827e75..de0047b6 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -88,6 +88,9 @@ include("Experimental/Experimental.jl") function __init__() # Load templates load_templates!() + + # Load ENV variables + load_api_keys!() end # Enable precompilation to reduce start time, disabled logging diff --git a/src/messages.jl b/src/messages.jl index fdc93120..718c906d 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -137,11 +137,12 @@ function (MSG::Type{<:AbstractChatMessage})(msg::AbstractTracerMessage{<:Abstrac MSG(; msg.content) end -isusermessage(m::AbstractMessage) = m isa UserMessage -issystemmessage(m::AbstractMessage) = m isa SystemMessage -isdatamessage(m::AbstractMessage) = m isa DataMessage -isaimessage(m::AbstractMessage) = m isa AIMessage -istracermessage(m::AbstractMessage) = m isa AbstractTracerMessage +## It checks types so it should be defined for all inputs +isusermessage(m::Any) = m isa UserMessage +issystemmessage(m::Any) = m isa SystemMessage +isdatamessage(m::Any) = m isa DataMessage +isaimessage(m::Any) = m isa AIMessage +istracermessage(m::Any) = m isa AbstractTracerMessage isusermessage(m::AbstractTracerMessage) = isusermessage(m.object) issystemmessage(m::AbstractTracerMessage) = issystemmessage(m.object) isdatamessage(m::AbstractTracerMessage) = isdatamessage(m.object) diff --git a/src/user_preferences.jl b/src/user_preferences.jl index ef7fe8ec..32b3271c 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -124,80 +124,83 @@ function get_preferences(key::String) end ## Load up GLOBALS -const MODEL_CHAT::String = @load_preference("MODEL_CHAT", default="gpt-3.5-turbo") -const MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING", +global MODEL_CHAT::String = @load_preference("MODEL_CHAT", default="gpt-3.5-turbo") +global MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING", default="text-embedding-3-small") -const MODEL_IMAGE_GENERATION = @load_preference("MODEL_IMAGE_GENERATION", +global MODEL_IMAGE_GENERATION::String = @load_preference("MODEL_IMAGE_GENERATION", default="dall-e-3") # the prompt schema default is defined in llm_interace.jl ! # const PROMPT_SCHEMA = OpenAISchema() # First, load from preferences, then from environment variables -# Note: We load first into a variable `temp_` to avoid inlining of the get(ENV...) call -_temp = get(ENV, "OPENAI_API_KEY", "") -const OPENAI_API_KEY::String = @load_preference("OPENAI_API_KEY", - default=_temp); -# Note: Disable this warning by setting OPENAI_API_KEY to anything -isempty(OPENAI_API_KEY) && - @warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=`!" - -_temp = get(ENV, "MISTRALAI_API_KEY", "") -const MISTRALAI_API_KEY::String = @load_preference("MISTRALAI_API_KEY", - default=_temp); - -_temp = get(ENV, "COHERE_API_KEY", "") -const COHERE_API_KEY::String = @load_preference("COHERE_API_KEY", - default=_temp); - -_temp = get(ENV, "DATABRICKS_API_KEY", "") -const DATABRICKS_API_KEY::String = @load_preference("DATABRICKS_API_KEY", - default=_temp); - -_temp = get(ENV, "DATABRICKS_HOST", "") -const DATABRICKS_HOST::String = @load_preference("DATABRICKS_HOST", - default=_temp); - -_temp = get(ENV, "TAVILY_API_KEY", "") -const TAVILY_API_KEY::String = @load_preference("TAVILY_API_KEY", - default=_temp); - -_temp = get(ENV, "GOOGLE_API_KEY", "") -const GOOGLE_API_KEY::String = @load_preference("GOOGLE_API_KEY", - default=_temp); - -_temp = get(ENV, "TOGETHER_API_KEY", "") -const TOGETHER_API_KEY::String = @load_preference("TOGETHER_API_KEY", - default=_temp); - -_temp = get(ENV, "FIREWORKS_API_KEY", "") -const FIREWORKS_API_KEY::String = @load_preference("FIREWORKS_API_KEY", - default=_temp); - -_temp = get(ENV, "ANTHROPIC_API_KEY", "") -const ANTHROPIC_API_KEY::String = @load_preference("ANTHROPIC_API_KEY", - default=_temp); - -_temp = get(ENV, "VOYAGE_API_KEY", "") -const VOYAGE_API_KEY::String = @load_preference("VOYAGE_API_KEY", - default=_temp); - -_temp = get(ENV, "GROQ_API_KEY", "") -const GROQ_API_KEY::String = @load_preference("GROQ_API_KEY", - default=_temp); - -_temp = get(ENV, "DEEPSEEK_API_KEY", "") -const DEEPSEEK_API_KEY::String = @load_preference("DEEPSEEK_API_KEY", - default=_temp); - -_temp = get(ENV, "LOCAL_SERVER", "http://localhost:10897/v1") -## Address of the local server -const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER", - default=_temp); - -_temp = get(ENV, "LOG_DIR", joinpath(pwd(), "log")) -## Address of the local server -const LOG_DIR::String = @load_preference("LOG_DIR", - default=_temp); +# Instantiate empty global variables +global OPENAI_API_KEY::String = "" +global MISTRALAI_API_KEY::String = "" +global COHERE_API_KEY::String = "" +global DATABRICKS_API_KEY::String = "" +global TAVILY_API_KEY::String = "" +global GOOGLE_API_KEY::String = "" +global ANTHROPIC_API_KEY::String = "" +global VOYAGE_API_KEY::String = "" +global GROQ_API_KEY::String = "" +global DEEPSEEK_API_KEY::String = "" +global LOCAL_SERVER::String = "" +global LOG_DIR::String = "" + +# Load them on init +"Loads API keys from environment variables and preferences" +function load_api_keys!() + global OPENAI_API_KEY + OPENAI_API_KEY = @load_preference("OPENAI_API_KEY", + default=get(ENV, "OPENAI_API_KEY", "")) + # Note: Disable this warning by setting OPENAI_API_KEY to anything + isempty(OPENAI_API_KEY) && + @warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=`!" + + global MISTRALAI_API_KEY + MISTRALAI_API_KEY = @load_preference("MISTRALAI_API_KEY", + default=get(ENV, "MISTRALAI_API_KEY", "")) + global COHERE_API_KEY + COHERE_API_KEY = @load_preference("COHERE_API_KEY", + default=get(ENV, "COHERE_API_KEY", "")) + global DATABRICKS_API_KEY + DATABRICKS_API_KEY = @load_preference("DATABRICKS_API_KEY", + default=get(ENV, "DATABRICKS_API_KEY", "")) + global TAVILY_API_KEY + TAVILY_API_KEY = @load_preference("TAVILY_API_KEY", + default=get(ENV, "TAVILY_API_KEY", "")) + global GOOGLE_API_KEY + GOOGLE_API_KEY = @load_preference("GOOGLE_API_KEY", + default=get(ENV, "GOOGLE_API_KEY", "")) + global TOGETHER_API_KEY + TOGETHER_API_KEY = @load_preference("TOGETHER_API_KEY", + default=get(ENV, "TOGETHER_API_KEY", "")) + global FIREWORKS_API_KEY + FIREWORKS_API_KEY = @load_preference("FIREWORKS_API_KEY", + default=get(ENV, "FIREWORKS_API_KEY", "")) + global ANTHROPIC_API_KEY + ANTHROPIC_API_KEY = @load_preference("ANTHROPIC_API_KEY", + default=get(ENV, "ANTHROPIC_API_KEY", "")) + global VOYAGE_API_KEY + VOYAGE_API_KEY = @load_preference("VOYAGE_API_KEY", + default=get(ENV, "VOYAGE_API_KEY", "")) + global GROQ_API_KEY + GROQ_API_KEY = @load_preference("GROQ_API_KEY", + default=get(ENV, "GROQ_API_KEY", "")) + global DEEPSEEK_API_KEY + DEEPSEEK_API_KEY = @load_preference("DEEPSEEK_API_KEY", + default=get(ENV, "DEEPSEEK_API_KEY", "")) + global LOCAL_SERVER + LOCAL_SERVER = @load_preference("LOCAL_SERVER", + default=get(ENV, "LOCAL_SERVER", "")) + global LOG_DIR + LOG_DIR = @load_preference("LOG_DIR", + default=get(ENV, "LOG_DIR", joinpath(pwd(), "log"))) + + return nothing +end +# Try to load already for safety +load_api_keys!() ## CONVERSATION HISTORY """ @@ -212,8 +215,8 @@ See also: `push_conversation!`, `resize_conversation!` """ const CONV_HISTORY = Vector{Vector{<:Any}}() const CONV_HISTORY_LOCK = ReentrantLock() -const MAX_HISTORY_LENGTH = @load_preference("MAX_HISTORY_LENGTH", - default=5)::Union{Int, Nothing} +global MAX_HISTORY_LENGTH::Union{Int, Nothing} = @load_preference("MAX_HISTORY_LENGTH", + default=5) ## Model registry # A dictionary of model names and their specs (ie, name, costs per token, etc.) diff --git a/test/messages.jl b/test/messages.jl index 93e2b2de..793c3324 100644 --- a/test/messages.jl +++ b/test/messages.jl @@ -37,6 +37,12 @@ using PromptingTools: TracerSchema, SaverSchema @test AIMessage(; content) |> isaimessage @test UserMessage(content) |> AIMessage |> isaimessage @test UserMessage(content) != AIMessage(content) + ## check handling other types + @test isusermessage(1) == false + @test issystemmessage(nothing) == false + @test isdatamessage(1) == false + @test isaimessage(missing) == false + @test istracermessage(1) == false end @testset "UserMessageWithImages" begin content = "Hello, world!"