Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ public static class Pipeline {
*/
private boolean queryTranslationEnabled = false;

/**
* 是否启用假设性文档嵌入功能,默认为false。
*/
private boolean hypotheticalDocumentEmbeddingEnabled = false;

/**
* 查询翻译的目标语言,默认为"English"。
*/
Expand Down Expand Up @@ -220,7 +225,15 @@ public void setQueryTranslationEnabled(boolean queryTranslationEnabled) {
this.queryTranslationEnabled = queryTranslationEnabled;
}

public String getQueryTranslationLanguage() {
public boolean isHypotheticalDocumentEmbeddingEnabled() {
return hypotheticalDocumentEmbeddingEnabled;
}

public void setHypotheticalDocumentEmbeddingEnabled(boolean hypotheticalDocumentEmbeddingEnabled) {
this.hypotheticalDocumentEmbeddingEnabled = hypotheticalDocumentEmbeddingEnabled;
}

public String getQueryTranslationLanguage() {
return queryTranslationLanguage;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.alibaba.cloud.ai.example.deepresearch.rag.post.DocumentSelectFirstProcess;
import com.alibaba.cloud.ai.example.deepresearch.rag.retriever.RrfHybridElasticsearchRetriever;
import com.alibaba.cloud.ai.example.deepresearch.rag.strategy.RrfFusionStrategy;
import com.alibaba.cloud.ai.example.deepresearch.rag.transformer.HyDeTransformer;
import org.elasticsearch.client.RestClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -64,6 +65,8 @@ public class DefaultHybridRagProcessor implements HybridRagProcessor {

private final TranslationQueryTransformer queryTransformer;

private final HyDeTransformer hyDeTransformer;

private final DocumentSelectFirstProcess documentPostProcessor;

private final RrfFusionStrategy rrfFusionStrategy;
Expand Down Expand Up @@ -98,6 +101,9 @@ public DefaultHybridRagProcessor(@Qualifier("ragVectorStore") VectorStore vector
.build()
: null;

this.hyDeTransformer = ragProperties.getPipeline().isHypotheticalDocumentEmbeddingEnabled()
? HyDeTransformer.builder().chatClientBuilder(chatClientBuilder).build() : null;

// 初始化文档后处理器
this.documentPostProcessor = ragProperties.getPipeline().isPostProcessingSelectFirstEnabled()
? new DocumentSelectFirstProcess() : null;
Expand Down Expand Up @@ -133,7 +139,7 @@ public List<org.springframework.ai.rag.Query> preProcess(org.springframework.ai.
if (queryTransformer != null) {
queries = queries.stream().flatMap(q -> {
org.springframework.ai.rag.Query transformed = queryTransformer.transform(q);
return transformed != null ? Stream.of(transformed) : Stream.empty();
return Stream.of(transformed);
}).collect(Collectors.toList());
}

Expand All @@ -142,6 +148,14 @@ public List<org.springframework.ai.rag.Query> preProcess(org.springframework.ai.
queries = queries.stream().flatMap(q -> queryExpander.expand(q).stream()).collect(Collectors.toList());
}

// 假设性文档生成
if (hyDeTransformer != null) {
queries = queries.stream().flatMap(q -> {
org.springframework.ai.rag.Query transformed = hyDeTransformer.transform(q);
return Stream.of(transformed);
}).collect(Collectors.toList());
}

return queries;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.alibaba.cloud.ai.example.deepresearch.rag.transformer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Generate hypothetical document for query. The implements of the Hypothetical Document
* Embeddings
* <a href="https://arxiv.org/abs/2212.10496">https://arxiv.org/abs/2212.10496</a>
*
* @author benym
*/
public class HyDeTransformer implements QueryTransformer {

private static final Logger logger = LoggerFactory.getLogger(HyDeTransformer.class);

private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
Given a user question, write a comprehensive and informative passage that directly answers the question.

The passage should be factual, well-structured, and contain specific details.

Question: {query}

Passage:
""");

private final ChatClient chatClient;

private final PromptTemplate promptTemplate;

public HyDeTransformer(ChatClient.Builder chatClientBuilder, PromptTemplate promptTemplate) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query");
}

@Override
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");
var hyDeQueryText = this.chatClient.prompt()
.user(user -> user.text(this.promptTemplate.getTemplate()).param("query", query.text()))
.call()
.content();
if (!StringUtils.hasText(hyDeQueryText)) {
logger.warn("Query generate hyDe document result is null/empty. Returning the input query unchanged.");
return query;
}
return query.mutate().text(hyDeQueryText).build();
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {

private ChatClient.Builder chatClientBuilder;

@Nullable
private PromptTemplate promptTemplate;

private Builder() {
}

public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}

public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}

public HyDeTransformer build() {
return new HyDeTransformer(this.chatClientBuilder, this.promptTemplate);
}

}

}