Skip to content

Commit

Permalink
Add assertion for empty docs in get_embeddings (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliadmtru authored Aug 23, 2024
1 parent 3083e9d commit 5b31403
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
kwargs...)
@assert !isempty(docs) "The list of docs to get embeddings from should not be empty."

## check if extension is available
ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt)
if isnothing(ext)
Expand Down Expand Up @@ -338,6 +340,8 @@ function get_embeddings(
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
kwargs...)
@assert !isempty(docs) "The list of docs to get embeddings from should not be empty."

emb = get_embeddings(BatchEmbedder(), docs; verbose, model, truncate_dimension,
cost_tracker, target_batch_size_length, ntasks, kwargs...)
# This will return Matrix{Bool}, eg, map(>(0),emb)
Expand Down Expand Up @@ -387,6 +391,8 @@ function get_embeddings(
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
kwargs...)
@assert !isempty(docs) "The list of docs to get embeddings from should not be empty."

emb = get_embeddings(BatchEmbedder(), docs; verbose, model, truncate_dimension,
cost_tracker, target_batch_size_length, ntasks, kwargs...)
# This will return Matrix{UInt64} to save space
Expand Down
5 changes: 5 additions & 0 deletions test/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ end
end

@testset "get_embeddings" begin
# docs should not be empty
@test_throws AssertionError get_embeddings(BatchEmbedder(), String[])
@test_throws AssertionError get_embeddings(BinaryBatchEmbedder(), String[])
@test_throws AssertionError get_embeddings(BitPackedBatchEmbedder(), String[])

# corresponds to OpenAI API v1
response1 = Dict(:data => [Dict(:embedding => ones(128, 2))],
:usage => Dict(:total_tokens => 2, :prompt_tokens => 2, :completion_tokens => 0))
Expand Down

0 comments on commit 5b31403

Please sign in to comment.