Skip to content

Commit

Permalink
Rag Tools fix + relaxing const for API key loading
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jun 26, 2024
1 parent da4089f commit a649431
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 78 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Fixed
- Fixed loading RAGResult when one of the candidate fields was `nothing`.
- Utility type checks like `isusermessage`, `issystemmessage`, `isdatamessage`, `isaimessage`, `istracermessage` do not throw errors when given any arbitrary input types (previously they only worked for `AbstractMessage` types). It's a `isa` check, so it should work for all input types.
- Changed preference loading to use typed `global` instead of `const`, to fix issues with API keys not being loaded properly on start. You can now also call `PromptingTools.load_api_keys!()` to re-load the API keys (and ENV variables) manually.

## [0.33.0]

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.33.0"
version = "0.33.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
5 changes: 3 additions & 2 deletions src/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,10 @@ function StructTypes.constructfrom(::Type{RAGResult}, obj::Union{Dict, JSON3.Obj
## Retype where necessary
for f in [
:emb_candidates, :tag_candidates, :filtered_candidates, :reranked_candidates]
if haskey(obj, f) && haskey(obj[f], :index_ids)
## Check for nothing value, because tag_candidates can be empty
if haskey(obj, f) && !isnothing(obj[f]) && haskey(obj[f], :index_ids)
obj[f] = StructTypes.constructfrom(MultiCandidateChunks, obj[f])
elseif haskey(obj, f)
elseif haskey(obj, f) && !isnothing(obj[f])
obj[f] = StructTypes.constructfrom(CandidateChunks, obj[f])
end
end
Expand Down
3 changes: 3 additions & 0 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ include("Experimental/Experimental.jl")
function __init__()
# Load templates
load_templates!()

# Load ENV variables
load_api_keys!()
end

# Enable precompilation to reduce start time, disabled logging
Expand Down
11 changes: 6 additions & 5 deletions src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,12 @@ function (MSG::Type{<:AbstractChatMessage})(msg::AbstractTracerMessage{<:Abstrac
MSG(; msg.content)
end

isusermessage(m::AbstractMessage) = m isa UserMessage
issystemmessage(m::AbstractMessage) = m isa SystemMessage
isdatamessage(m::AbstractMessage) = m isa DataMessage
isaimessage(m::AbstractMessage) = m isa AIMessage
istracermessage(m::AbstractMessage) = m isa AbstractTracerMessage
## It checks types so it should be defined for all inputs
isusermessage(m::Any) = m isa UserMessage
issystemmessage(m::Any) = m isa SystemMessage
isdatamessage(m::Any) = m isa DataMessage
isaimessage(m::Any) = m isa AIMessage
istracermessage(m::Any) = m isa AbstractTracerMessage
isusermessage(m::AbstractTracerMessage) = isusermessage(m.object)
issystemmessage(m::AbstractTracerMessage) = issystemmessage(m.object)
isdatamessage(m::AbstractTracerMessage) = isdatamessage(m.object)
Expand Down
143 changes: 73 additions & 70 deletions src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,80 +124,83 @@ function get_preferences(key::String)
end

## Load up GLOBALS
const MODEL_CHAT::String = @load_preference("MODEL_CHAT", default="gpt-3.5-turbo")
const MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING",
global MODEL_CHAT::String = @load_preference("MODEL_CHAT", default="gpt-3.5-turbo")
global MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING",
default="text-embedding-3-small")
const MODEL_IMAGE_GENERATION = @load_preference("MODEL_IMAGE_GENERATION",
global MODEL_IMAGE_GENERATION::String = @load_preference("MODEL_IMAGE_GENERATION",
default="dall-e-3")
# the prompt schema default is defined in llm_interace.jl !
# const PROMPT_SCHEMA = OpenAISchema()

# First, load from preferences, then from environment variables
# Note: We load first into a variable `temp_` to avoid inlining of the get(ENV...) call
_temp = get(ENV, "OPENAI_API_KEY", "")
const OPENAI_API_KEY::String = @load_preference("OPENAI_API_KEY",
default=_temp);
# Note: Disable this warning by setting OPENAI_API_KEY to anything
isempty(OPENAI_API_KEY) &&
@warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=<api-key>`!"

_temp = get(ENV, "MISTRALAI_API_KEY", "")
const MISTRALAI_API_KEY::String = @load_preference("MISTRALAI_API_KEY",
default=_temp);

_temp = get(ENV, "COHERE_API_KEY", "")
const COHERE_API_KEY::String = @load_preference("COHERE_API_KEY",
default=_temp);

_temp = get(ENV, "DATABRICKS_API_KEY", "")
const DATABRICKS_API_KEY::String = @load_preference("DATABRICKS_API_KEY",
default=_temp);

_temp = get(ENV, "DATABRICKS_HOST", "")
const DATABRICKS_HOST::String = @load_preference("DATABRICKS_HOST",
default=_temp);

_temp = get(ENV, "TAVILY_API_KEY", "")
const TAVILY_API_KEY::String = @load_preference("TAVILY_API_KEY",
default=_temp);

_temp = get(ENV, "GOOGLE_API_KEY", "")
const GOOGLE_API_KEY::String = @load_preference("GOOGLE_API_KEY",
default=_temp);

_temp = get(ENV, "TOGETHER_API_KEY", "")
const TOGETHER_API_KEY::String = @load_preference("TOGETHER_API_KEY",
default=_temp);

_temp = get(ENV, "FIREWORKS_API_KEY", "")
const FIREWORKS_API_KEY::String = @load_preference("FIREWORKS_API_KEY",
default=_temp);

_temp = get(ENV, "ANTHROPIC_API_KEY", "")
const ANTHROPIC_API_KEY::String = @load_preference("ANTHROPIC_API_KEY",
default=_temp);

_temp = get(ENV, "VOYAGE_API_KEY", "")
const VOYAGE_API_KEY::String = @load_preference("VOYAGE_API_KEY",
default=_temp);

_temp = get(ENV, "GROQ_API_KEY", "")
const GROQ_API_KEY::String = @load_preference("GROQ_API_KEY",
default=_temp);

_temp = get(ENV, "DEEPSEEK_API_KEY", "")
const DEEPSEEK_API_KEY::String = @load_preference("DEEPSEEK_API_KEY",
default=_temp);

_temp = get(ENV, "LOCAL_SERVER", "http://localhost:10897/v1")
## Address of the local server
const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER",
default=_temp);

_temp = get(ENV, "LOG_DIR", joinpath(pwd(), "log"))
## Address of the local server
const LOG_DIR::String = @load_preference("LOG_DIR",
default=_temp);
# Instantiate empty global variables
global OPENAI_API_KEY::String = ""
global MISTRALAI_API_KEY::String = ""
global COHERE_API_KEY::String = ""
global DATABRICKS_API_KEY::String = ""
global TAVILY_API_KEY::String = ""
global GOOGLE_API_KEY::String = ""
global ANTHROPIC_API_KEY::String = ""
global VOYAGE_API_KEY::String = ""
global GROQ_API_KEY::String = ""
global DEEPSEEK_API_KEY::String = ""
global LOCAL_SERVER::String = ""
global LOG_DIR::String = ""

# Load them on init
"Loads API keys from environment variables and preferences"
function load_api_keys!()
global OPENAI_API_KEY
OPENAI_API_KEY = @load_preference("OPENAI_API_KEY",
default=get(ENV, "OPENAI_API_KEY", ""))
# Note: Disable this warning by setting OPENAI_API_KEY to anything
isempty(OPENAI_API_KEY) &&
@warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=<api-key>`!"

global MISTRALAI_API_KEY
MISTRALAI_API_KEY = @load_preference("MISTRALAI_API_KEY",
default=get(ENV, "MISTRALAI_API_KEY", ""))
global COHERE_API_KEY
COHERE_API_KEY = @load_preference("COHERE_API_KEY",
default=get(ENV, "COHERE_API_KEY", ""))
global DATABRICKS_API_KEY
DATABRICKS_API_KEY = @load_preference("DATABRICKS_API_KEY",
default=get(ENV, "DATABRICKS_API_KEY", ""))
global TAVILY_API_KEY
TAVILY_API_KEY = @load_preference("TAVILY_API_KEY",
default=get(ENV, "TAVILY_API_KEY", ""))
global GOOGLE_API_KEY
GOOGLE_API_KEY = @load_preference("GOOGLE_API_KEY",
default=get(ENV, "GOOGLE_API_KEY", ""))
global TOGETHER_API_KEY
TOGETHER_API_KEY = @load_preference("TOGETHER_API_KEY",
default=get(ENV, "TOGETHER_API_KEY", ""))
global FIREWORKS_API_KEY
FIREWORKS_API_KEY = @load_preference("FIREWORKS_API_KEY",
default=get(ENV, "FIREWORKS_API_KEY", ""))
global ANTHROPIC_API_KEY
ANTHROPIC_API_KEY = @load_preference("ANTHROPIC_API_KEY",
default=get(ENV, "ANTHROPIC_API_KEY", ""))
global VOYAGE_API_KEY
VOYAGE_API_KEY = @load_preference("VOYAGE_API_KEY",
default=get(ENV, "VOYAGE_API_KEY", ""))
global GROQ_API_KEY
GROQ_API_KEY = @load_preference("GROQ_API_KEY",
default=get(ENV, "GROQ_API_KEY", ""))
global DEEPSEEK_API_KEY
DEEPSEEK_API_KEY = @load_preference("DEEPSEEK_API_KEY",
default=get(ENV, "DEEPSEEK_API_KEY", ""))
global LOCAL_SERVER
LOCAL_SERVER = @load_preference("LOCAL_SERVER",
default=get(ENV, "LOCAL_SERVER", ""))
global LOG_DIR
LOG_DIR = @load_preference("LOG_DIR",
default=get(ENV, "LOG_DIR", joinpath(pwd(), "log")))

return nothing
end
# Try to load already for safety
load_api_keys!()

## CONVERSATION HISTORY
"""
Expand All @@ -212,8 +215,8 @@ See also: `push_conversation!`, `resize_conversation!`
"""
const CONV_HISTORY = Vector{Vector{<:Any}}()
const CONV_HISTORY_LOCK = ReentrantLock()
const MAX_HISTORY_LENGTH = @load_preference("MAX_HISTORY_LENGTH",
default=5)::Union{Int, Nothing}
global MAX_HISTORY_LENGTH::Union{Int, Nothing} = @load_preference("MAX_HISTORY_LENGTH",
default=5)

## Model registry
# A dictionary of model names and their specs (ie, name, costs per token, etc.)
Expand Down
6 changes: 6 additions & 0 deletions test/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ using PromptingTools: TracerSchema, SaverSchema
@test AIMessage(; content) |> isaimessage
@test UserMessage(content) |> AIMessage |> isaimessage
@test UserMessage(content) != AIMessage(content)
## check handling other types
@test isusermessage(1) == false
@test issystemmessage(nothing) == false
@test isdatamessage(1) == false
@test isaimessage(missing) == false
@test istracermessage(1) == false
end
@testset "UserMessageWithImages" begin
content = "Hello, world!"
Expand Down

0 comments on commit a649431

Please sign in to comment.