Skip to content

Conversation

cyyeh
Copy link
Member

@cyyeh cyyeh commented Sep 18, 2025

Summary by CodeRabbit

  • New Features
    • Added runtime configuration endpoints: GET /configs to view current providers/pipelines and POST /configs to update them without restart.
    • Exposed per-pipeline metadata (LLM/embedder aliases, descriptions, DB data usage).
    • Introduced provider aliases in configs for easier selection.
  • Refactor
    • Pipelines reworked to support dynamic provider updates and standardized initialization.
    • Switched provider interfaces to property-based access (model, kwargs, context window).
  • Chores
    • Renamed indexing pipelines: sql_pairs → sql_pairs_indexing; instructions → instructions_indexing.
    • Safer document store initialization (conditional index recreation).

Copy link
Contributor

coderabbitai bot commented Sep 18, 2025

Walkthrough

Adds runtime configuration endpoints (/configs GET/POST), restructures startup to store providers and pipe components, and propagates description/alias metadata. Refactors provider interfaces to properties, introduces per-pipeline update_components across many pipelines, and renames some indexing pipeline keys. Adds Configs schema and pipe metadata builder.

Changes

Cohort / File(s) Summary
Service runtime config & startup wiring
src/__main__.py, src/providers/__init__.py, src/globals.py, src/utils.py
Adds GET/POST /configs for introspection and live reconfiguration; generate_components now returns (components, instantiated_providers); new create_pipe_components; introduces Configs schema and has_db_data_in_llm_prompt; pipelines metadata include description and aliases.
Core interfaces
src/core/provider.py, src/core/pipeline.py
Switches getter methods to read-only properties (model, model_kwargs, context_window_size, alias); extends PipelineComponent with description; adds BasicPipeline.update_components/_update_components and str methods.
Providers (LLM/Embedder/Document store)
src/providers/llm/litellm.py, src/providers/embedder/litellm.py, src/providers/document_store/qdrant.py
LLM provider adds alias; embedder renames api_base_url→api_base, adds api_version, alias; Qdrant reset only when recreate_index=True.
Generation pipelines refactor
src/pipelines/generation/*
Multiple pipelines now store providers/description, move AsyncDriver init to start, replace inline components with _update_components, switch to provider properties (model) and add update_components where applicable. Files: chart_adjustment.py, chart_generation.py, data_assistance.py, followup_sql_generation.py, followup_sql_generation_reasoning.py, intent_classification.py, misleading_assistance.py, question_recommendation.py, relationship_recommendation.py, semantics_description.py, sql_answer.py, sql_correction.py, sql_diagnosis.py, sql_generation.py, sql_generation_reasoning.py, sql_question.py, sql_regeneration.py, sql_tables_extraction.py, user_guide_assistance.py.
Indexing pipelines refactor
src/pipelines/indexing/*
Add description param, persist stores/providers, centralize component construction via _update_components, add update_components for dynamic rebind. Files: db_schema.py, historical_question.py, instructions.py, project_meta.py, sql_pairs.py, table_description.py.
Retrieval pipelines refactor
src/pipelines/retrieval/*
Add description, centralize components/configs, add update_components for dynamic reconfiguration, switch to provider properties. Files: db_schema_retrieval.py, historical_question_retrieval.py, instructions.py, preprocess_sql_data.py, sql_executor.py, sql_functions.py, sql_pairs_retrieval.py.
Service pipeline naming adjustments
src/web/v1/services/semantics_preparation.py
Renames pipeline keys: sql_pairs→sql_pairs_indexing; instructions→instructions_indexing.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant U as Client
    participant S as Service (FastAPI)
    participant C as generate_components
    participant SC as ServiceContainer
    participant G as create_pipe_components

    rect rgba(230,245,255,0.5)
    note over S: Startup
    S->>C: generate_components(configs)
    C-->>S: (pipe_components, instantiated_providers)
    S->>SC: build services with providers
    S->>G: create_pipe_components(ServiceContainer)
    G-->>S: pipe_service_components
    S->>S: store in app.state
    end

    rect rgba(240,255,230,0.5)
    note over U,S: Introspection
    U->>S: GET /configs
    S-->>U: env_vars, providers (llm/embedder), pipelines + linkage
    end

    rect rgba(255,245,230,0.5)
    note over U,S: Live reconfiguration
    U->>S: POST /configs (Configs)
    S->>S: rebuild providers (LLM/Embedder/Qdrant)
    S->>SC: update pipelines via update_components(...)
    S->>G: create_pipe_components(...)
    S-->>U: 200 OK (updated snapshot)
    end
Loading
sequenceDiagram
    autonumber
    participant Svc as Service
    participant P as Pipeline Instance
    participant Prov as Providers (LLM/Embedder/Store)

    note over Svc,P: Dynamic pipeline update
    Svc->>P: update_components(llm_provider, embedder_provider, document_store_provider)
    P->>P: set internal refs
    alt needs recompute
      P->>P: _update_components()
    end
    P-->>Svc: components refreshed
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

module/ai-service, wren-ai-service, ai-env-changed

Suggested reviewers

  • yichieh-lu

Poem

A rabbit taps configs with gentle paws,
Flips pipes and providers without a pause.
Descriptions bloom, aliases align,
Components refresh—oh how they shine!
With POST and GET, we hop along—
Live-tuned burrows, swift and strong. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.90% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "feat(wren-ai-service): configs-api" is concise, uses conventional commit style, and accurately captures the primary change in the PR (adding a configuration API and runtime reconfiguration endpoints for the wren-ai-service); it is specific enough for a teammate scanning history to understand the main intent.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/ai-service/configs-api

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (26)
wren-ai-service/src/providers/document_store/qdrant.py (2)

381-394: Dangerous boolean parsing for SHOULD_FORCE_DEPLOY can trigger destructive resets.

bool(os.getenv(...)) is True for any non-empty string (e.g., "false"). Combined with the new conditional reset, this can still wipe collections unintentionally.

-        recreate_index: bool = (
-            bool(os.getenv("SHOULD_FORCE_DEPLOY"))
-            if os.getenv("SHOULD_FORCE_DEPLOY")
-            else False
-        ),
+        recreate_index: bool = (
+            os.getenv("SHOULD_FORCE_DEPLOY", "").strip().lower()
+            in {"1", "true", "t", "yes", "y", "on"}
+        ),
         **_,
     ):
         self._location = location
         self._api_key = Secret.from_token(api_key) if api_key else None
         self._timeout = timeout
         self._embedding_model_dim = embedding_model_dim
-        if recreate_index:
-            self._reset_document_store(recreate_index)
+        if recreate_index:
+            self._reset_document_store(recreate_index)

395-402: Collection names out of sync with new indexing keys.

Reset still targets "sql_pairs" and "instructions" but upstream now uses "..._indexing". This will leave stale/new collections unmanaged.

     def _reset_document_store(self, recreate_index: bool):
         self.get_store(recreate_index=recreate_index)
         self.get_store(dataset_name="table_descriptions", recreate_index=recreate_index)
         self.get_store(dataset_name="view_questions", recreate_index=recreate_index)
-        self.get_store(dataset_name="sql_pairs", recreate_index=recreate_index)
-        self.get_store(dataset_name="instructions", recreate_index=recreate_index)
+        self.get_store(dataset_name="sql_pairs_indexing", recreate_index=recreate_index)
+        self.get_store(dataset_name="instructions_indexing", recreate_index=recreate_index)
         self.get_store(dataset_name="project_meta", recreate_index=recreate_index)
wren-ai-service/src/pipelines/generation/sql_answer.py (2)

152-153: Catch asyncio.TimeoutError, not built-in TimeoutError.

asyncio.wait_for raises asyncio.TimeoutError. Current except block won’t trigger.

Apply this diff:

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break

155-165: Avoid dynamic default evaluated at import time.

current_time: str = Configuration().show_current_time() is evaluated once when the function is defined, not per call.

Apply this diff:

-    async def run(
+    async def run(
         self,
         query: str,
         sql: str,
         sql_data: dict,
         language: str,
-        current_time: str = Configuration().show_current_time(),
+        current_time: Optional[str] = None,
         query_id: Optional[str] = None,
         custom_instruction: Optional[str] = None,
     ) -> dict:
@@
-        return await self._pipe.execute(
+        if current_time is None:
+            current_time = Configuration().show_current_time()
+        return await self._pipe.execute(
             ["generate_answer"],
             inputs={
                 "query": query,
                 "sql": sql,
                 "sql_data": sql_data,
                 "language": language,
                 "current_time": current_time,
                 "query_id": query_id,
                 "custom_instruction": custom_instruction or "",
                 **self._components,
             },
         )
wren-ai-service/src/pipelines/generation/question_recommendation.py (2)

163-186: Template uses categories, but prompt() doesn’t pass it.

user_prompt_template references categories, yet prompt() neither accepts nor forwards it to PromptBuilder.run, so categories are silently ignored.

Proposed change (outside this hunk):

def prompt(
    previous_questions: list[str],
    documents: list,
    language: str,
    max_questions: int,
    max_categories: int,
    categories: list[str],              # add
    prompt_builder: PromptBuilder,
) -> dict:
    _prompt = prompt_builder.run(
        documents=documents,
        previous_questions=previous_questions,
        language=language,
        max_questions=max_questions,
        max_categories=max_categories,
        categories=categories,          # add
    )
    return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}

And add "categories": categories to the inputs of any upstream execute call if missing.


265-270: Avoid mutable default args.

previous_questions: list[str] = [] and categories: list[str] = [] share state across calls.

Apply this diff:

-        previous_questions: list[str] = [],
-        categories: list[str] = [],
+        previous_questions: Optional[list[str]] = None,
+        categories: Optional[list[str]] = None,
@@
-        return await self._pipe.execute(
+        previous_questions = previous_questions or []
+        categories = categories or []
+        return await self._pipe.execute(
             [self._final],
             inputs={
                 "documents": contexts,
                 "previous_questions": previous_questions,
                 "categories": categories,
wren-ai-service/src/pipelines/generation/sql_question.py (1)

62-68: Return type mismatch: generation step returns a tuple but is typed/used as dict

generate_sql_question is annotated to return dict, but it returns (dict, generator_name). post_process consumes it as a dict. Align by returning only the generator result.

 async def generate_sql_question(
-    prompt: dict, generator: Any, generator_name: str
+    prompt: dict, generator: Any, generator_name: str
 ) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/generation/data_assistance.py (2)

80-87: Return type mismatch: generation step returns a tuple but is typed as dict

Returning (dict, generator_name) from data_assistance changes the final pipeline output shape. Return only the generator result.

 async def data_assistance(
     prompt: dict, generator: Any, query_id: str, generator_name: str
 ) -> dict:
-    return await generator(
+    return await generator(
         prompt=prompt.get("prompt"),
         query_id=query_id,
-    ), generator_name
+    )

152-153: Catch asyncio.TimeoutError, not TimeoutError

asyncio.wait_for raises asyncio.TimeoutError.

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break
wren-ai-service/src/pipelines/generation/misleading_assistance.py (2)

80-87: Return type mismatch: generation step returns a tuple but is typed as dict

Return only the generator result to keep the pipeline output stable.

 async def misleading_assistance(
     prompt: dict, generator: Any, query_id: str, generator_name: str
 ) -> dict:
-    return await generator(
+    return await generator(
         prompt=prompt.get("prompt"),
         query_id=query_id,
-    ), generator_name
+    )

152-153: Catch asyncio.TimeoutError

Same issue as data_assistance.

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break
wren-ai-service/src/pipelines/generation/chart_generation.py (1)

88-92: Return type mismatch: generation step returns a tuple but is typed/used as dict

post_process expects generate_chart: dict. Return only the generator result.

 async def generate_chart(prompt: dict, generator: Any, generator_name: str) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py (2)

92-102: Return type mismatch: generation step returns a tuple but is typed/consumed as dict

post_process uses generate_sql_reasoning as dict. Return only the generator result.

 async def generate_sql_reasoning(
     prompt: dict,
     generator: Any,
     query_id: str,
     generator_name: str,
 ) -> dict:
-    return await generator(
+    return await generator(
         prompt=prompt.get("prompt"),
         query_id=query_id,
-    ), generator_name
+    )

172-173: Catch asyncio.TimeoutError

Same asyncio-specific exception applies here.

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break
wren-ai-service/src/pipelines/generation/chart_adjustment.py (1)

114-121: Return type mismatch: generation step returns a tuple but is typed/used as dict

post_process expects dict. Return only generator result.

 async def generate_chart_adjustment(
     prompt: dict,
     generator: Any,
     generator_name: str,
 ) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/generation/user_guide_assistance.py (1)

145-146: Catch asyncio.TimeoutError, not TimeoutError

asyncio.wait_for raises asyncio.TimeoutError.

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break
wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py (2)

81-87: Bug: returning a tuple breaks post_process expecting a dict

post_process() calls .get("replies") on the output of this step; returning (dict, generator_name) will raise at runtime.

-async def generate_sql_reasoning(
-    prompt: dict, generator: Any, query_id: str, generator_name: str
-) -> dict:
-    return await generator(
-        prompt=prompt.get("prompt"), query_id=query_id
-    ), generator_name
+async def generate_sql_reasoning(
+    prompt: dict, generator: Any, query_id: str, generator_name: str
+) -> dict:
+    return await generator(prompt=prompt.get("prompt"), query_id=query_id)

157-158: Catch asyncio.TimeoutError

Same issue as other pipelines.

-            except TimeoutError:
+            except asyncio.TimeoutError:
                 break
wren-ai-service/src/pipelines/generation/relationship_recommendation.py (1)

113-115: Bug: return only the generator result

normalized(generate: dict) expects a dict; returning (dict, generator_name) will fail.

-async def generate(prompt: dict, generator: Any, generator_name: str) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+async def generate(prompt: dict, generator: Any, generator_name: str) -> dict:
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/generation/sql_regeneration.py (1)

130-136: Bug: return only the generator result

post_process consumes regenerate_sql as a dict; tuple return will break.

-async def regenerate_sql(
-    prompt: dict,
-    generator: Any,
-    generator_name: str,
-) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+async def regenerate_sql(
+    prompt: dict,
+    generator: Any,
+    generator_name: str,
+) -> dict:
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/generation/semantics_description.py (1)

148-151: Fix: generate() must return a dict, not a tuple.

Downstream expects dict (normalize() calls .get). Return the generator result and (optionally) inject generator_name into it.

-async def generate(prompt: dict, generator: Any, generator_name: str) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+async def generate(prompt: dict, generator: Any, generator_name: str) -> dict:
+    result = await generator(prompt=prompt.get("prompt"))
+    # optional: expose generator_name for observability
+    result["generator_name"] = generator_name
+    return result
wren-ai-service/src/pipelines/generation/sql_correction.py (1)

101-105: Fix: generate_sql_correction() must return a dict, not a tuple.

post_process() calls .get on its input.

-async def generate_sql_correction(
-    prompt: dict, generator: Any, generator_name: str
-) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+async def generate_sql_correction(
+    prompt: dict, generator: Any, generator_name: str
+) -> dict:
+    result = await generator(prompt=prompt.get("prompt"))
+    result["generator_name"] = generator_name  # optional
+    return result
wren-ai-service/src/pipelines/generation/followup_sql_generation.py (1)

121-128: Bug: generate_sql_in_followup returns a tuple but downstream expects a dict.

post_process calls .get("replies"), which fails on a tuple. Return a dict with "replies" and "generator_name".

 async def generate_sql_in_followup(
     prompt: dict, generator: Any, histories: list[AskHistory], generator_name: str
 ) -> dict:
     history_messages = construct_ask_history_messages(histories)
-    return await generator(
-        prompt=prompt.get("prompt"), history_messages=history_messages
-    ), generator_name
+    replies = await generator(
+        prompt=prompt.get("prompt"), history_messages=history_messages
+    )
+    return {"replies": replies, "generator_name": generator_name}
wren-ai-service/src/pipelines/generation/intent_classification.py (1)

294-297: Return the generator result dict, not a tuple

Returning (dict, generator_name) breaks downstream post_process expecting a dict.

 @observe(as_type="generation", capture_input=False)
 @trace_cost
 async def classify_intent(prompt: dict, generator: Any, generator_name: str) -> dict:
-    return await generator(prompt=prompt.get("prompt")), generator_name
+    return await generator(prompt=prompt.get("prompt"))
wren-ai-service/src/pipelines/retrieval/instructions.py (1)

169-179: default_instructions returns list but callers expect dict; fix type and return shape

This currently raises when filtered_documents is truthy and default_instructions is [].

-@observe(capture_input=False)
-def default_instructions(
+@observe(capture_input=False)
+def default_instructions(
     count_documents: int,
     retriever: Any,
     project_id: str,
     scope_filter: ScopeFilter,
     scope: str,
-) -> list[Document]:
-    if not count_documents:
-        return []
+) -> dict:
+    if not count_documents:
+        return {"documents": []}
@@
-    return dict(documents=res.get("documents"))
+    return dict(documents=res.get("documents"))

Also applies to: 133-166

wren-ai-service/src/pipelines/indexing/sql_pairs.py (1)

125-136: Fix delete_all: current filter never deletes anything when delete_all=True

clean() currently passes an empty sql_pair_ids -> the store builds an "in: []" filter (matches nothing). Pass delete_all through and build filters conditionally. Consider requiring project_id when delete_all=True to avoid accidental global wipes.

@@
 @component
 class SqlPairsCleaner:
-    def __init__(self, sql_pairs_store: DocumentStore) -> None:
+    def __init__(self, sql_pairs_store: DocumentStore) -> None:
         self.store = sql_pairs_store
 
     @component.output_types()
     async def run(
-        self, sql_pair_ids: List[str], project_id: Optional[str] = None
+        self,
+        sql_pair_ids: List[str],
+        project_id: Optional[str] = None,
+        delete_all: bool = False,
     ) -> None:
-        filter = {
-            "operator": "AND",
-            "conditions": [
-                {"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
-            ],
-        }
-
-        if project_id:
-            filter["conditions"].append(
-                {"field": "project_id", "operator": "==", "value": project_id}
-            )
-
-        return await self.store.delete_documents(filter)
+        conditions = []
+        if not delete_all:
+            conditions.append(
+                {"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids}
+            )
+        if project_id:
+            conditions.append(
+                {"field": "project_id", "operator": "==", "value": project_id}
+            )
+        filters = {"operator": "AND", "conditions": conditions} if conditions else None
+        return await self.store.delete_documents(filters)
@@
 async def clean(
     cleaner: SqlPairsCleaner,
     sql_pairs: List[SqlPair],
     embedding: Dict[str, Any] = {},
     project_id: str = "",
     delete_all: bool = False,
 ) -> Dict[str, Any]:
@@
-    if sql_pair_ids or delete_all:
-        await cleaner.run(sql_pair_ids=sql_pair_ids, project_id=project_id)
+    if sql_pair_ids or delete_all:
+        await cleaner.run(
+            sql_pair_ids=sql_pair_ids,
+            project_id=project_id,
+            delete_all=delete_all,
+        )

Also applies to: 53-74

🧹 Nitpick comments (57)
wren-ai-service/src/web/v1/services/semantics_preparation.py (1)

129-134: Fix error message: missing f-string and wrong field.

Currently the braces are literal and it references a non-existent attribute id. Use mdl_hash with an f-string.

-            return SemanticsPreparationStatusResponse(
-                status="failed",
-                error=SemanticsPreparationStatusResponse.SemanticsPreparationError(
-                    code="OTHERS",
-                    message="{prepare_semantics_status_request.id} is not found",
-                ),
-            )
+            return SemanticsPreparationStatusResponse(
+                status="failed",
+                error=SemanticsPreparationStatusResponse.SemanticsPreparationError(
+                    code="OTHERS",
+                    message=f"{prepare_semantics_status_request.mdl_hash} is not found",
+                ),
+            )
wren-ai-service/src/providers/document_store/qdrant.py (2)

161-163: Creating payload index on every init may raise if already exists.

Guard for idempotency to avoid noisy errors on restarts.

-        self.client.create_payload_index(
-            collection_name=index, field_name="project_id", field_schema="keyword"
-        )
+        try:
+            self.client.create_payload_index(
+                collection_name=index, field_name="project_id", field_schema="keyword"
+            )
+        except Exception as e:
+            logger.debug("create_payload_index skipped: %s", e)

351-357: Boolean coalescing drops explicit False.

Using or ignores caller-supplied False for scale_score/return_embedding. Respect None vs False.

-                top_k=top_k or self._top_k,
-                scale_score=scale_score or self._scale_score,
-                return_embedding=return_embedding or self._return_embedding,
+                top_k=self._top_k if top_k is None else top_k,
+                scale_score=self._scale_score if scale_score is None else scale_score,
+                return_embedding=(
+                    self._return_embedding
+                    if return_embedding is None
+                    else return_embedding
+                ),
wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py (2)

144-159: Dynamic update_components LGTM; consider chaining return self.

Returning self enables fluent reconfiguration, but optional.

     def update_components(
         self,
         embedder_provider: EmbedderProvider,
         document_store_provider: DocumentStoreProvider,
         **_,
     ):
         super().update_components(
             embedder_provider=embedder_provider,
             document_store_provider=document_store_provider,
             update_components=False,
         )
         self._view_questions_store = self._document_store_provider.get_store(
             dataset_name="view_questions"
         )
         self._components = self._update_components()
+        return self

24-35: Avoid shadowing built-in list.

Minor readability nit.

-        list = [
+        items = [
             {
                 "question": doc.content,
                 "summary": doc.meta.get("summary", ""),
                 "statement": doc.meta.get("statement") or doc.meta.get("sql"),
                 "viewId": doc.meta.get("viewId", ""),
             }
             for doc in documents
         ]
 
-        return {"documents": list}
+        return {"documents": items}
wren-ai-service/src/providers/llm/litellm.py (3)

32-43: Alias and api key name plumbed correctly. Consider lazy api key resolution for runtime config updates.

Storing _api_key at init means env/key updates after process start won't apply. Prefer deferring to os.getenv(self._api_key_name) at call time or add an api_key property that does this.

Example (outside this hunk):

@property
def api_key(self) -> Optional[str]:
    return os.getenv(self._api_key_name) if self._api_key_name else self._api_key

Then pass api_key=self.api_key in acompletion.


139-147: Guard against empty stream producing UnboundLocalError.

If no chunks arrive, chunk is undefined when calling connect_chunks(chunk, chunks). Guard or track the last chunk explicitly.

Apply this diff:

-                async for chunk in completion:
+                last_chunk = None
+                async for chunk in completion:
                     if chunk.choices and streaming_callback:
                         chunk_delta: StreamingChunk = build_chunk(chunk)
                         chunks.append(chunk_delta)
                         streaming_callback(
                             chunk_delta, query_id
                         )  # invoke callback with the chunk_delta
-                completions = [connect_chunks(chunk, chunks)]
+                        last_chunk = chunk
+                if last_chunk is None:
+                    return {"replies": [""], "meta": [{}]}
+                completions = [connect_chunks(last_chunk, chunks)]

75-81: Broaden retryable exceptions.

Only retrying on openai.APIError may miss connection/timeouts/rate-limits and LiteLLM errors. Consider including openai.RateLimitError, openai.APIConnectionError, openai.APITimeoutError, and LiteLLM exceptions.

wren-ai-service/src/pipelines/generation/sql_answer.py (2)

120-129: Streaming queue logic is fine; ensure queue cleanup always runs.

Cleanup is gated on "" which is good. Just a note: consider a finally/timeout path to avoid orphaned queues if a session never sends finish.


82-87: Return type hint mismatch.

generate_answer returns a tuple (dict, generator_name) but is annotated as -> dict. Adjust to tuple[dict, str] or return only the dict.

wren-ai-service/src/pipelines/generation/sql_diagnosis.py (1)

92-96: Return type annotation incorrect.

post_process returns a parsed JSON object (dict), but the annotation is -> str.

Apply this diff:

-async def post_process(
-    generate_sql_diagnosis: dict,
-) -> str:
+async def post_process(
+    generate_sql_diagnosis: dict,
+) -> dict:
wren-ai-service/src/pipelines/generation/question_recommendation.py (2)

196-210: Consider validating LLM JSON with Pydantic.

After orjson.loads, validate against QuestionResult to fail fast on malformed outputs and normalize types.

Example (outside this hunk):

try:
    parsed = QuestionResult.model_validate_json(text)
    return parsed.model_dump()
except Exception as e:
    logger.error(f"Validation error: {e}")
    return {"questions": []}

191-193: Return type hint alignment.

generate returns a tuple (dict, generator_name) but annotated as -> dict. Align the annotation or the return shape.

wren-ai-service/src/pipelines/generation/sql_question.py (2)

71-75: Add defensive JSON parsing to avoid crashes on malformed replies

Guard against invalid JSON or empty replies to prevent 500s.

 @observe(capture_input=False)
 def post_process(
     generate_sql_question: dict,
 ) -> str:
-    return orjson.loads(generate_sql_question.get("replies")[0])["question"]
+    try:
+        reply = (generate_sql_question.get("replies") or [None])[0]
+        if not reply:
+            return ""
+        payload = orjson.loads(reply)
+        return payload.get("question", "")
+    except Exception as e:
+        logger.warning("sql_question: failed to parse reply: %s", e)
+        return ""

95-109: Description is stored but unused

If description is intended for metadata/configs, propagate it to pipe metadata; otherwise remove to avoid dead state.

wren-ai-service/src/pipelines/generation/data_assistance.py (1)

130-153: Ensure queues are cleaned up on timeout/error to avoid leaks

If the stream times out, the per-user queue remains in memory. Pop it in a finally block.

     async def get_streaming_results(self, query_id):
         async def _get_streaming_results(query_id):
             return await self._user_queues[query_id].get()

         if query_id not in self._user_queues:
             self._user_queues[
                 query_id
             ] = asyncio.Queue()  # Ensure the user's queue exists
-        while True:
-            try:
+        try:
+            while True:
+                try:
                     # Wait for an item from the user's queue
-                self._streaming_results = await asyncio.wait_for(
-                    _get_streaming_results(query_id), timeout=120
-                )
-                if (
-                    self._streaming_results == "<DONE>"
-                ):  # Check for end-of-stream signal
-                    del self._user_queues[query_id]
-                    break
-                if self._streaming_results:  # Check if there are results to yield
-                    yield self._streaming_results
-                    self._streaming_results = ""  # Clear after yielding
-            except TimeoutError:
-                break
+                    item = await asyncio.wait_for(
+                        _get_streaming_results(query_id), timeout=120
+                    )
+                    if item == "<DONE>":
+                        break
+                    if item:
+                        yield item
+                except asyncio.TimeoutError:
+                    break
+        finally:
+            self._user_queues.pop(query_id, None)
wren-ai-service/src/pipelines/generation/misleading_assistance.py (1)

130-153: Clean up queues on timeout/error

Mirror the leak prevention suggested for data_assistance.

-        while True:
-            try:
+        try:
+            while True:
+                try:
                     # Wait for an item from the user's queue
-                self._streaming_results = await asyncio.wait_for(
-                    _get_streaming_results(query_id), timeout=120
-                )
-                if (
-                    self._streaming_results == "<DONE>"
-                ):  # Check for end-of-stream signal
-                    del self._user_queues[query_id]
-                    break
-                if self._streaming_results:  # Check if there are results to yield
-                    yield self._streaming_results
-                    self._streaming_results = ""  # Clear after yielding
-            except TimeoutError:
-                break
+                    item = await asyncio.wait_for(
+                        _get_streaming_results(query_id), timeout=120
+                    )
+                    if item == "<DONE>":
+                        break
+                    if item:
+                        yield item
+                except asyncio.TimeoutError:
+                    break
+        finally:
+            self._user_queues.pop(query_id, None)
wren-ai-service/src/pipelines/generation/chart_generation.py (1)

137-143: File path robustness: avoid CWD-dependent open()

Use a path relative to this file to prevent failures when CWD differs (e.g., when installed as a package).

-        with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f:
-            _vega_schema = orjson.loads(f.read())
+        from pathlib import Path
+        schema_path = Path(__file__).parent / "utils" / "vega-lite-schema-v5.json"
+        _vega_schema = orjson.loads(schema_path.read_bytes())
wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py (1)

66-87: Avoid instantiating Configuration in default arguments

Use None default and construct inside to prevent unintended shared state.

 def prompt(
     query: str,
     documents: list[str],
     histories: list[AskHistory],
     sql_samples: list[dict],
     instructions: list[dict],
     prompt_builder: PromptBuilder,
-    configuration: Configuration | None = Configuration(),
+    configuration: Configuration | None = None,
 ) -> dict:
+    configuration = configuration or Configuration()
wren-ai-service/src/pipelines/generation/chart_adjustment.py (1)

164-169: File path robustness: avoid CWD-dependent open()

Same as chart_generation: use a path relative to this file.

-        with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f:
-            _vega_schema = orjson.loads(f.read())
+        from pathlib import Path
+        schema_path = Path(__file__).parent / "utils" / "vega-lite-schema-v5.json"
+        _vega_schema = orjson.loads(schema_path.read_bytes())
wren-ai-service/src/pipelines/retrieval/sql_functions.py (2)

24-33: Fix return annotation for _extract

function_type is a string; the type hint says list.

-        def _extract() -> tuple[str, list, str]:
+        def _extract() -> tuple[str, str, str]:

128-135: Avoid shadowing built-in input

Use a different variable name to prevent confusion with the built-in.

-        input = {
+        inputs = {
             "data_source": _data_source,
             "project_id": project_id,
             **self._components,
         }
-        result = await self._pipe.execute(["cache"], inputs=input)
+        result = await self._pipe.execute(["cache"], inputs=inputs)
wren-ai-service/src/pipelines/generation/user_guide_assistance.py (1)

70-76: Return only the generator result (not a tuple)

Downstream doesn’t consume generator_name from the return value; keeping it in the tuple is inconsistent and risks surprises.

-async def user_guide_assistance(
-    prompt: dict, generator: Any, query_id: str, generator_name: str
-) -> dict:
-    return await generator(
-        prompt=prompt.get("prompt"), query_id=query_id
-    ), generator_name
+async def user_guide_assistance(
+    prompt: dict, generator: Any, query_id: str, generator_name: str
+) -> dict:
+    return await generator(prompt=prompt.get("prompt"), query_id=query_id)
wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py (2)

64-76: Avoid defaulting to a constructed Configuration

Don’t use Configuration() as a default value; construct it inside when needed.

-def prompt(
+def prompt(
     query: str,
     documents: list[str],
     sql_samples: list[dict],
-    instructions: list[dict],
+    instructions: list[str] | list[dict],
     prompt_builder: PromptBuilder,
-    configuration: Configuration | None = Configuration(),
+    configuration: Configuration | None = None,
 ) -> dict:
-    _prompt = prompt_builder.run(
+    configuration = configuration or Configuration()
+    _prompt = prompt_builder.run(

167-171: Run signature: avoid constructing default Configuration

Mirror the fix in prompt and pass a fresh instance when absent.

-        configuration: Configuration = Configuration(),
+        configuration: Optional[Configuration] = None,
         query_id: Optional[str] = None,
 ):
@@
-            inputs={
+            inputs={
                 "query": query,
                 "documents": contexts,
                 "sql_samples": sql_samples or [],
                 "instructions": instructions or [],
-                "configuration": configuration,
+                "configuration": configuration or Configuration(),
                 "query_id": query_id,
                 **self._components,
             },
wren-ai-service/src/pipelines/generation/relationship_recommendation.py (2)

76-99: Type hint: function returns a list, not a dict

Adjust return annotation for clarity.

-@observe(capture_input=False)
-def cleaned_models(mdl: dict) -> dict:
+@observe(capture_input=False)
+def cleaned_models(mdl: dict) -> list[dict]:

119-129: Wrapper return type hint is wrong

wrapper returns a dict (or {}), not str.

-    def wrapper(text: str) -> str:
+    def wrapper(text: str) -> dict:
wren-ai-service/src/pipelines/generation/semantics_description.py (1)

155-165: Type hint fix in normalize(): wrapper returns dict.

Align annotation with return value and keep behavior.

-def wrapper(text: str) -> str:
+def wrapper(text: str) -> dict:
wren-ai-service/src/__main__.py (2)

151-174: Simplify pipeline llm/embedder mapping; you already store aliases.

The alias lookup map is model→alias but pipe_service_components already holds alias. Set directly.

-            if llm_model:
-                if llm_model_alias := _llm_model_alias_mapping.get(llm_model):
-                    _configs["pipelines"][pipe_name]["llm"] = llm_model_alias
-                else:
-                    _configs["pipelines"][pipe_name]["llm"] = llm_model
-            if embedding_model:
-                if embedding_model_alias := _embedder_model_alias_mapping.get(
-                    embedding_model
-                ):
-                    _configs["pipelines"][pipe_name]["embedder"] = embedding_model_alias
-                else:
-                    _configs["pipelines"][pipe_name]["embedder"] = embedding_model
+            if llm_model:
+                _configs["pipelines"][pipe_name]["llm"] = llm_model
+            if embedding_model:
+                _configs["pipelines"][pipe_name]["embedder"] = embedding_model

262-266: Return the updated config snapshot.

Saves a follow-up GET and confirms applied state.

-        for pipeline_name, _ in app.state.pipe_components.items():
-            pass
+        for pipeline_name, _ in app.state.pipe_components.items():
+            pass  # placeholder for future metadata updates
+        return get_configs()
wren-ai-service/src/core/pipeline.py (1)

52-56: Nit: make fields Optional[...] and default description empty string.

Improves typing consistency and avoids None checks downstream.

-@dataclass
-class PipelineComponent(Mapping):
-    description: str = None
-    llm_provider: LLMProvider = None
-    embedder_provider: EmbedderProvider = None
-    document_store_provider: DocumentStoreProvider = None
+@dataclass
+class PipelineComponent(Mapping):
+    description: str = ""
+    llm_provider: Optional[LLMProvider] = None
+    embedder_provider: Optional[EmbedderProvider] = None
+    document_store_provider: Optional[DocumentStoreProvider] = None
wren-ai-service/src/pipelines/indexing/instructions.py (4)

66-78: Use explicit 'filters' and avoid shadowing built-in 'filter'.

Minor clarity/readability and avoids shadowing a built-in. Pass the kwarg explicitly to match Haystack API.

Apply this diff:

-        filter = {
+        filters = {
             "operator": "AND",
             "conditions": [
                 {"field": "instruction_id", "operator": "in", "value": instruction_ids},
             ],
         }
 
         if project_id:
-            filter["conditions"].append(
+            filters["conditions"].append(
                 {"field": "project_id", "operator": "==", "value": project_id}
             )
 
-        return await self.store.delete_documents(filter)
+        return await self.store.delete_documents(filters=filters)

101-109: Avoid default mutable argument for 'embedding'.

Safer to use None and default to an empty payload at runtime.

Apply this diff:

-async def clean(
+async def clean(
     cleaner: InstructionsCleaner,
     instructions: List[Instruction],
-    embedding: Dict[str, Any] = {},
+    embedding: Optional[Dict[str, Any]] = None,
     project_id: str = "",
     delete_all: bool = False,
 ) -> Dict[str, Any]:
@@
-    return embedding
+    return embedding or {"documents": []}

116-121: Return type annotation mismatch.

Function returns the writer’s result; annotate accordingly for type checkers.

-async def write(
+async def write(
     clean: Dict[str, Any],
     writer: AsyncDocumentWriter,
-) -> None:
+) -> Dict[str, Any]:
     return await writer.run(documents=clean["documents"])

184-191: Avoid shadowing built-in 'input'.

Rename local variable for clarity.

-        input = {
+        inputs = {
             "project_id": project_id,
             "instructions": instructions,
             **self._components,
         }
 
-        return await self._pipe.execute(["write"], inputs=input)
+        return await self._pipe.execute(["write"], inputs=inputs)
wren-ai-service/src/pipelines/indexing/project_meta.py (1)

59-61: Return type annotation mismatch.

write returns the writer’s result; annotate accordingly.

-async def write(clean: dict[str, Any], writer: DocumentWriter) -> None:
+async def write(clean: dict[str, Any], writer: DocumentWriter) -> dict[str, Any]:
     return await writer.run(documents=clean["documents"])
wren-ai-service/src/providers/embedder/litellm.py (3)

128-142: Harden usage aggregation to avoid KeyError when earlier batches lack usage.

Make accumulation resilient across mixed responses.

-            if "usage" not in meta:
-                meta["usage"] = (
-                    dict(response.usage) if hasattr(response, "usage") else {}
-                )
-            else:
-                if hasattr(response, "usage"):
-                    meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
-                    meta["usage"]["total_tokens"] += response.usage.total_tokens
+            if hasattr(response, "usage"):
+                usage_dict = dict(response.usage)
+                u = meta.setdefault("usage", {})
+                u["prompt_tokens"] = u.get("prompt_tokens", 0) + usage_dict.get("prompt_tokens", 0)
+                u["total_tokens"] = u.get("total_tokens", 0) + usage_dict.get("total_tokens", 0)

52-53: Broaden retry policy to catch common transient errors.

Consider including rate limits/timeouts from the SDK to reduce flakiness.

Example:

-@backoff.on_exception(backoff.expo, openai.APIError, max_time=60.0, max_tries=3)
+@backoff.on_exception(backoff.expo, (openai.APIError, openai.RateLimitError), max_time=60.0, max_tries=3)

And similarly for AsyncDocumentEmbedder.run.

Also applies to: 146-147


18-21: Docstring says metadata is embedded, but only content is used.

Align comment with behavior, or concatenate selected metadata fields if intended.

wren-ai-service/src/pipelines/indexing/historical_question.py (3)

55-56: Return type annotation mismatch for ViewChunker.run.

It returns a dict; annotate accordingly.

-    def run(self, mdl: Dict[str, Any], project_id: Optional[str] = None) -> None:
+    def run(self, mdl: Dict[str, Any], project_id: Optional[str] = None) -> Dict[str, Any]:

128-130: Return type annotation mismatch.

write returns the writer’s result; annotate accordingly.

-async def write(clean: Dict[str, Any], writer: DocumentWriter) -> None:
+async def write(clean: Dict[str, Any], writer: DocumentWriter) -> Dict[str, Any]:
     return await writer.run(documents=clean["documents"])

85-92: Avoid tqdm in production hot path (optional).

Progress bars can add overhead/log noise in services; gate behind log level.

-            "documents": [
-                Document(**chunk)
-                for chunk in tqdm(
-                    chunks,
-                    desc=f"Project ID: {project_id}, Chunking views into documents",
-                )
-            ]
+            "documents": [
+                Document(**chunk)
+                for chunk in (
+                    tqdm(
+                        chunks,
+                        desc=f"Project ID: {project_id}, Chunking views into documents",
+                    )
+                    if logger.isEnabledFor(logging.INFO)
+                    else chunks
+                )
+            ]
wren-ai-service/src/pipelines/indexing/sql_pairs.py (5)

128-131: Avoid mutable default arguments

Use None defaults to prevent shared state across calls.

@@
-async def clean(
-    cleaner: SqlPairsCleaner,
-    sql_pairs: List[SqlPair],
-    embedding: Dict[str, Any] = {},
+async def clean(
+    cleaner: SqlPairsCleaner,
+    sql_pairs: List[SqlPair],
+    embedding: Optional[Dict[str, Any]] = None,
@@
-    sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs]
+    embedding = embedding or {}
+    sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs]
@@
-    async def clean(
-        self,
-        sql_pairs: List[SqlPair] = [],
+    async def clean(
+        self,
+        sql_pairs: Optional[List[SqlPair]] = None,
         project_id: Optional[str] = None,
         delete_all: bool = False,
     ) -> None:
-        await clean(
+        await clean(
             sql_pairs=sql_pairs,
             cleaner=self._components["cleaner"],
             project_id=project_id,
             delete_all=delete_all,
         )

Also applies to: 239-249


81-87: Type-guard boilerplate before .lower()

Guard against non-string values in mdl["models"][*]["properties"]["boilerplate"].

-return {
-    boilerplate.lower()
-    for model in mdl.get("models", [])
-    if (boilerplate := model.get("properties", {}).get("boilerplate"))
-}
+return {
+    boilerplate.lower()
+    for model in mdl.get("models", [])
+    if (boilerplate := model.get("properties", {}).get("boilerplate")) and isinstance(boilerplate, str)
+}

95-104: Ensure SqlPair.id is populated

External data may omit id. Provide a stable fallback.

-        SqlPair(
-            id=pair.get("id"),
+        SqlPair(
+            id=str(pair.get("id") or uuid.uuid4()),
             question=pair.get("question"),
             sql=pair.get("sql"),
         )

224-235: Rename local variable input to inputs

Avoid shadowing built-in input().

-        input = {
+        inputs = {
             "mdl_str": mdl_str,
             "project_id": project_id,
             "external_pairs": {
                 **self._external_pairs,
                 **(external_pairs or {}),
             },
             **self._components,
         }
 
-        return await self._pipe.execute(["write"], inputs=input)
+        return await self._pipe.execute(["write"], inputs=inputs)

36-49: Optional: use stable document ids to make overwrites independent of cleaning

Consider deriving Document.id from sql_pair_id to let DuplicatePolicy.OVERWRITE supersede prior versions even if cleaning is skipped.

-                Document(
-                    id=str(uuid.uuid4()),
+                Document(
+                    id=f"sqlpair:{sql_pair.id}",
wren-ai-service/src/pipelines/generation/intent_classification.py (1)

327-335: Fix typo in constant name for clarity

Rename INTENT_CLASSIFICAION_MODEL_KWARGS → INTENT_CLASSIFICATION_MODEL_KWARGS and update usage.

-INTENT_CLASSIFICAION_MODEL_KWARGS = {
+INTENT_CLASSIFICATION_MODEL_KWARGS = {
@@
-            "generator": self._llm_provider.get_generator(
+            "generator": self._llm_provider.get_generator(
                 system_prompt=intent_classification_system_prompt,
-                generation_kwargs=INTENT_CLASSIFICAION_MODEL_KWARGS,
+                generation_kwargs=INTENT_CLASSIFICATION_MODEL_KWARGS,
             ),

Also applies to: 380-384

wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py (2)

23-33: Avoid shadowing built-in list

Use a neutral name.

     def run(self, documents: List[Document]):
-        list = []
+        items = []
@@
-            list.append(formatted)
+            items.append(formatted)
 
-        return {"documents": list}
+        return {"documents": items}

38-41: Loosen store type to generic DocumentStore

The provider returns an async store; using the generic interface avoids coupling to a specific backend type.

-from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
+from haystack.document_stores.types import DocumentStore
@@
-async def count_documents(
-    store: QdrantDocumentStore, project_id: Optional[str] = None
+async def count_documents(
+    store: DocumentStore, project_id: Optional[str] = None
 ) -> int:
wren-ai-service/src/pipelines/retrieval/instructions.py (2)

23-35: Avoid shadowing built-in list

Rename local var.

     def run(self, documents: List[Document]):
-        list = []
+        items = []
@@
-            list.append(formatted)
+            items.append(formatted)
 
-        return {"documents": list}
+        return {"documents": items}

59-63: Loosen store type to generic DocumentStore

Aligns with provider interface and async usage.

-from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
+from haystack.document_stores.types import DocumentStore
@@
-async def count_documents(
-    store: QdrantDocumentStore, project_id: Optional[str] = None
+async def count_documents(
+    store: DocumentStore, project_id: Optional[str] = None
 ) -> int:
wren-ai-service/src/pipelines/generation/sql_generation.py (1)

157-159: Consider moving super().init() to after instance variables are initialized.

Moving the super().__init__() call after initializing instance variables follows a more typical initialization pattern and ensures all dependencies are ready before the base class potentially uses them.

-        super().__init__(
-            AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
-        )
-
         self._document_store_provider = document_store_provider
         self._retriever = self._document_store_provider.get_retriever(
             self._document_store_provider.get_store("project_meta")
         )
         self._llm_provider = llm_provider
         self._engine = engine
         self._description = description
+        
+        super().__init__(
+            AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
+        )
+        
         self._components = self._update_components()
wren-ai-service/src/utils.py (1)

259-268: Consider making the hardcoded pipeline list configurable.

The has_db_data_in_llm_prompt function hardcodes specific pipeline names. Consider making this list configurable through the Configs class or environment variables for better maintainability.

 def has_db_data_in_llm_prompt(pipe_name: str) -> bool:
-    pipes_containing_db_data = set(
-        [
-            "sql_answer",
-            "chart_adjustment",
-            "chart_generation",
-        ]
-    )
+    # Could be loaded from config or environment
+    pipes_containing_db_data = getattr(
+        settings, 
+        'PIPES_WITH_DB_DATA', 
+        {"sql_answer", "chart_adjustment", "chart_generation"}
+    )
 
     return pipe_name in pipes_containing_db_data
wren-ai-service/src/providers/__init__.py (1)

343-401: Fix: Function signature doesn't match return type annotation.

The function annotation says it returns dict[str, PipelineComponent] but actually returns a tuple. Update the annotation to match the implementation.

-def generate_components(configs: list[dict]) -> dict[str, PipelineComponent]:
+def generate_components(configs: list[dict]) -> tuple[dict[str, PipelineComponent], dict]:
wren-ai-service/src/pipelines/indexing/db_schema.py (1)

363-363: Keep helper.load_helpers(), but run it before component initialization or re-run it on updates.

load_helpers populates MODEL_PREPROCESSORS/COLUMN_PREPROCESSORS/COLUMN_COMMENT_HELPERS used by DDLChunker at runtime; it’s currently called after self._components = self._update_components() in wren-ai-service/src/pipelines/indexing/db_schema.py (DBSchema.init) and is not invoked in update_components(), so dynamically added helpers would not be picked up.

  • Action: Move helper.load_helpers() to before self._components = self._update_components() in DBSchema.init OR call helper.load_helpers() inside update_components() so helpers are loaded/reloaded on dynamic updates.
  • Ensure load_helpers is idempotent / safe to call multiple times.
wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py (1)

511-534: Align update_components with BasicPipeline and remove unused retriever attributes

Callers pass None and _table_retriever/_dbschema_retriever are only assigned here — making providers Optional and dropping those attributes is safe.

-    def update_components(
-        self,
-        llm_provider: LLMProvider,
-        embedder_provider: EmbedderProvider,
-        document_store_provider: DocumentStoreProvider,
-        **_,
-    ):
-        super().update_components(
-            llm_provider=llm_provider,
-            embedder_provider=embedder_provider,
-            document_store_provider=document_store_provider,
-            update_components=False,
-        )
-        self._table_retriever = self._document_store_provider.get_retriever(
-            self._document_store_provider.get_store(dataset_name="table_descriptions"),
-            top_k=self._table_retrieval_size,
-        )
-        self._dbschema_retriever = self._document_store_provider.get_retriever(
-            self._document_store_provider.get_store(),
-            top_k=self._table_column_retrieval_size,
-        )
-        self._components = self._update_components()
-        self._configs = self._update_configs()
+    def update_components(
+        self,
+        llm_provider: Optional[LLMProvider] = None,
+        embedder_provider: Optional[EmbedderProvider] = None,
+        document_store_provider: Optional[DocumentStoreProvider] = None,
+        **_,
+    ):
+        super().update_components(
+            llm_provider=llm_provider,
+            embedder_provider=embedder_provider,
+            document_store_provider=document_store_provider,
+            update_components=False,
+        )
+        self._components = self._update_components()
+        self._configs = self._update_configs()

Comment on lines +99 to +146
@app.get("/configs")
def get_configs():
_configs = {
"env_vars": {},
"providers": {
"llm": [],
"embedder": [],
},
"pipelines": {},
}

_llm_model_alias_mapping = {}
_embedder_model_alias_mapping = {}

_llm_configs = []
for _, model_config in app.state.instantiated_providers["llm"].items():
_llm_config = {
"model": model_config._model,
"alias": model_config._alias,
"context_window_size": model_config._context_window_size,
"timeout": model_config._timeout,
"kwargs": model_config._model_kwargs,
}
if model_config._api_base:
_llm_config["api_base"] = model_config._api_base
if model_config._api_version:
_llm_config["api_version"] = model_config._api_version
_llm_configs.append(_llm_config)
_llm_model_alias_mapping[model_config._model] = model_config._alias
_configs["providers"]["llm"] = _llm_configs

_embedder_configs = []
# we only support one embedding model now
for _, model_config in app.state.instantiated_providers["embedder"].items():
_embedder_config = {
"model": model_config._model,
"alias": model_config._alias,
"dimension": app.state.instantiated_providers["document_store"][
"qdrant"
]._embedding_model_dim,
"timeout": model_config._timeout,
"kwargs": model_config._model_kwargs,
}
if model_config._api_base:
_embedder_config["api_base"] = model_config._api_base
if model_config._api_version:
_embedder_config["api_version"] = model_config._api_version
_embedder_configs.append(_embedder_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Sanitize exposed kwargs to avoid accidental secret leakage.

If any provider kwargs include sensitive fields, filter them before returning.

Proposed filter example:

-            "kwargs": model_config._model_kwargs,
+            "kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items() if k.lower() not in {"api_key", "key", "token", "secret"}},

Run to verify presence of sensitive keys in kwargs across providers:


🏁 Script executed:

#!/bin/bash
rg -nP --type=py -C2 '_model_kwargs' src/providers

Length of output: 53


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "PWD: $(pwd)"
echo "Listing top-level files/dirs:"
ls -la || true

echo
echo "Searching repository for occurrences of '_model_kwargs' and related patterns..."
rg -n --hidden -S "_model_kwargs" -C3 || true

echo
echo "Searching for '\"kwargs':' occurrences (to see where kwargs are returned in JSON responses)..."
rg -n --hidden -S '"kwargs"\s*:' -C3 || true

echo
echo "Searching for 'def get_configs' and __main__ to locate the reported file."
rg -n --hidden -S "def get_configs" -C3 || true
rg -n --hidden -S "__main__\.py" -C3 || true

echo
echo "Searching for 'instantiated_providers' usages (to confirm provider objects structure)..."
rg -n --hidden -S "instantiated_providers" -C3 || true

echo
echo "Searching for common secret-like keys in code (api_key, token, secret) to spot exposed fields."
rg -n --hidden -S -i "api_key|apikey|token|secret|api_secret" -C1 || true

Length of output: 317607


Sanitize exposed provider kwargs in API responses

The API currently returns raw provider kwargs (risk: leaking secrets such as API keys/tokens). Filter sensitive keys (handle None) before returning.

  • Files & places to fix:
    • wren-ai-service/src/main.py — _llm_config and _embedder_config (currently "kwargs": model_config._model_kwargs). Replace with a sanitized dict. Example:
-            "kwargs": model_config._model_kwargs,
+            "kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items()
+                       if k.lower() not in {"api_key","key","token","secret","password","access_key","access_token","secret_key","client_secret","authorization"}},
  • wren-ai-service/src/globals.py — llm_model_kwargs (line ~317) is populated from provider.model_kwargs; sanitize there too using the same filter before exposing.

  • Tests & config:

    • Tests that assert presence/shape of kwargs (e.g. tests/pytest/test_config.py, tests/pytest/providers/test_providers.py) may need updating if they expect raw values.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@app.get("/configs")
def get_configs():
_configs = {
"env_vars": {},
"providers": {
"llm": [],
"embedder": [],
},
"pipelines": {},
}
_llm_model_alias_mapping = {}
_embedder_model_alias_mapping = {}
_llm_configs = []
for _, model_config in app.state.instantiated_providers["llm"].items():
_llm_config = {
"model": model_config._model,
"alias": model_config._alias,
"context_window_size": model_config._context_window_size,
"timeout": model_config._timeout,
"kwargs": model_config._model_kwargs,
}
if model_config._api_base:
_llm_config["api_base"] = model_config._api_base
if model_config._api_version:
_llm_config["api_version"] = model_config._api_version
_llm_configs.append(_llm_config)
_llm_model_alias_mapping[model_config._model] = model_config._alias
_configs["providers"]["llm"] = _llm_configs
_embedder_configs = []
# we only support one embedding model now
for _, model_config in app.state.instantiated_providers["embedder"].items():
_embedder_config = {
"model": model_config._model,
"alias": model_config._alias,
"dimension": app.state.instantiated_providers["document_store"][
"qdrant"
]._embedding_model_dim,
"timeout": model_config._timeout,
"kwargs": model_config._model_kwargs,
}
if model_config._api_base:
_embedder_config["api_base"] = model_config._api_base
if model_config._api_version:
_embedder_config["api_version"] = model_config._api_version
_embedder_configs.append(_embedder_config)
@app.get("/configs")
def get_configs():
_configs = {
"env_vars": {},
"providers": {
"llm": [],
"embedder": [],
},
"pipelines": {},
}
_llm_model_alias_mapping = {}
_embedder_model_alias_mapping = {}
_llm_configs = []
for _, model_config in app.state.instantiated_providers["llm"].items():
_llm_config = {
"model": model_config._model,
"alias": model_config._alias,
"context_window_size": model_config._context_window_size,
"timeout": model_config._timeout,
"kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items()
if k.lower() not in {"api_key", "key", "token", "secret", "password", "access_key", "access_token", "secret_key", "client_secret", "authorization"}},
}
if model_config._api_base:
_llm_config["api_base"] = model_config._api_base
if model_config._api_version:
_llm_config["api_version"] = model_config._api_version
_llm_configs.append(_llm_config)
_llm_model_alias_mapping[model_config._model] = model_config._alias
_configs["providers"]["llm"] = _llm_configs
_embedder_configs = []
# we only support one embedding model now
for _, model_config in app.state.instantiated_providers["embedder"].items():
_embedder_config = {
"model": model_config._model,
"alias": model_config._alias,
"dimension": app.state.instantiated_providers["document_store"][
"qdrant"
]._embedding_model_dim,
"timeout": model_config._timeout,
"kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items()
if k.lower() not in {"api_key", "key", "token", "secret", "password", "access_key", "access_token", "secret_key", "client_secret", "authorization"}},
}
if model_config._api_base:
_embedder_config["api_base"] = model_config._api_base
if model_config._api_version:
_embedder_config["api_version"] = model_config._api_version
_embedder_configs.append(_embedder_config)

Comment on lines +178 to +206
@app.post("/configs")
def update_configs(configs_request: Configs):
try:
# override current instantiated_providers
app.state.instantiated_providers["embedder"] = {
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.__dict__
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.__dict__
)
for llm_provider in configs_request.providers.llm
}
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=app.state.instantiated_providers["document_store"][
"qdrant"
]._location,
api_key=app.state.instantiated_providers["document_store"][
"qdrant"
]._api_key,
timeout=app.state.instantiated_providers["document_store"][
"qdrant"
]._timeout,
embedding_model_dim=configs_request.providers.embedder[0].dimension,
recreate_index=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against empty embedder list.

IndexError if providers.embedder is empty. Add checks or 400.

-def update_configs(configs_request: Configs):
+def update_configs(configs_request: Configs):
     try:
+        if not configs_request.providers.embedder:
+            raise HTTPException(status_code=400, detail="At least one embedder is required")
         # override current instantiated_providers
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@app.post("/configs")
def update_configs(configs_request: Configs):
try:
# override current instantiated_providers
app.state.instantiated_providers["embedder"] = {
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.__dict__
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.__dict__
)
for llm_provider in configs_request.providers.llm
}
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=app.state.instantiated_providers["document_store"][
"qdrant"
]._location,
api_key=app.state.instantiated_providers["document_store"][
"qdrant"
]._api_key,
timeout=app.state.instantiated_providers["document_store"][
"qdrant"
]._timeout,
embedding_model_dim=configs_request.providers.embedder[0].dimension,
recreate_index=True,
)
@app.post("/configs")
def update_configs(configs_request: Configs):
try:
if not configs_request.providers.embedder:
raise HTTPException(status_code=400, detail="At least one embedder is required")
# override current instantiated_providers
app.state.instantiated_providers["embedder"] = {
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.__dict__
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.__dict__
)
for llm_provider in configs_request.providers.llm
}
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=app.state.instantiated_providers["document_store"][
"qdrant"
]._location,
api_key=app.state.instantiated_providers["document_store"][
"qdrant"
]._api_key,
timeout=app.state.instantiated_providers["document_store"][
"qdrant"
]._timeout,
embedding_model_dim=configs_request.providers.embedder[0].dimension,
recreate_index=True,
)
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 178 to 206, the code assumes
configs_request.providers.embedder has at least one item and will raise
IndexError when empty; before accessing
configs_request.providers.embedder[0].dimension, validate that
providers.embedder is not empty and if it is return a 400 HTTP response (raise
fastapi.HTTPException(status_code=400, detail="...")) or otherwise handle the
missing embedder (e.g., require a default embedder or skip updating
document_store.embedding_model_dim), then proceed to instantiate providers;
ensure the validation occurs before any use of embedder[0] so the endpoint fails
fast with a clear error message.

Comment on lines +183 to +193
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.__dict__
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.__dict__
)
for llm_provider in configs_request.providers.llm
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid Pydantic dict; use model_dump(exclude_none=True).

dict may include internals; model_dump is stable/safe.

-            f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
-                **embedder_provider.__dict__
-            )
+            f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
+                **embedder_provider.model_dump(exclude_none=True)
+            )
             for embedder_provider in configs_request.providers.embedder
         }
         app.state.instantiated_providers["llm"] = {
-            f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
-                **llm_provider.__dict__
-            )
+            f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
+                **llm_provider.model_dump(exclude_none=True)
+            )
             for llm_provider in configs_request.providers.llm
         }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.__dict__
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.__dict__
)
for llm_provider in configs_request.providers.llm
}
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
**embedder_provider.model_dump(exclude_none=True)
)
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
**llm_provider.model_dump(exclude_none=True)
)
for llm_provider in configs_request.providers.llm
}
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 183 to 193, the code uses
Pydantic objects' __dict__ to expand kwargs for LitellmEmbedderProvider and
LitellmLLMProvider; replace each use of embedder_provider.__dict__ and
llm_provider.__dict__ with embedder_provider.model_dump(exclude_none=True) and
llm_provider.model_dump(exclude_none=True) respectively so only validated fields
are passed and internal attributes are omitted. Ensure both comprehensions call
model_dump(exclude_none=True) on each provider instance and no other behavior
changes are introduced.

Comment on lines +194 to +206
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=app.state.instantiated_providers["document_store"][
"qdrant"
]._location,
api_key=app.state.instantiated_providers["document_store"][
"qdrant"
]._api_key,
timeout=app.state.instantiated_providers["document_store"][
"qdrant"
]._timeout,
embedding_model_dim=configs_request.providers.embedder[0].dimension,
recreate_index=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Danger: unconditional Qdrant reindex can drop data. Gate on dimension change.

Recreate only if the embedding dim actually changes (or behind an explicit flag).

-        app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
-            location=app.state.instantiated_providers["document_store"][
-                "qdrant"
-            ]._location,
-            api_key=app.state.instantiated_providers["document_store"][
-                "qdrant"
-            ]._api_key,
-            timeout=app.state.instantiated_providers["document_store"][
-                "qdrant"
-            ]._timeout,
-            embedding_model_dim=configs_request.providers.embedder[0].dimension,
-            recreate_index=True,
-        )
+        _current_qdrant = app.state.instantiated_providers["document_store"]["qdrant"]
+        _new_dim = (
+            configs_request.providers.embedder[0].dimension
+            if configs_request.providers.embedder
+            else _current_qdrant._embedding_model_dim
+        )
+        _should_recreate = _new_dim != _current_qdrant._embedding_model_dim
+        app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
+            location=_current_qdrant._location,
+            api_key=_current_qdrant._api_key,
+            timeout=_current_qdrant._timeout,
+            embedding_model_dim=_new_dim,
+            recreate_index=_should_recreate,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=app.state.instantiated_providers["document_store"][
"qdrant"
]._location,
api_key=app.state.instantiated_providers["document_store"][
"qdrant"
]._api_key,
timeout=app.state.instantiated_providers["document_store"][
"qdrant"
]._timeout,
embedding_model_dim=configs_request.providers.embedder[0].dimension,
recreate_index=True,
)
_current_qdrant = app.state.instantiated_providers["document_store"]["qdrant"]
_new_dim = (
configs_request.providers.embedder[0].dimension
if configs_request.providers.embedder
else _current_qdrant._embedding_model_dim
)
_should_recreate = _new_dim != _current_qdrant._embedding_model_dim
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
location=_current_qdrant._location,
api_key=_current_qdrant._api_key,
timeout=_current_qdrant._timeout,
embedding_model_dim=_new_dim,
recreate_index=_should_recreate,
)
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 194 to 206, the code
unconditionally sets recreate_index=True when instantiating the QdrantProvider
which can drop existing data; change this to only recreate when the embedder
dimension actually changes or when an explicit override flag is set. Implement a
check that reads the current provider's embedding_model_dim (if present) and
compares it to configs_request.providers.embedder[0].dimension and set
recreate_index=True only if they differ; otherwise set recreate_index=False.
Also add support for an explicit override (e.g. a config or env flag like
configs_request.providers.document_store.force_recreate or
app.state.force_recreate) that, when true, forces recreate_index=True regardless
of dimension comparison.

Comment on lines +170 to +182
def update_components(
self,
llm_provider: LLMProvider,
document_store_provider: DocumentStoreProvider,
**_,
):
super().update_components(
llm_provider=llm_provider, document_store_provider=document_store_provider
)
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing call to self._components = self._update_components() in update_components method.

The update_components method updates providers and retriever but doesn't rebuild the components dictionary. This differs from the pattern in other pipelines like followup_sql_generation.py where self._components = self._update_components() is called at the end.

Apply this fix to ensure components are rebuilt with the updated providers:

     def update_components(
         self,
         llm_provider: LLMProvider,
         document_store_provider: DocumentStoreProvider,
         **_,
     ):
         super().update_components(
             llm_provider=llm_provider, document_store_provider=document_store_provider
         )
         self._retriever = self._document_store_provider.get_retriever(
             self._document_store_provider.get_store("project_meta")
         )
+        self._components = self._update_components()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def update_components(
self,
llm_provider: LLMProvider,
document_store_provider: DocumentStoreProvider,
**_,
):
super().update_components(
llm_provider=llm_provider, document_store_provider=document_store_provider
)
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
def update_components(
self,
llm_provider: LLMProvider,
document_store_provider: DocumentStoreProvider,
**_,
):
super().update_components(
llm_provider=llm_provider, document_store_provider=document_store_provider
)
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
self._components = self._update_components()
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/generation/sql_generation.py around lines 170
to 182, the update_components method updates providers and the retriever but
does not rebuild the components dict; add a final call to self._components =
self._update_components() (exactly as other pipelines do) so the pipeline's
components are reconstructed using the updated llm_provider and
document_store_provider.

Comment on lines +488 to 509
def _update_components(self):
return {
"embedder": self._embedder_provider.get_text_embedder(),
"table_retriever": self._document_store_provider.get_retriever(
self._document_store_provider.get_store(
dataset_name="table_descriptions"
),
top_k=self._table_retrieval_size,
),
"dbschema_retriever": document_store_provider.get_retriever(
document_store_provider.get_store(),
top_k=table_column_retrieval_size,
"dbschema_retriever": self._document_store_provider.get_retriever(
self._document_store_provider.get_store(),
top_k=self._table_column_retrieval_size,
),
"table_columns_selection_generator": llm_provider.get_generator(
"table_columns_selection_generator": self._llm_provider.get_generator(
system_prompt=table_columns_selection_system_prompt,
generation_kwargs=RETRIEVAL_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"generator_name": self._llm_provider.model,
"prompt_builder": PromptBuilder(
template=table_columns_selection_user_prompt_template
),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against context overflow when prompting the selection LLM.

When token count exceeds the context window you still send all DDLs to the generator to select columns, which can exceed the same context window. Consider chunking the DDLs, pre‑filtering tables (e.g., by semantic similarity or name constraints), or reserving headroom for prompt overhead.

I can draft a chunked selection pass (N tables per call with merge) if helpful.

🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py around lines
488-509, the selection LLM is sent all DDLs unbounded which can exceed the model
context window; modify _update_components and the selection flow to enforce a
token-safe pipeline: compute a token budget for the generator (context window
minus prompt overhead), pre-filter candidate tables by simple heuristics (name
patterns, schema size) and/or a fast semantic similarity pass to reduce
candidates, split remaining DDLs into chunks that fit the token budget (N tables
per call), call the generator iteratively for each chunk and merge the partial
selections into a final set, and ensure generator calls include reserved
headroom for system/user prompt and any completion; update any config to expose
chunk size and token headroom.

Comment on lines +501 to +505
"table_columns_selection_generator": self._llm_provider.get_generator(
system_prompt=table_columns_selection_system_prompt,
generation_kwargs=RETRIEVAL_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"generator_name": self._llm_provider.model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid logging chain‑of‑thought; tighten output schema to reduce tokens.

The system prompt asks for step‑by‑step “chain_of_thought_reasoning”. Combined with @observe(as_type="generation", capture_input=False) the output may still be logged by Langfuse. Either disable output capture on the generation step or drop CoT from the schema to brief rationales.

Outside this hunk, update the decorator:

@observe(as_type="generation", capture_input=False, capture_output=False)

Optionally, rename chain_of_thought_reasoning to a short rationales array and instruct the model to produce terse bullet points only.

🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py around lines
501 to 505, the system prompt requests step-by-step "chain_of_thought_reasoning"
which risks logging CoT tokens; remove or shorten CoT by replacing that field
with a brief "rationales" array of terse bullet points (or drop it entirely),
and change the Langfuse observe decorator (outside this hunk) to disable output
capture by using capture_output=False (i.e., @observe(as_type="generation",
capture_input=False, capture_output=False)); ensure the generator prompt
instructs concise rationales only and update any schema/field names accordingly.

Comment on lines +92 to +99
def _update_configs(self):
_model = (self._llm_provider.model,)
if _model == "gpt-4o-mini" or _model == "gpt-4o":
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")

self._configs = {
return {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix model tuple bug and robustly select encoding

current code wraps model in a tuple and compares with ==, so the 'o200k_base' branch never triggers.

     def _update_configs(self):
-        _model = (self._llm_provider.model,)
-        if _model == "gpt-4o-mini" or _model == "gpt-4o":
+        _model = self._llm_provider.model
+        if "gpt-4o" in _model:
             _encoding = tiktoken.get_encoding("o200k_base")
         else:
             _encoding = tiktoken.get_encoding("cl100k_base")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _update_configs(self):
_model = (self._llm_provider.model,)
if _model == "gpt-4o-mini" or _model == "gpt-4o":
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")
self._configs = {
return {
def _update_configs(self):
_model = self._llm_provider.model
if "gpt-4o" in _model:
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")
return {
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py around lines
92 to 99, the code incorrectly wraps the model name in a tuple and compares with
== so the o200k_base branch never matches; change to read the model string
directly (e.g., model = self._llm_provider.model), guard for None, and select
encoding robustly by checking the model name (for example use
model.startswith("gpt-4o") or check membership in an explicit set of names) to
return tiktoken.get_encoding("o200k_base") for gpt-4o variants and
tiktoken.get_encoding("cl100k_base") otherwise.

Comment on lines +106 to 110
def update_components(self, document_store_provider: DocumentStoreProvider, **_):
self._document_store_provider = document_store_provider
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Align update_components signature with Base + clear stale cache

This override diverges from BasicPipeline.update_components(...), which can break central reconfiguration calls. Also clear the TTL cache when the underlying store/retriever changes.

-    def update_components(self, document_store_provider: DocumentStoreProvider, **_):
-        self._document_store_provider = document_store_provider
-        self._retriever = self._document_store_provider.get_retriever(
-            self._document_store_provider.get_store("project_meta")
-        )
+    def update_components(
+        self,
+        llm_provider=None,
+        embedder_provider=None,
+        document_store_provider: DocumentStoreProvider | None = None,
+        update_components: bool = True,
+        **_,
+    ):
+        # keep base state consistent; we don't need base to rebuild components here
+        super().update_components(
+            llm_provider=llm_provider,
+            embedder_provider=embedder_provider,
+            document_store_provider=document_store_provider,
+            update_components=False,
+        )
+        if document_store_provider is not None:
+            self._document_store_provider = document_store_provider
+        self._retriever = self._document_store_provider.get_retriever(
+            self._document_store_provider.get_store("project_meta")
+        )
+        # avoid returning stale results across provider swaps
+        self._cache.clear()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def update_components(self, document_store_provider: DocumentStoreProvider, **_):
self._document_store_provider = document_store_provider
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
def update_components(
self,
llm_provider=None,
embedder_provider=None,
document_store_provider: DocumentStoreProvider | None = None,
update_components: bool = True,
**_,
):
# keep base state consistent; we don't need base to rebuild components here
super().update_components(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
update_components=False,
)
if document_store_provider is not None:
self._document_store_provider = document_store_provider
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
# avoid returning stale results across provider swaps
self._cache.clear()
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/sql_functions.py around lines 106 to
110, the update_components override uses a different signature than
BasicPipeline.update_components and doesn't clear cached values when the
underlying store/retriever changes; change the method signature to match the
base (accept the same positional and keyword args, e.g., (*args, **kwargs) or
the exact parameters used by BasicPipeline), call super().update_components(...)
so central reconfiguration still works, then set self._document_store_provider
and recompute self._retriever as before, and finally clear any TTL/lru caches
associated with cached helper methods or properties that depend on the
store/retriever (call .cache_clear() on those decorated functions or otherwise
invalidate the cache) so no stale data remains.

Comment on lines 195 to 203
def get_text_embedder(self):
return AsyncTextEmbedder(
api_key=self._api_key,
api_base_url=self._api_base,
model=self._embedding_model,
api_version=self._api_version,
model=self._model,
timeout=self._timeout,
**self._kwargs,
**self._model_kwargs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Constructor arg name mismatch: AsyncTextEmbedder expects 'api_base', not 'api_base_url'.

This will raise a TypeError at runtime.

     def get_text_embedder(self):
         return AsyncTextEmbedder(
             api_key=self._api_key,
-            api_base_url=self._api_base,
+            api_base=self._api_base,
             api_version=self._api_version,
             model=self._model,
             timeout=self._timeout,
             **self._model_kwargs,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_text_embedder(self):
return AsyncTextEmbedder(
api_key=self._api_key,
api_base_url=self._api_base,
model=self._embedding_model,
api_version=self._api_version,
model=self._model,
timeout=self._timeout,
**self._kwargs,
**self._model_kwargs,
)
def get_text_embedder(self):
return AsyncTextEmbedder(
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
model=self._model,
timeout=self._timeout,
**self._model_kwargs,
)
🤖 Prompt for AI Agents
In wren-ai-service/src/providers/embedder/litellm.py around lines 195 to 203,
the constructor call uses the wrong keyword name api_base_url while
AsyncTextEmbedder expects api_base; update the argument name to
api_base=self._api_base (leave other args as-is) so the correct keyword is
passed and avoid the TypeError at runtime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant