Skip to content
This repository was archived by the owner on Jan 2, 2025. It is now read-only.

Commit a04579e

Browse files
authored
Re-introduce HyDE, work around batched calls (#735)
* Re-introduce HyDE logic This reverts commit f18eb02. * Add trace logs to hyde doc generation * Log around qdrant batch search * Patch batch_search_with using multiple calls to search_points
1 parent 1e5cedf commit a04579e

File tree

3 files changed

+273
-1
lines changed

3 files changed

+273
-1
lines changed

server/bleep/src/semantic.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use qdrant_client::{
1818
},
1919
};
2020

21+
use futures::{stream, StreamExt, TryStreamExt};
2122
use rayon::prelude::*;
2223
use thiserror::Error;
2324
use tracing::{debug, info, trace, warn};
@@ -326,6 +327,62 @@ impl Semantic {
326327
Ok(response.result)
327328
}
328329

330+
pub async fn batch_search_with<'a>(
331+
&self,
332+
parsed_queries: &[&SemanticQuery<'a>],
333+
vectors: Vec<Embedding>,
334+
limit: u64,
335+
offset: u64,
336+
) -> anyhow::Result<Vec<ScoredPoint>> {
337+
// FIXME: This method uses `search_points` internally, and not `search_batch_points`. It's
338+
// not clear why, but it seems that the `batch` variant of the `qdrant` calls leads to
339+
// HTTP2 errors on some deployment configurations. A typical example error:
340+
//
341+
// ```
342+
// hyper::proto::h2::client: client response error: stream error received: stream no longer needed
343+
// ```
344+
//
345+
// Given that qdrant uses `tonic`, this may be a `tonic` issue, possibly similar to:
346+
// https://github.com/hyperium/tonic/issues/222
347+
348+
// Queries should contain the same filters, so we get the first one
349+
let parsed_query = parsed_queries.first().unwrap();
350+
let filters = &build_conditions(parsed_query);
351+
352+
let responses = stream::iter(vectors.into_iter())
353+
.map(|vector| async move {
354+
let points = SearchPoints {
355+
limit,
356+
vector,
357+
collection_name: COLLECTION_NAME.to_string(),
358+
offset: Some(offset),
359+
score_threshold: Some(SCORE_THRESHOLD),
360+
with_payload: Some(WithPayloadSelector {
361+
selector_options: Some(with_payload_selector::SelectorOptions::Enable(
362+
true,
363+
)),
364+
}),
365+
filter: Some(Filter {
366+
must: filters.clone(),
367+
..Default::default()
368+
}),
369+
with_vectors: Some(WithVectorsSelector {
370+
selector_options: Some(with_vectors_selector::SelectorOptions::Enable(
371+
true,
372+
)),
373+
}),
374+
..Default::default()
375+
};
376+
377+
self.qdrant.search_points(&points).await
378+
})
379+
.buffered(10)
380+
.try_collect::<Vec<_>>()
381+
.await?;
382+
383+
Ok(responses.into_iter().flat_map(|r| r.result).collect())
384+
}
385+
329386
pub async fn search<'a>(
330387
&self,
331388
parsed_query: &SemanticQuery<'a>,
@@ -357,6 +414,46 @@ impl Semantic {
357414
Ok(deduplicate_snippets(results, vector, limit))
358415
}
359416

417+
pub async fn batch_search<'a>(
418+
&self,
419+
parsed_queries: &[&SemanticQuery<'a>],
420+
limit: u64,
421+
offset: u64,
422+
retrieve_more: bool,
423+
) -> anyhow::Result<Vec<Payload>> {
424+
if parsed_queries.iter().any(|q| q.target().is_none()) {
425+
anyhow::bail!("no search target for query");
426+
};
427+
428+
let vectors = parsed_queries
429+
.iter()
430+
.map(|q| self.embed(&q.target().unwrap()))
431+
.collect::<anyhow::Result<Vec<_>>>()?;
432+
433+
tracing::trace!(?parsed_queries, "performing qdrant batch search");
434+
435+
let result = self
436+
.batch_search_with(
437+
parsed_queries,
438+
vectors.clone(),
439+
if retrieve_more { limit * 2 } else { limit }, // Retrieve double `limit` and deduplicate
440+
offset,
441+
)
442+
.await;
443+
444+
tracing::trace!(?result, "qdrant batch search returned");
445+
446+
let results = result?
447+
.into_iter()
448+
.map(Payload::from_qdrant)
449+
.collect::<Vec<_>>();
450+
451+
// deduplicate with mmr with respect to the mean of query vectors
452+
// TODO: implement a more robust multi-vector deduplication strategy
453+
let target_vector = mean_pool(vectors);
454+
Ok(deduplicate_snippets(results, target_vector, limit))
455+
}
456+
360457
#[allow(clippy::too_many_arguments)]
361458
#[tracing::instrument(skip(self, repo_ref, relative_path, buffer))]
362459
pub async fn insert_points_for_buffer(
@@ -593,6 +690,19 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
593690
dot(a, b) / (norm(a) * norm(b))
594691
}
595692

693+
// Calculate the element-wise mean of the embeddings
694+
fn mean_pool(embeddings: Vec<Vec<f32>>) -> Vec<f32> {
695+
let len = embeddings.len() as f32;
696+
let mut result = vec![0.0; EMBEDDING_DIM];
697+
for embedding in embeddings {
698+
for (i, v) in embedding.iter().enumerate() {
699+
result[i] += v;
700+
}
701+
}
702+
result.iter_mut().for_each(|v| *v /= len);
703+
result
704+
}
705+
596706
// returns a list of indices to preserve from `snippets`
597707
//
598708
// query_embedding: the embedding of the query terms

server/bleep/src/webserver/answer.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,14 @@ impl Agent {
738738
self.update(Update::Step(SearchStep::Code(query.clone())))
739739
.await?;
740740

741-
let results = self.semantic_search(query.into(), 10, 0, true).await?;
741+
let mut results = self.semantic_search(query.into(), 10, 0, true).await?;
742+
743+
let hyde_docs = self.hyde(query).await?;
744+
if !hyde_docs.is_empty() {
745+
let hyde_docs = hyde_docs.iter().map(|d| d.into()).collect();
746+
let hyde_results = self.batch_semantic_search(hyde_docs, 10, 0, true).await?;
747+
results.extend(hyde_results);
748+
}
742749

743750
let chunks = results
744751
.into_iter()
@@ -764,6 +771,7 @@ impl Agent {
764771
self.track_query(
765772
EventData::input_stage("semantic code search")
766773
.with_payload("query", query)
774+
.with_payload("hyde_queries", &hyde_docs)
767775
.with_payload("chunks", &chunks)
768776
.with_payload("raw_prompt", &prompt),
769777
);
@@ -849,6 +857,33 @@ impl Agent {
849857
Ok(Some(action))
850858
}
851859

860+
async fn hyde(&self, query: &str) -> Result<Vec<String>> {
861+
let prompt = vec![llm_gateway::api::Message::system(
862+
&prompts::hypothetical_document_prompt(query),
863+
)];
864+
865+
tracing::trace!(?query, "generating hyde docs");
866+
867+
let response = self
868+
.llm_gateway
869+
.clone()
870+
.model("gpt-3.5-turbo-0613")
871+
.chat(&prompt, None)
872+
.await?
873+
.try_collect::<String>()
874+
.await?;
875+
876+
tracing::trace!("parsing hyde response");
877+
878+
let documents = prompts::try_parse_hypothetical_documents(&response);
879+
880+
for doc in documents.iter() {
881+
info!(?doc, "got hyde doc");
882+
}
883+
884+
Ok(documents)
885+
}
886+
852887
async fn proc(&mut self, question: &str, path_aliases: &[usize]) -> Result<String> {
853888
let paths = path_aliases
854889
.iter()
@@ -1364,6 +1399,36 @@ impl Agent {
13641399
.await
13651400
}
13661401

1402+
async fn batch_semantic_search(
1403+
&self,
1404+
queries: Vec<Literal<'_>>,
1405+
limit: u64,
1406+
offset: u64,
1407+
retrieve_more: bool,
1408+
) -> Result<Vec<semantic::Payload>> {
1409+
let queries = queries
1410+
.iter()
1411+
.map(|q| SemanticQuery {
1412+
target: Some(q.clone()),
1413+
repos: [Literal::Plain(
1414+
self.conversation.repo_ref.display_name().into(),
1415+
)]
1416+
.into(),
1417+
..self.conversation.last_exchange().query.clone()
1418+
})
1419+
.collect::<Vec<_>>();
1420+
1421+
let queries = queries.iter().collect::<Vec<_>>();
1422+
1423+
debug!(?queries, %self.thread_id, "executing semantic query");
1424+
self.app
1425+
.semantic
1426+
.as_ref()
1427+
.unwrap()
1428+
.batch_search(queries.as_slice(), limit, offset, retrieve_more)
1429+
.await
1430+
}
1431+
13671432
async fn get_file_content(&self, path: &str) -> Result<Option<ContentDocument>> {
13681433
let branch = self.conversation.last_exchange().query.first_branch();
13691434
let repo_ref = &self.conversation.repo_ref;

server/bleep/src/webserver/answer/prompts.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,100 @@ Respect these rules at all times:
280280
- Always begin your answer with an appropriate title"#
281281
)
282282
}
283+
284+
pub fn hypothetical_document_prompt(query: &str) -> String {
285+
format!(
286+
r#"Write three code snippets that could hypothetically be returned by a code search engine as the answer to the query: {query}
287+
288+
- Write these snippets in a variety of programming languages
289+
- The snippets should not be too similar to one another
290+
- Each snippet should be between 5 and 10 lines long
291+
- Surround the snippets in triple backticks
292+
293+
For example:
294+
295+
What's the Qdrant threshold?
296+
297+
```rust
298+
SearchPoints {{
299+
limit,
300+
vector: vectors.get(idx).unwrap().clone(),
301+
collection_name: COLLECTION_NAME.to_string(),
302+
offset: Some(offset),
303+
score_threshold: Some(0.3),
304+
with_payload: Some(WithPayloadSelector {{
305+
selector_options: Some(with_payload_selector::SelectorOptions::Enable(true)),
306+
}}),
307+
```
308+
309+
```python
310+
memories = self.client.search(
311+
collection_name=self.collection_name,
312+
limit=k,
313+
score_threshold=threshold,
314+
search_params=SearchParams(
315+
quantization=QuantizationSearchParams(
316+
ignore=False,
317+
rescore=True
318+
)
319+
)
320+
```
321+
322+
```typescript
323+
const res = await qdrant.search(collectionName, {{
324+
vector: embedding,
325+
score_threshold:
326+
distanceMetric === "Cosine" ? 1 - threshold : threshold,
327+
with_payload: true,
328+
}});
329+
```
330+
"#
331+
)
332+
}
333+
334+
pub fn try_parse_hypothetical_documents(document: &str) -> Vec<String> {
335+
let pattern = r"```([\s\S]*?)```";
336+
let re = regex::Regex::new(pattern).unwrap();
337+
338+
re.captures_iter(document)
339+
.map(|m| m[1].trim().to_string())
340+
.collect()
341+
}
342+
343+
#[cfg(test)]
344+
mod tests {
345+
use super::*;
346+
347+
#[test]
348+
fn test_parse_hypothetical_document() {
349+
let document = r#"Here is some pointless text
350+
351+
```rust
352+
pub fn final_explanation_prompt(context: &str, query: &str, query_history: &str) -> String {
353+
struct Rule<'a> {
354+
title: &'a str,
355+
description: &'a str,
356+
note: &'a str,
357+
schema: &'a str,```
358+
359+
Here is some more pointless text
360+
361+
```
362+
pub fn functions() -> serde_json::Value {
363+
serde_json::json!(
364+
```"#;
365+
let expected = vec![
366+
r#"rust
367+
pub fn final_explanation_prompt(context: &str, query: &str, query_history: &str) -> String {
368+
struct Rule<'a> {
369+
title: &'a str,
370+
description: &'a str,
371+
note: &'a str,
372+
schema: &'a str,"#,
373+
r#"pub fn functions() -> serde_json::Value {
374+
serde_json::json!("#,
375+
];
376+
377+
assert_eq!(try_parse_hypothetical_documents(document), expected);
378+
}
379+
}

0 commit comments

Comments
 (0)