-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat(wren-ai-service): configs-api #1962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds runtime configuration endpoints (/configs GET/POST), restructures startup to store providers and pipe components, and propagates description/alias metadata. Refactors provider interfaces to properties, introduces per-pipeline update_components across many pipelines, and renames some indexing pipeline keys. Adds Configs schema and pipe metadata builder. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as Client
participant S as Service (FastAPI)
participant C as generate_components
participant SC as ServiceContainer
participant G as create_pipe_components
rect rgba(230,245,255,0.5)
note over S: Startup
S->>C: generate_components(configs)
C-->>S: (pipe_components, instantiated_providers)
S->>SC: build services with providers
S->>G: create_pipe_components(ServiceContainer)
G-->>S: pipe_service_components
S->>S: store in app.state
end
rect rgba(240,255,230,0.5)
note over U,S: Introspection
U->>S: GET /configs
S-->>U: env_vars, providers (llm/embedder), pipelines + linkage
end
rect rgba(255,245,230,0.5)
note over U,S: Live reconfiguration
U->>S: POST /configs (Configs)
S->>S: rebuild providers (LLM/Embedder/Qdrant)
S->>SC: update pipelines via update_components(...)
S->>G: create_pipe_components(...)
S-->>U: 200 OK (updated snapshot)
end
sequenceDiagram
autonumber
participant Svc as Service
participant P as Pipeline Instance
participant Prov as Providers (LLM/Embedder/Store)
note over Svc,P: Dynamic pipeline update
Svc->>P: update_components(llm_provider, embedder_provider, document_store_provider)
P->>P: set internal refs
alt needs recompute
P->>P: _update_components()
end
P-->>Svc: components refreshed
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (26)
wren-ai-service/src/providers/document_store/qdrant.py (2)
381-394
: Dangerous boolean parsing for SHOULD_FORCE_DEPLOY can trigger destructive resets.
bool(os.getenv(...))
is True for any non-empty string (e.g., "false"). Combined with the new conditional reset, this can still wipe collections unintentionally.- recreate_index: bool = ( - bool(os.getenv("SHOULD_FORCE_DEPLOY")) - if os.getenv("SHOULD_FORCE_DEPLOY") - else False - ), + recreate_index: bool = ( + os.getenv("SHOULD_FORCE_DEPLOY", "").strip().lower() + in {"1", "true", "t", "yes", "y", "on"} + ), **_, ): self._location = location self._api_key = Secret.from_token(api_key) if api_key else None self._timeout = timeout self._embedding_model_dim = embedding_model_dim - if recreate_index: - self._reset_document_store(recreate_index) + if recreate_index: + self._reset_document_store(recreate_index)
395-402
: Collection names out of sync with new indexing keys.Reset still targets "sql_pairs" and "instructions" but upstream now uses "..._indexing". This will leave stale/new collections unmanaged.
def _reset_document_store(self, recreate_index: bool): self.get_store(recreate_index=recreate_index) self.get_store(dataset_name="table_descriptions", recreate_index=recreate_index) self.get_store(dataset_name="view_questions", recreate_index=recreate_index) - self.get_store(dataset_name="sql_pairs", recreate_index=recreate_index) - self.get_store(dataset_name="instructions", recreate_index=recreate_index) + self.get_store(dataset_name="sql_pairs_indexing", recreate_index=recreate_index) + self.get_store(dataset_name="instructions_indexing", recreate_index=recreate_index) self.get_store(dataset_name="project_meta", recreate_index=recreate_index)wren-ai-service/src/pipelines/generation/sql_answer.py (2)
152-153
: Catch asyncio.TimeoutError, not built-in TimeoutError.
asyncio.wait_for
raisesasyncio.TimeoutError
. Current except block won’t trigger.Apply this diff:
- except TimeoutError: + except asyncio.TimeoutError: break
155-165
: Avoid dynamic default evaluated at import time.
current_time: str = Configuration().show_current_time()
is evaluated once when the function is defined, not per call.Apply this diff:
- async def run( + async def run( self, query: str, sql: str, sql_data: dict, language: str, - current_time: str = Configuration().show_current_time(), + current_time: Optional[str] = None, query_id: Optional[str] = None, custom_instruction: Optional[str] = None, ) -> dict: @@ - return await self._pipe.execute( + if current_time is None: + current_time = Configuration().show_current_time() + return await self._pipe.execute( ["generate_answer"], inputs={ "query": query, "sql": sql, "sql_data": sql_data, "language": language, "current_time": current_time, "query_id": query_id, "custom_instruction": custom_instruction or "", **self._components, }, )wren-ai-service/src/pipelines/generation/question_recommendation.py (2)
163-186
: Template usescategories
, but prompt() doesn’t pass it.
user_prompt_template
referencescategories
, yetprompt()
neither accepts nor forwards it toPromptBuilder.run
, so categories are silently ignored.Proposed change (outside this hunk):
def prompt( previous_questions: list[str], documents: list, language: str, max_questions: int, max_categories: int, categories: list[str], # add prompt_builder: PromptBuilder, ) -> dict: _prompt = prompt_builder.run( documents=documents, previous_questions=previous_questions, language=language, max_questions=max_questions, max_categories=max_categories, categories=categories, # add ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}And add
"categories": categories
to the inputs of any upstreamexecute
call if missing.
265-270
: Avoid mutable default args.
previous_questions: list[str] = []
andcategories: list[str] = []
share state across calls.Apply this diff:
- previous_questions: list[str] = [], - categories: list[str] = [], + previous_questions: Optional[list[str]] = None, + categories: Optional[list[str]] = None, @@ - return await self._pipe.execute( + previous_questions = previous_questions or [] + categories = categories or [] + return await self._pipe.execute( [self._final], inputs={ "documents": contexts, "previous_questions": previous_questions, "categories": categories,wren-ai-service/src/pipelines/generation/sql_question.py (1)
62-68
: Return type mismatch: generation step returns a tuple but is typed/used as dictgenerate_sql_question is annotated to return dict, but it returns (dict, generator_name). post_process consumes it as a dict. Align by returning only the generator result.
async def generate_sql_question( - prompt: dict, generator: Any, generator_name: str + prompt: dict, generator: Any, generator_name: str ) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/generation/data_assistance.py (2)
80-87
: Return type mismatch: generation step returns a tuple but is typed as dictReturning (dict, generator_name) from data_assistance changes the final pipeline output shape. Return only the generator result.
async def data_assistance( prompt: dict, generator: Any, query_id: str, generator_name: str ) -> dict: - return await generator( + return await generator( prompt=prompt.get("prompt"), query_id=query_id, - ), generator_name + )
152-153
: Catch asyncio.TimeoutError, not TimeoutErrorasyncio.wait_for raises asyncio.TimeoutError.
- except TimeoutError: + except asyncio.TimeoutError: breakwren-ai-service/src/pipelines/generation/misleading_assistance.py (2)
80-87
: Return type mismatch: generation step returns a tuple but is typed as dictReturn only the generator result to keep the pipeline output stable.
async def misleading_assistance( prompt: dict, generator: Any, query_id: str, generator_name: str ) -> dict: - return await generator( + return await generator( prompt=prompt.get("prompt"), query_id=query_id, - ), generator_name + )
152-153
: Catch asyncio.TimeoutErrorSame issue as data_assistance.
- except TimeoutError: + except asyncio.TimeoutError: breakwren-ai-service/src/pipelines/generation/chart_generation.py (1)
88-92
: Return type mismatch: generation step returns a tuple but is typed/used as dictpost_process expects generate_chart: dict. Return only the generator result.
async def generate_chart(prompt: dict, generator: Any, generator_name: str) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py (2)
92-102
: Return type mismatch: generation step returns a tuple but is typed/consumed as dictpost_process uses generate_sql_reasoning as dict. Return only the generator result.
async def generate_sql_reasoning( prompt: dict, generator: Any, query_id: str, generator_name: str, ) -> dict: - return await generator( + return await generator( prompt=prompt.get("prompt"), query_id=query_id, - ), generator_name + )
172-173
: Catch asyncio.TimeoutErrorSame asyncio-specific exception applies here.
- except TimeoutError: + except asyncio.TimeoutError: breakwren-ai-service/src/pipelines/generation/chart_adjustment.py (1)
114-121
: Return type mismatch: generation step returns a tuple but is typed/used as dictpost_process expects dict. Return only generator result.
async def generate_chart_adjustment( prompt: dict, generator: Any, generator_name: str, ) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/generation/user_guide_assistance.py (1)
145-146
: Catch asyncio.TimeoutError, not TimeoutError
asyncio.wait_for
raisesasyncio.TimeoutError
.- except TimeoutError: + except asyncio.TimeoutError: breakwren-ai-service/src/pipelines/generation/sql_generation_reasoning.py (2)
81-87
: Bug: returning a tuple breaks post_process expecting a dict
post_process()
calls.get("replies")
on the output of this step; returning(dict, generator_name)
will raise at runtime.-async def generate_sql_reasoning( - prompt: dict, generator: Any, query_id: str, generator_name: str -) -> dict: - return await generator( - prompt=prompt.get("prompt"), query_id=query_id - ), generator_name +async def generate_sql_reasoning( + prompt: dict, generator: Any, query_id: str, generator_name: str +) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id)
157-158
: Catch asyncio.TimeoutErrorSame issue as other pipelines.
- except TimeoutError: + except asyncio.TimeoutError: breakwren-ai-service/src/pipelines/generation/relationship_recommendation.py (1)
113-115
: Bug: return only the generator result
normalized(generate: dict)
expects a dict; returning(dict, generator_name)
will fail.-async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name +async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/generation/sql_regeneration.py (1)
130-136
: Bug: return only the generator result
post_process
consumesregenerate_sql
as a dict; tuple return will break.-async def regenerate_sql( - prompt: dict, - generator: Any, - generator_name: str, -) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name +async def regenerate_sql( + prompt: dict, + generator: Any, + generator_name: str, +) -> dict: + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/generation/semantics_description.py (1)
148-151
: Fix: generate() must return a dict, not a tuple.Downstream expects dict (normalize() calls .get). Return the generator result and (optionally) inject generator_name into it.
-async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name +async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: + result = await generator(prompt=prompt.get("prompt")) + # optional: expose generator_name for observability + result["generator_name"] = generator_name + return resultwren-ai-service/src/pipelines/generation/sql_correction.py (1)
101-105
: Fix: generate_sql_correction() must return a dict, not a tuple.post_process() calls .get on its input.
-async def generate_sql_correction( - prompt: dict, generator: Any, generator_name: str -) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name +async def generate_sql_correction( + prompt: dict, generator: Any, generator_name: str +) -> dict: + result = await generator(prompt=prompt.get("prompt")) + result["generator_name"] = generator_name # optional + return resultwren-ai-service/src/pipelines/generation/followup_sql_generation.py (1)
121-128
: Bug: generate_sql_in_followup returns a tuple but downstream expects a dict.post_process calls .get("replies"), which fails on a tuple. Return a dict with "replies" and "generator_name".
async def generate_sql_in_followup( prompt: dict, generator: Any, histories: list[AskHistory], generator_name: str ) -> dict: history_messages = construct_ask_history_messages(histories) - return await generator( - prompt=prompt.get("prompt"), history_messages=history_messages - ), generator_name + replies = await generator( + prompt=prompt.get("prompt"), history_messages=history_messages + ) + return {"replies": replies, "generator_name": generator_name}wren-ai-service/src/pipelines/generation/intent_classification.py (1)
294-297
: Return the generator result dict, not a tupleReturning (dict, generator_name) breaks downstream post_process expecting a dict.
@observe(as_type="generation", capture_input=False) @trace_cost async def classify_intent(prompt: dict, generator: Any, generator_name: str) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name + return await generator(prompt=prompt.get("prompt"))wren-ai-service/src/pipelines/retrieval/instructions.py (1)
169-179
: default_instructions returns list but callers expect dict; fix type and return shapeThis currently raises when filtered_documents is truthy and default_instructions is [].
-@observe(capture_input=False) -def default_instructions( +@observe(capture_input=False) +def default_instructions( count_documents: int, retriever: Any, project_id: str, scope_filter: ScopeFilter, scope: str, -) -> list[Document]: - if not count_documents: - return [] +) -> dict: + if not count_documents: + return {"documents": []} @@ - return dict(documents=res.get("documents")) + return dict(documents=res.get("documents"))Also applies to: 133-166
wren-ai-service/src/pipelines/indexing/sql_pairs.py (1)
125-136
: Fix delete_all: current filter never deletes anything when delete_all=Trueclean() currently passes an empty sql_pair_ids -> the store builds an "in: []" filter (matches nothing). Pass delete_all through and build filters conditionally. Consider requiring project_id when delete_all=True to avoid accidental global wipes.
@@ @component class SqlPairsCleaner: - def __init__(self, sql_pairs_store: DocumentStore) -> None: + def __init__(self, sql_pairs_store: DocumentStore) -> None: self.store = sql_pairs_store @component.output_types() async def run( - self, sql_pair_ids: List[str], project_id: Optional[str] = None + self, + sql_pair_ids: List[str], + project_id: Optional[str] = None, + delete_all: bool = False, ) -> None: - filter = { - "operator": "AND", - "conditions": [ - {"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids}, - ], - } - - if project_id: - filter["conditions"].append( - {"field": "project_id", "operator": "==", "value": project_id} - ) - - return await self.store.delete_documents(filter) + conditions = [] + if not delete_all: + conditions.append( + {"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids} + ) + if project_id: + conditions.append( + {"field": "project_id", "operator": "==", "value": project_id} + ) + filters = {"operator": "AND", "conditions": conditions} if conditions else None + return await self.store.delete_documents(filters) @@ async def clean( cleaner: SqlPairsCleaner, sql_pairs: List[SqlPair], embedding: Dict[str, Any] = {}, project_id: str = "", delete_all: bool = False, ) -> Dict[str, Any]: @@ - if sql_pair_ids or delete_all: - await cleaner.run(sql_pair_ids=sql_pair_ids, project_id=project_id) + if sql_pair_ids or delete_all: + await cleaner.run( + sql_pair_ids=sql_pair_ids, + project_id=project_id, + delete_all=delete_all, + )Also applies to: 53-74
🧹 Nitpick comments (57)
wren-ai-service/src/web/v1/services/semantics_preparation.py (1)
129-134
: Fix error message: missing f-string and wrong field.Currently the braces are literal and it references a non-existent attribute
id
. Usemdl_hash
with an f-string.- return SemanticsPreparationStatusResponse( - status="failed", - error=SemanticsPreparationStatusResponse.SemanticsPreparationError( - code="OTHERS", - message="{prepare_semantics_status_request.id} is not found", - ), - ) + return SemanticsPreparationStatusResponse( + status="failed", + error=SemanticsPreparationStatusResponse.SemanticsPreparationError( + code="OTHERS", + message=f"{prepare_semantics_status_request.mdl_hash} is not found", + ), + )wren-ai-service/src/providers/document_store/qdrant.py (2)
161-163
: Creating payload index on every init may raise if already exists.Guard for idempotency to avoid noisy errors on restarts.
- self.client.create_payload_index( - collection_name=index, field_name="project_id", field_schema="keyword" - ) + try: + self.client.create_payload_index( + collection_name=index, field_name="project_id", field_schema="keyword" + ) + except Exception as e: + logger.debug("create_payload_index skipped: %s", e)
351-357
: Boolean coalescing drops explicit False.Using
or
ignores caller-supplied False forscale_score
/return_embedding
. Respect None vs False.- top_k=top_k or self._top_k, - scale_score=scale_score or self._scale_score, - return_embedding=return_embedding or self._return_embedding, + top_k=self._top_k if top_k is None else top_k, + scale_score=self._scale_score if scale_score is None else scale_score, + return_embedding=( + self._return_embedding + if return_embedding is None + else return_embedding + ),wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py (2)
144-159
: Dynamic update_components LGTM; consider chaining return self.Returning self enables fluent reconfiguration, but optional.
def update_components( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, **_, ): super().update_components( embedder_provider=embedder_provider, document_store_provider=document_store_provider, update_components=False, ) self._view_questions_store = self._document_store_provider.get_store( dataset_name="view_questions" ) self._components = self._update_components() + return self
24-35
: Avoid shadowing built-in list.Minor readability nit.
- list = [ + items = [ { "question": doc.content, "summary": doc.meta.get("summary", ""), "statement": doc.meta.get("statement") or doc.meta.get("sql"), "viewId": doc.meta.get("viewId", ""), } for doc in documents ] - return {"documents": list} + return {"documents": items}wren-ai-service/src/providers/llm/litellm.py (3)
32-43
: Alias and api key name plumbed correctly. Consider lazy api key resolution for runtime config updates.Storing
_api_key
at init means env/key updates after process start won't apply. Prefer deferring toos.getenv(self._api_key_name)
at call time or add anapi_key
property that does this.Example (outside this hunk):
@property def api_key(self) -> Optional[str]: return os.getenv(self._api_key_name) if self._api_key_name else self._api_keyThen pass
api_key=self.api_key
in acompletion.
139-147
: Guard against empty stream producing UnboundLocalError.If no chunks arrive,
chunk
is undefined when callingconnect_chunks(chunk, chunks)
. Guard or track the last chunk explicitly.Apply this diff:
- async for chunk in completion: + last_chunk = None + async for chunk in completion: if chunk.choices and streaming_callback: chunk_delta: StreamingChunk = build_chunk(chunk) chunks.append(chunk_delta) streaming_callback( chunk_delta, query_id ) # invoke callback with the chunk_delta - completions = [connect_chunks(chunk, chunks)] + last_chunk = chunk + if last_chunk is None: + return {"replies": [""], "meta": [{}]} + completions = [connect_chunks(last_chunk, chunks)]
75-81
: Broaden retryable exceptions.Only retrying on
openai.APIError
may miss connection/timeouts/rate-limits and LiteLLM errors. Consider includingopenai.RateLimitError
,openai.APIConnectionError
,openai.APITimeoutError
, and LiteLLM exceptions.wren-ai-service/src/pipelines/generation/sql_answer.py (2)
120-129
: Streaming queue logic is fine; ensure queue cleanup always runs.Cleanup is gated on "" which is good. Just a note: consider a finally/timeout path to avoid orphaned queues if a session never sends finish.
82-87
: Return type hint mismatch.
generate_answer
returns a tuple(dict, generator_name)
but is annotated as-> dict
. Adjust totuple[dict, str]
or return only the dict.wren-ai-service/src/pipelines/generation/sql_diagnosis.py (1)
92-96
: Return type annotation incorrect.
post_process
returns a parsed JSON object (dict), but the annotation is-> str
.Apply this diff:
-async def post_process( - generate_sql_diagnosis: dict, -) -> str: +async def post_process( + generate_sql_diagnosis: dict, +) -> dict:wren-ai-service/src/pipelines/generation/question_recommendation.py (2)
196-210
: Consider validating LLM JSON with Pydantic.After
orjson.loads
, validate againstQuestionResult
to fail fast on malformed outputs and normalize types.Example (outside this hunk):
try: parsed = QuestionResult.model_validate_json(text) return parsed.model_dump() except Exception as e: logger.error(f"Validation error: {e}") return {"questions": []}
191-193
: Return type hint alignment.
generate
returns a tuple(dict, generator_name)
but annotated as-> dict
. Align the annotation or the return shape.wren-ai-service/src/pipelines/generation/sql_question.py (2)
71-75
: Add defensive JSON parsing to avoid crashes on malformed repliesGuard against invalid JSON or empty replies to prevent 500s.
@observe(capture_input=False) def post_process( generate_sql_question: dict, ) -> str: - return orjson.loads(generate_sql_question.get("replies")[0])["question"] + try: + reply = (generate_sql_question.get("replies") or [None])[0] + if not reply: + return "" + payload = orjson.loads(reply) + return payload.get("question", "") + except Exception as e: + logger.warning("sql_question: failed to parse reply: %s", e) + return ""
95-109
: Description is stored but unusedIf description is intended for metadata/configs, propagate it to pipe metadata; otherwise remove to avoid dead state.
wren-ai-service/src/pipelines/generation/data_assistance.py (1)
130-153
: Ensure queues are cleaned up on timeout/error to avoid leaksIf the stream times out, the per-user queue remains in memory. Pop it in a finally block.
async def get_streaming_results(self, query_id): async def _get_streaming_results(query_id): return await self._user_queues[query_id].get() if query_id not in self._user_queues: self._user_queues[ query_id ] = asyncio.Queue() # Ensure the user's queue exists - while True: - try: + try: + while True: + try: # Wait for an item from the user's queue - self._streaming_results = await asyncio.wait_for( - _get_streaming_results(query_id), timeout=120 - ) - if ( - self._streaming_results == "<DONE>" - ): # Check for end-of-stream signal - del self._user_queues[query_id] - break - if self._streaming_results: # Check if there are results to yield - yield self._streaming_results - self._streaming_results = "" # Clear after yielding - except TimeoutError: - break + item = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if item == "<DONE>": + break + if item: + yield item + except asyncio.TimeoutError: + break + finally: + self._user_queues.pop(query_id, None)wren-ai-service/src/pipelines/generation/misleading_assistance.py (1)
130-153
: Clean up queues on timeout/errorMirror the leak prevention suggested for data_assistance.
- while True: - try: + try: + while True: + try: # Wait for an item from the user's queue - self._streaming_results = await asyncio.wait_for( - _get_streaming_results(query_id), timeout=120 - ) - if ( - self._streaming_results == "<DONE>" - ): # Check for end-of-stream signal - del self._user_queues[query_id] - break - if self._streaming_results: # Check if there are results to yield - yield self._streaming_results - self._streaming_results = "" # Clear after yielding - except TimeoutError: - break + item = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if item == "<DONE>": + break + if item: + yield item + except asyncio.TimeoutError: + break + finally: + self._user_queues.pop(query_id, None)wren-ai-service/src/pipelines/generation/chart_generation.py (1)
137-143
: File path robustness: avoid CWD-dependent open()Use a path relative to this file to prevent failures when CWD differs (e.g., when installed as a package).
- with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) + from pathlib import Path + schema_path = Path(__file__).parent / "utils" / "vega-lite-schema-v5.json" + _vega_schema = orjson.loads(schema_path.read_bytes())wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py (1)
66-87
: Avoid instantiating Configuration in default argumentsUse None default and construct inside to prevent unintended shared state.
def prompt( query: str, documents: list[str], histories: list[AskHistory], sql_samples: list[dict], instructions: list[dict], prompt_builder: PromptBuilder, - configuration: Configuration | None = Configuration(), + configuration: Configuration | None = None, ) -> dict: + configuration = configuration or Configuration()wren-ai-service/src/pipelines/generation/chart_adjustment.py (1)
164-169
: File path robustness: avoid CWD-dependent open()Same as chart_generation: use a path relative to this file.
- with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) + from pathlib import Path + schema_path = Path(__file__).parent / "utils" / "vega-lite-schema-v5.json" + _vega_schema = orjson.loads(schema_path.read_bytes())wren-ai-service/src/pipelines/retrieval/sql_functions.py (2)
24-33
: Fix return annotation for _extract
function_type
is a string; the type hint sayslist
.- def _extract() -> tuple[str, list, str]: + def _extract() -> tuple[str, str, str]:
128-135
: Avoid shadowing built-ininput
Use a different variable name to prevent confusion with the built-in.
- input = { + inputs = { "data_source": _data_source, "project_id": project_id, **self._components, } - result = await self._pipe.execute(["cache"], inputs=input) + result = await self._pipe.execute(["cache"], inputs=inputs)wren-ai-service/src/pipelines/generation/user_guide_assistance.py (1)
70-76
: Return only the generator result (not a tuple)Downstream doesn’t consume
generator_name
from the return value; keeping it in the tuple is inconsistent and risks surprises.-async def user_guide_assistance( - prompt: dict, generator: Any, query_id: str, generator_name: str -) -> dict: - return await generator( - prompt=prompt.get("prompt"), query_id=query_id - ), generator_name +async def user_guide_assistance( + prompt: dict, generator: Any, query_id: str, generator_name: str +) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id)wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py (2)
64-76
: Avoid defaulting to a constructed ConfigurationDon’t use
Configuration()
as a default value; construct it inside when needed.-def prompt( +def prompt( query: str, documents: list[str], sql_samples: list[dict], - instructions: list[dict], + instructions: list[str] | list[dict], prompt_builder: PromptBuilder, - configuration: Configuration | None = Configuration(), + configuration: Configuration | None = None, ) -> dict: - _prompt = prompt_builder.run( + configuration = configuration or Configuration() + _prompt = prompt_builder.run(
167-171
: Run signature: avoid constructing default ConfigurationMirror the fix in
prompt
and pass a fresh instance when absent.- configuration: Configuration = Configuration(), + configuration: Optional[Configuration] = None, query_id: Optional[str] = None, ): @@ - inputs={ + inputs={ "query": query, "documents": contexts, "sql_samples": sql_samples or [], "instructions": instructions or [], - "configuration": configuration, + "configuration": configuration or Configuration(), "query_id": query_id, **self._components, },wren-ai-service/src/pipelines/generation/relationship_recommendation.py (2)
76-99
: Type hint: function returns a list, not a dictAdjust return annotation for clarity.
-@observe(capture_input=False) -def cleaned_models(mdl: dict) -> dict: +@observe(capture_input=False) +def cleaned_models(mdl: dict) -> list[dict]:
119-129
: Wrapper return type hint is wrong
wrapper
returns a dict (or{}
), notstr
.- def wrapper(text: str) -> str: + def wrapper(text: str) -> dict:wren-ai-service/src/pipelines/generation/semantics_description.py (1)
155-165
: Type hint fix in normalize(): wrapper returns dict.Align annotation with return value and keep behavior.
-def wrapper(text: str) -> str: +def wrapper(text: str) -> dict:wren-ai-service/src/__main__.py (2)
151-174
: Simplify pipeline llm/embedder mapping; you already store aliases.The alias lookup map is model→alias but pipe_service_components already holds alias. Set directly.
- if llm_model: - if llm_model_alias := _llm_model_alias_mapping.get(llm_model): - _configs["pipelines"][pipe_name]["llm"] = llm_model_alias - else: - _configs["pipelines"][pipe_name]["llm"] = llm_model - if embedding_model: - if embedding_model_alias := _embedder_model_alias_mapping.get( - embedding_model - ): - _configs["pipelines"][pipe_name]["embedder"] = embedding_model_alias - else: - _configs["pipelines"][pipe_name]["embedder"] = embedding_model + if llm_model: + _configs["pipelines"][pipe_name]["llm"] = llm_model + if embedding_model: + _configs["pipelines"][pipe_name]["embedder"] = embedding_model
262-266
: Return the updated config snapshot.Saves a follow-up GET and confirms applied state.
- for pipeline_name, _ in app.state.pipe_components.items(): - pass + for pipeline_name, _ in app.state.pipe_components.items(): + pass # placeholder for future metadata updates + return get_configs()wren-ai-service/src/core/pipeline.py (1)
52-56
: Nit: make fields Optional[...] and default description empty string.Improves typing consistency and avoids None checks downstream.
-@dataclass -class PipelineComponent(Mapping): - description: str = None - llm_provider: LLMProvider = None - embedder_provider: EmbedderProvider = None - document_store_provider: DocumentStoreProvider = None +@dataclass +class PipelineComponent(Mapping): + description: str = "" + llm_provider: Optional[LLMProvider] = None + embedder_provider: Optional[EmbedderProvider] = None + document_store_provider: Optional[DocumentStoreProvider] = Nonewren-ai-service/src/pipelines/indexing/instructions.py (4)
66-78
: Use explicit 'filters' and avoid shadowing built-in 'filter'.Minor clarity/readability and avoids shadowing a built-in. Pass the kwarg explicitly to match Haystack API.
Apply this diff:
- filter = { + filters = { "operator": "AND", "conditions": [ {"field": "instruction_id", "operator": "in", "value": instruction_ids}, ], } if project_id: - filter["conditions"].append( + filters["conditions"].append( {"field": "project_id", "operator": "==", "value": project_id} ) - return await self.store.delete_documents(filter) + return await self.store.delete_documents(filters=filters)
101-109
: Avoid default mutable argument for 'embedding'.Safer to use None and default to an empty payload at runtime.
Apply this diff:
-async def clean( +async def clean( cleaner: InstructionsCleaner, instructions: List[Instruction], - embedding: Dict[str, Any] = {}, + embedding: Optional[Dict[str, Any]] = None, project_id: str = "", delete_all: bool = False, ) -> Dict[str, Any]: @@ - return embedding + return embedding or {"documents": []}
116-121
: Return type annotation mismatch.Function returns the writer’s result; annotate accordingly for type checkers.
-async def write( +async def write( clean: Dict[str, Any], writer: AsyncDocumentWriter, -) -> None: +) -> Dict[str, Any]: return await writer.run(documents=clean["documents"])
184-191
: Avoid shadowing built-in 'input'.Rename local variable for clarity.
- input = { + inputs = { "project_id": project_id, "instructions": instructions, **self._components, } - return await self._pipe.execute(["write"], inputs=input) + return await self._pipe.execute(["write"], inputs=inputs)wren-ai-service/src/pipelines/indexing/project_meta.py (1)
59-61
: Return type annotation mismatch.write returns the writer’s result; annotate accordingly.
-async def write(clean: dict[str, Any], writer: DocumentWriter) -> None: +async def write(clean: dict[str, Any], writer: DocumentWriter) -> dict[str, Any]: return await writer.run(documents=clean["documents"])wren-ai-service/src/providers/embedder/litellm.py (3)
128-142
: Harden usage aggregation to avoid KeyError when earlier batches lack usage.Make accumulation resilient across mixed responses.
- if "usage" not in meta: - meta["usage"] = ( - dict(response.usage) if hasattr(response, "usage") else {} - ) - else: - if hasattr(response, "usage"): - meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens - meta["usage"]["total_tokens"] += response.usage.total_tokens + if hasattr(response, "usage"): + usage_dict = dict(response.usage) + u = meta.setdefault("usage", {}) + u["prompt_tokens"] = u.get("prompt_tokens", 0) + usage_dict.get("prompt_tokens", 0) + u["total_tokens"] = u.get("total_tokens", 0) + usage_dict.get("total_tokens", 0)
52-53
: Broaden retry policy to catch common transient errors.Consider including rate limits/timeouts from the SDK to reduce flakiness.
Example:
-@backoff.on_exception(backoff.expo, openai.APIError, max_time=60.0, max_tries=3) +@backoff.on_exception(backoff.expo, (openai.APIError, openai.RateLimitError), max_time=60.0, max_tries=3)And similarly for AsyncDocumentEmbedder.run.
Also applies to: 146-147
18-21
: Docstring says metadata is embedded, but only content is used.Align comment with behavior, or concatenate selected metadata fields if intended.
wren-ai-service/src/pipelines/indexing/historical_question.py (3)
55-56
: Return type annotation mismatch for ViewChunker.run.It returns a dict; annotate accordingly.
- def run(self, mdl: Dict[str, Any], project_id: Optional[str] = None) -> None: + def run(self, mdl: Dict[str, Any], project_id: Optional[str] = None) -> Dict[str, Any]:
128-130
: Return type annotation mismatch.write returns the writer’s result; annotate accordingly.
-async def write(clean: Dict[str, Any], writer: DocumentWriter) -> None: +async def write(clean: Dict[str, Any], writer: DocumentWriter) -> Dict[str, Any]: return await writer.run(documents=clean["documents"])
85-92
: Avoid tqdm in production hot path (optional).Progress bars can add overhead/log noise in services; gate behind log level.
- "documents": [ - Document(**chunk) - for chunk in tqdm( - chunks, - desc=f"Project ID: {project_id}, Chunking views into documents", - ) - ] + "documents": [ + Document(**chunk) + for chunk in ( + tqdm( + chunks, + desc=f"Project ID: {project_id}, Chunking views into documents", + ) + if logger.isEnabledFor(logging.INFO) + else chunks + ) + ]wren-ai-service/src/pipelines/indexing/sql_pairs.py (5)
128-131
: Avoid mutable default argumentsUse None defaults to prevent shared state across calls.
@@ -async def clean( - cleaner: SqlPairsCleaner, - sql_pairs: List[SqlPair], - embedding: Dict[str, Any] = {}, +async def clean( + cleaner: SqlPairsCleaner, + sql_pairs: List[SqlPair], + embedding: Optional[Dict[str, Any]] = None, @@ - sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs] + embedding = embedding or {} + sql_pair_ids = [sql_pair.id for sql_pair in sql_pairs] @@ - async def clean( - self, - sql_pairs: List[SqlPair] = [], + async def clean( + self, + sql_pairs: Optional[List[SqlPair]] = None, project_id: Optional[str] = None, delete_all: bool = False, ) -> None: - await clean( + await clean( sql_pairs=sql_pairs, cleaner=self._components["cleaner"], project_id=project_id, delete_all=delete_all, )Also applies to: 239-249
81-87
: Type-guard boilerplate before .lower()Guard against non-string values in mdl["models"][*]["properties"]["boilerplate"].
-return { - boilerplate.lower() - for model in mdl.get("models", []) - if (boilerplate := model.get("properties", {}).get("boilerplate")) -} +return { + boilerplate.lower() + for model in mdl.get("models", []) + if (boilerplate := model.get("properties", {}).get("boilerplate")) and isinstance(boilerplate, str) +}
95-104
: Ensure SqlPair.id is populatedExternal data may omit id. Provide a stable fallback.
- SqlPair( - id=pair.get("id"), + SqlPair( + id=str(pair.get("id") or uuid.uuid4()), question=pair.get("question"), sql=pair.get("sql"), )
224-235
: Rename local variable input to inputsAvoid shadowing built-in input().
- input = { + inputs = { "mdl_str": mdl_str, "project_id": project_id, "external_pairs": { **self._external_pairs, **(external_pairs or {}), }, **self._components, } - return await self._pipe.execute(["write"], inputs=input) + return await self._pipe.execute(["write"], inputs=inputs)
36-49
: Optional: use stable document ids to make overwrites independent of cleaningConsider deriving Document.id from sql_pair_id to let DuplicatePolicy.OVERWRITE supersede prior versions even if cleaning is skipped.
- Document( - id=str(uuid.uuid4()), + Document( + id=f"sqlpair:{sql_pair.id}",wren-ai-service/src/pipelines/generation/intent_classification.py (1)
327-335
: Fix typo in constant name for clarityRename INTENT_CLASSIFICAION_MODEL_KWARGS → INTENT_CLASSIFICATION_MODEL_KWARGS and update usage.
-INTENT_CLASSIFICAION_MODEL_KWARGS = { +INTENT_CLASSIFICATION_MODEL_KWARGS = { @@ - "generator": self._llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=intent_classification_system_prompt, - generation_kwargs=INTENT_CLASSIFICAION_MODEL_KWARGS, + generation_kwargs=INTENT_CLASSIFICATION_MODEL_KWARGS, ),Also applies to: 380-384
wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py (2)
23-33
: Avoid shadowing built-in listUse a neutral name.
def run(self, documents: List[Document]): - list = [] + items = [] @@ - list.append(formatted) + items.append(formatted) - return {"documents": list} + return {"documents": items}
38-41
: Loosen store type to generic DocumentStoreThe provider returns an async store; using the generic interface avoids coupling to a specific backend type.
-from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from haystack.document_stores.types import DocumentStore @@ -async def count_documents( - store: QdrantDocumentStore, project_id: Optional[str] = None +async def count_documents( + store: DocumentStore, project_id: Optional[str] = None ) -> int:wren-ai-service/src/pipelines/retrieval/instructions.py (2)
23-35
: Avoid shadowing built-in listRename local var.
def run(self, documents: List[Document]): - list = [] + items = [] @@ - list.append(formatted) + items.append(formatted) - return {"documents": list} + return {"documents": items}
59-63
: Loosen store type to generic DocumentStoreAligns with provider interface and async usage.
-from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from haystack.document_stores.types import DocumentStore @@ -async def count_documents( - store: QdrantDocumentStore, project_id: Optional[str] = None +async def count_documents( + store: DocumentStore, project_id: Optional[str] = None ) -> int:wren-ai-service/src/pipelines/generation/sql_generation.py (1)
157-159
: Consider moving super().init() to after instance variables are initialized.Moving the
super().__init__()
call after initializing instance variables follows a more typical initialization pattern and ensures all dependencies are ready before the base class potentially uses them.- super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - self._document_store_provider = document_store_provider self._retriever = self._document_store_provider.get_retriever( self._document_store_provider.get_store("project_meta") ) self._llm_provider = llm_provider self._engine = engine self._description = description + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._components = self._update_components()wren-ai-service/src/utils.py (1)
259-268
: Consider making the hardcoded pipeline list configurable.The
has_db_data_in_llm_prompt
function hardcodes specific pipeline names. Consider making this list configurable through the Configs class or environment variables for better maintainability.def has_db_data_in_llm_prompt(pipe_name: str) -> bool: - pipes_containing_db_data = set( - [ - "sql_answer", - "chart_adjustment", - "chart_generation", - ] - ) + # Could be loaded from config or environment + pipes_containing_db_data = getattr( + settings, + 'PIPES_WITH_DB_DATA', + {"sql_answer", "chart_adjustment", "chart_generation"} + ) return pipe_name in pipes_containing_db_datawren-ai-service/src/providers/__init__.py (1)
343-401
: Fix: Function signature doesn't match return type annotation.The function annotation says it returns
dict[str, PipelineComponent]
but actually returns a tuple. Update the annotation to match the implementation.-def generate_components(configs: list[dict]) -> dict[str, PipelineComponent]: +def generate_components(configs: list[dict]) -> tuple[dict[str, PipelineComponent], dict]:wren-ai-service/src/pipelines/indexing/db_schema.py (1)
363-363
: Keep helper.load_helpers(), but run it before component initialization or re-run it on updates.load_helpers populates MODEL_PREPROCESSORS/COLUMN_PREPROCESSORS/COLUMN_COMMENT_HELPERS used by DDLChunker at runtime; it’s currently called after self._components = self._update_components() in wren-ai-service/src/pipelines/indexing/db_schema.py (DBSchema.init) and is not invoked in update_components(), so dynamically added helpers would not be picked up.
- Action: Move helper.load_helpers() to before self._components = self._update_components() in DBSchema.init OR call helper.load_helpers() inside update_components() so helpers are loaded/reloaded on dynamic updates.
- Ensure load_helpers is idempotent / safe to call multiple times.
wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py (1)
511-534
: Align update_components with BasicPipeline and remove unused retriever attributesCallers pass None and _table_retriever/_dbschema_retriever are only assigned here — making providers Optional and dropping those attributes is safe.
- def update_components( - self, - llm_provider: LLMProvider, - embedder_provider: EmbedderProvider, - document_store_provider: DocumentStoreProvider, - **_, - ): - super().update_components( - llm_provider=llm_provider, - embedder_provider=embedder_provider, - document_store_provider=document_store_provider, - update_components=False, - ) - self._table_retriever = self._document_store_provider.get_retriever( - self._document_store_provider.get_store(dataset_name="table_descriptions"), - top_k=self._table_retrieval_size, - ) - self._dbschema_retriever = self._document_store_provider.get_retriever( - self._document_store_provider.get_store(), - top_k=self._table_column_retrieval_size, - ) - self._components = self._update_components() - self._configs = self._update_configs() + def update_components( + self, + llm_provider: Optional[LLMProvider] = None, + embedder_provider: Optional[EmbedderProvider] = None, + document_store_provider: Optional[DocumentStoreProvider] = None, + **_, + ): + super().update_components( + llm_provider=llm_provider, + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._components = self._update_components() + self._configs = self._update_configs()
@app.get("/configs") | ||
def get_configs(): | ||
_configs = { | ||
"env_vars": {}, | ||
"providers": { | ||
"llm": [], | ||
"embedder": [], | ||
}, | ||
"pipelines": {}, | ||
} | ||
|
||
_llm_model_alias_mapping = {} | ||
_embedder_model_alias_mapping = {} | ||
|
||
_llm_configs = [] | ||
for _, model_config in app.state.instantiated_providers["llm"].items(): | ||
_llm_config = { | ||
"model": model_config._model, | ||
"alias": model_config._alias, | ||
"context_window_size": model_config._context_window_size, | ||
"timeout": model_config._timeout, | ||
"kwargs": model_config._model_kwargs, | ||
} | ||
if model_config._api_base: | ||
_llm_config["api_base"] = model_config._api_base | ||
if model_config._api_version: | ||
_llm_config["api_version"] = model_config._api_version | ||
_llm_configs.append(_llm_config) | ||
_llm_model_alias_mapping[model_config._model] = model_config._alias | ||
_configs["providers"]["llm"] = _llm_configs | ||
|
||
_embedder_configs = [] | ||
# we only support one embedding model now | ||
for _, model_config in app.state.instantiated_providers["embedder"].items(): | ||
_embedder_config = { | ||
"model": model_config._model, | ||
"alias": model_config._alias, | ||
"dimension": app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._embedding_model_dim, | ||
"timeout": model_config._timeout, | ||
"kwargs": model_config._model_kwargs, | ||
} | ||
if model_config._api_base: | ||
_embedder_config["api_base"] = model_config._api_base | ||
if model_config._api_version: | ||
_embedder_config["api_version"] = model_config._api_version | ||
_embedder_configs.append(_embedder_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Sanitize exposed kwargs to avoid accidental secret leakage.
If any provider kwargs include sensitive fields, filter them before returning.
Proposed filter example:
- "kwargs": model_config._model_kwargs,
+ "kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items() if k.lower() not in {"api_key", "key", "token", "secret"}},
Run to verify presence of sensitive keys in kwargs across providers:
🏁 Script executed:
#!/bin/bash
rg -nP --type=py -C2 '_model_kwargs' src/providers
Length of output: 53
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "PWD: $(pwd)"
echo "Listing top-level files/dirs:"
ls -la || true
echo
echo "Searching repository for occurrences of '_model_kwargs' and related patterns..."
rg -n --hidden -S "_model_kwargs" -C3 || true
echo
echo "Searching for '\"kwargs':' occurrences (to see where kwargs are returned in JSON responses)..."
rg -n --hidden -S '"kwargs"\s*:' -C3 || true
echo
echo "Searching for 'def get_configs' and __main__ to locate the reported file."
rg -n --hidden -S "def get_configs" -C3 || true
rg -n --hidden -S "__main__\.py" -C3 || true
echo
echo "Searching for 'instantiated_providers' usages (to confirm provider objects structure)..."
rg -n --hidden -S "instantiated_providers" -C3 || true
echo
echo "Searching for common secret-like keys in code (api_key, token, secret) to spot exposed fields."
rg -n --hidden -S -i "api_key|apikey|token|secret|api_secret" -C1 || true
Length of output: 317607
Sanitize exposed provider kwargs in API responses
The API currently returns raw provider kwargs (risk: leaking secrets such as API keys/tokens). Filter sensitive keys (handle None) before returning.
- Files & places to fix:
- wren-ai-service/src/main.py — _llm_config and _embedder_config (currently
"kwargs": model_config._model_kwargs
). Replace with a sanitized dict. Example:
- wren-ai-service/src/main.py — _llm_config and _embedder_config (currently
- "kwargs": model_config._model_kwargs,
+ "kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items()
+ if k.lower() not in {"api_key","key","token","secret","password","access_key","access_token","secret_key","client_secret","authorization"}},
-
wren-ai-service/src/globals.py —
llm_model_kwargs
(line ~317) is populated from provider.model_kwargs; sanitize there too using the same filter before exposing. -
Tests & config:
- Tests that assert presence/shape of
kwargs
(e.g. tests/pytest/test_config.py, tests/pytest/providers/test_providers.py) may need updating if they expect raw values.
- Tests that assert presence/shape of
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@app.get("/configs") | |
def get_configs(): | |
_configs = { | |
"env_vars": {}, | |
"providers": { | |
"llm": [], | |
"embedder": [], | |
}, | |
"pipelines": {}, | |
} | |
_llm_model_alias_mapping = {} | |
_embedder_model_alias_mapping = {} | |
_llm_configs = [] | |
for _, model_config in app.state.instantiated_providers["llm"].items(): | |
_llm_config = { | |
"model": model_config._model, | |
"alias": model_config._alias, | |
"context_window_size": model_config._context_window_size, | |
"timeout": model_config._timeout, | |
"kwargs": model_config._model_kwargs, | |
} | |
if model_config._api_base: | |
_llm_config["api_base"] = model_config._api_base | |
if model_config._api_version: | |
_llm_config["api_version"] = model_config._api_version | |
_llm_configs.append(_llm_config) | |
_llm_model_alias_mapping[model_config._model] = model_config._alias | |
_configs["providers"]["llm"] = _llm_configs | |
_embedder_configs = [] | |
# we only support one embedding model now | |
for _, model_config in app.state.instantiated_providers["embedder"].items(): | |
_embedder_config = { | |
"model": model_config._model, | |
"alias": model_config._alias, | |
"dimension": app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._embedding_model_dim, | |
"timeout": model_config._timeout, | |
"kwargs": model_config._model_kwargs, | |
} | |
if model_config._api_base: | |
_embedder_config["api_base"] = model_config._api_base | |
if model_config._api_version: | |
_embedder_config["api_version"] = model_config._api_version | |
_embedder_configs.append(_embedder_config) | |
@app.get("/configs") | |
def get_configs(): | |
_configs = { | |
"env_vars": {}, | |
"providers": { | |
"llm": [], | |
"embedder": [], | |
}, | |
"pipelines": {}, | |
} | |
_llm_model_alias_mapping = {} | |
_embedder_model_alias_mapping = {} | |
_llm_configs = [] | |
for _, model_config in app.state.instantiated_providers["llm"].items(): | |
_llm_config = { | |
"model": model_config._model, | |
"alias": model_config._alias, | |
"context_window_size": model_config._context_window_size, | |
"timeout": model_config._timeout, | |
"kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items() | |
if k.lower() not in {"api_key", "key", "token", "secret", "password", "access_key", "access_token", "secret_key", "client_secret", "authorization"}}, | |
} | |
if model_config._api_base: | |
_llm_config["api_base"] = model_config._api_base | |
if model_config._api_version: | |
_llm_config["api_version"] = model_config._api_version | |
_llm_configs.append(_llm_config) | |
_llm_model_alias_mapping[model_config._model] = model_config._alias | |
_configs["providers"]["llm"] = _llm_configs | |
_embedder_configs = [] | |
# we only support one embedding model now | |
for _, model_config in app.state.instantiated_providers["embedder"].items(): | |
_embedder_config = { | |
"model": model_config._model, | |
"alias": model_config._alias, | |
"dimension": app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._embedding_model_dim, | |
"timeout": model_config._timeout, | |
"kwargs": {k: v for k, v in (model_config._model_kwargs or {}).items() | |
if k.lower() not in {"api_key", "key", "token", "secret", "password", "access_key", "access_token", "secret_key", "client_secret", "authorization"}}, | |
} | |
if model_config._api_base: | |
_embedder_config["api_base"] = model_config._api_base | |
if model_config._api_version: | |
_embedder_config["api_version"] = model_config._api_version | |
_embedder_configs.append(_embedder_config) |
@app.post("/configs") | ||
def update_configs(configs_request: Configs): | ||
try: | ||
# override current instantiated_providers | ||
app.state.instantiated_providers["embedder"] = { | ||
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | ||
**embedder_provider.__dict__ | ||
) | ||
for embedder_provider in configs_request.providers.embedder | ||
} | ||
app.state.instantiated_providers["llm"] = { | ||
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | ||
**llm_provider.__dict__ | ||
) | ||
for llm_provider in configs_request.providers.llm | ||
} | ||
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | ||
location=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._location, | ||
api_key=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._api_key, | ||
timeout=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._timeout, | ||
embedding_model_dim=configs_request.providers.embedder[0].dimension, | ||
recreate_index=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against empty embedder list.
IndexError if providers.embedder is empty. Add checks or 400.
-def update_configs(configs_request: Configs):
+def update_configs(configs_request: Configs):
try:
+ if not configs_request.providers.embedder:
+ raise HTTPException(status_code=400, detail="At least one embedder is required")
# override current instantiated_providers
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@app.post("/configs") | |
def update_configs(configs_request: Configs): | |
try: | |
# override current instantiated_providers | |
app.state.instantiated_providers["embedder"] = { | |
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | |
**embedder_provider.__dict__ | |
) | |
for embedder_provider in configs_request.providers.embedder | |
} | |
app.state.instantiated_providers["llm"] = { | |
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | |
**llm_provider.__dict__ | |
) | |
for llm_provider in configs_request.providers.llm | |
} | |
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | |
location=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._location, | |
api_key=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._api_key, | |
timeout=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._timeout, | |
embedding_model_dim=configs_request.providers.embedder[0].dimension, | |
recreate_index=True, | |
) | |
@app.post("/configs") | |
def update_configs(configs_request: Configs): | |
try: | |
if not configs_request.providers.embedder: | |
raise HTTPException(status_code=400, detail="At least one embedder is required") | |
# override current instantiated_providers | |
app.state.instantiated_providers["embedder"] = { | |
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | |
**embedder_provider.__dict__ | |
) | |
for embedder_provider in configs_request.providers.embedder | |
} | |
app.state.instantiated_providers["llm"] = { | |
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | |
**llm_provider.__dict__ | |
) | |
for llm_provider in configs_request.providers.llm | |
} | |
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | |
location=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._location, | |
api_key=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._api_key, | |
timeout=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._timeout, | |
embedding_model_dim=configs_request.providers.embedder[0].dimension, | |
recreate_index=True, | |
) |
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 178 to 206, the code assumes
configs_request.providers.embedder has at least one item and will raise
IndexError when empty; before accessing
configs_request.providers.embedder[0].dimension, validate that
providers.embedder is not empty and if it is return a 400 HTTP response (raise
fastapi.HTTPException(status_code=400, detail="...")) or otherwise handle the
missing embedder (e.g., require a default embedder or skip updating
document_store.embedding_model_dim), then proceed to instantiate providers;
ensure the validation occurs before any use of embedder[0] so the endpoint fails
fast with a clear error message.
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | ||
**embedder_provider.__dict__ | ||
) | ||
for embedder_provider in configs_request.providers.embedder | ||
} | ||
app.state.instantiated_providers["llm"] = { | ||
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | ||
**llm_provider.__dict__ | ||
) | ||
for llm_provider in configs_request.providers.llm | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid Pydantic dict; use model_dump(exclude_none=True).
dict may include internals; model_dump is stable/safe.
- f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
- **embedder_provider.__dict__
- )
+ f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider(
+ **embedder_provider.model_dump(exclude_none=True)
+ )
for embedder_provider in configs_request.providers.embedder
}
app.state.instantiated_providers["llm"] = {
- f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
- **llm_provider.__dict__
- )
+ f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(
+ **llm_provider.model_dump(exclude_none=True)
+ )
for llm_provider in configs_request.providers.llm
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | |
**embedder_provider.__dict__ | |
) | |
for embedder_provider in configs_request.providers.embedder | |
} | |
app.state.instantiated_providers["llm"] = { | |
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | |
**llm_provider.__dict__ | |
) | |
for llm_provider in configs_request.providers.llm | |
} | |
f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( | |
**embedder_provider.model_dump(exclude_none=True) | |
) | |
for embedder_provider in configs_request.providers.embedder | |
} | |
app.state.instantiated_providers["llm"] = { | |
f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( | |
**llm_provider.model_dump(exclude_none=True) | |
) | |
for llm_provider in configs_request.providers.llm | |
} |
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 183 to 193, the code uses
Pydantic objects' __dict__ to expand kwargs for LitellmEmbedderProvider and
LitellmLLMProvider; replace each use of embedder_provider.__dict__ and
llm_provider.__dict__ with embedder_provider.model_dump(exclude_none=True) and
llm_provider.model_dump(exclude_none=True) respectively so only validated fields
are passed and internal attributes are omitted. Ensure both comprehensions call
model_dump(exclude_none=True) on each provider instance and no other behavior
changes are introduced.
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | ||
location=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._location, | ||
api_key=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._api_key, | ||
timeout=app.state.instantiated_providers["document_store"][ | ||
"qdrant" | ||
]._timeout, | ||
embedding_model_dim=configs_request.providers.embedder[0].dimension, | ||
recreate_index=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Danger: unconditional Qdrant reindex can drop data. Gate on dimension change.
Recreate only if the embedding dim actually changes (or behind an explicit flag).
- app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
- location=app.state.instantiated_providers["document_store"][
- "qdrant"
- ]._location,
- api_key=app.state.instantiated_providers["document_store"][
- "qdrant"
- ]._api_key,
- timeout=app.state.instantiated_providers["document_store"][
- "qdrant"
- ]._timeout,
- embedding_model_dim=configs_request.providers.embedder[0].dimension,
- recreate_index=True,
- )
+ _current_qdrant = app.state.instantiated_providers["document_store"]["qdrant"]
+ _new_dim = (
+ configs_request.providers.embedder[0].dimension
+ if configs_request.providers.embedder
+ else _current_qdrant._embedding_model_dim
+ )
+ _should_recreate = _new_dim != _current_qdrant._embedding_model_dim
+ app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider(
+ location=_current_qdrant._location,
+ api_key=_current_qdrant._api_key,
+ timeout=_current_qdrant._timeout,
+ embedding_model_dim=_new_dim,
+ recreate_index=_should_recreate,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | |
location=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._location, | |
api_key=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._api_key, | |
timeout=app.state.instantiated_providers["document_store"][ | |
"qdrant" | |
]._timeout, | |
embedding_model_dim=configs_request.providers.embedder[0].dimension, | |
recreate_index=True, | |
) | |
_current_qdrant = app.state.instantiated_providers["document_store"]["qdrant"] | |
_new_dim = ( | |
configs_request.providers.embedder[0].dimension | |
if configs_request.providers.embedder | |
else _current_qdrant._embedding_model_dim | |
) | |
_should_recreate = _new_dim != _current_qdrant._embedding_model_dim | |
app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( | |
location=_current_qdrant._location, | |
api_key=_current_qdrant._api_key, | |
timeout=_current_qdrant._timeout, | |
embedding_model_dim=_new_dim, | |
recreate_index=_should_recreate, | |
) |
🤖 Prompt for AI Agents
In wren-ai-service/src/__main__.py around lines 194 to 206, the code
unconditionally sets recreate_index=True when instantiating the QdrantProvider
which can drop existing data; change this to only recreate when the embedder
dimension actually changes or when an explicit override flag is set. Implement a
check that reads the current provider's embedding_model_dim (if present) and
compares it to configs_request.providers.embedder[0].dimension and set
recreate_index=True only if they differ; otherwise set recreate_index=False.
Also add support for an explicit override (e.g. a config or env flag like
configs_request.providers.document_store.force_recreate or
app.state.force_recreate) that, when true, forces recreate_index=True regardless
of dimension comparison.
def update_components( | ||
self, | ||
llm_provider: LLMProvider, | ||
document_store_provider: DocumentStoreProvider, | ||
**_, | ||
): | ||
super().update_components( | ||
llm_provider=llm_provider, document_store_provider=document_store_provider | ||
) | ||
self._retriever = self._document_store_provider.get_retriever( | ||
self._document_store_provider.get_store("project_meta") | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing call to self._components = self._update_components() in update_components method.
The update_components
method updates providers and retriever but doesn't rebuild the components dictionary. This differs from the pattern in other pipelines like followup_sql_generation.py
where self._components = self._update_components()
is called at the end.
Apply this fix to ensure components are rebuilt with the updated providers:
def update_components(
self,
llm_provider: LLMProvider,
document_store_provider: DocumentStoreProvider,
**_,
):
super().update_components(
llm_provider=llm_provider, document_store_provider=document_store_provider
)
self._retriever = self._document_store_provider.get_retriever(
self._document_store_provider.get_store("project_meta")
)
+ self._components = self._update_components()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def update_components( | |
self, | |
llm_provider: LLMProvider, | |
document_store_provider: DocumentStoreProvider, | |
**_, | |
): | |
super().update_components( | |
llm_provider=llm_provider, document_store_provider=document_store_provider | |
) | |
self._retriever = self._document_store_provider.get_retriever( | |
self._document_store_provider.get_store("project_meta") | |
) | |
def update_components( | |
self, | |
llm_provider: LLMProvider, | |
document_store_provider: DocumentStoreProvider, | |
**_, | |
): | |
super().update_components( | |
llm_provider=llm_provider, document_store_provider=document_store_provider | |
) | |
self._retriever = self._document_store_provider.get_retriever( | |
self._document_store_provider.get_store("project_meta") | |
) | |
self._components = self._update_components() |
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/generation/sql_generation.py around lines 170
to 182, the update_components method updates providers and the retriever but
does not rebuild the components dict; add a final call to self._components =
self._update_components() (exactly as other pipelines do) so the pipeline's
components are reconstructed using the updated llm_provider and
document_store_provider.
def _update_components(self): | ||
return { | ||
"embedder": self._embedder_provider.get_text_embedder(), | ||
"table_retriever": self._document_store_provider.get_retriever( | ||
self._document_store_provider.get_store( | ||
dataset_name="table_descriptions" | ||
), | ||
top_k=self._table_retrieval_size, | ||
), | ||
"dbschema_retriever": document_store_provider.get_retriever( | ||
document_store_provider.get_store(), | ||
top_k=table_column_retrieval_size, | ||
"dbschema_retriever": self._document_store_provider.get_retriever( | ||
self._document_store_provider.get_store(), | ||
top_k=self._table_column_retrieval_size, | ||
), | ||
"table_columns_selection_generator": llm_provider.get_generator( | ||
"table_columns_selection_generator": self._llm_provider.get_generator( | ||
system_prompt=table_columns_selection_system_prompt, | ||
generation_kwargs=RETRIEVAL_MODEL_KWARGS, | ||
), | ||
"generator_name": llm_provider.get_model(), | ||
"generator_name": self._llm_provider.model, | ||
"prompt_builder": PromptBuilder( | ||
template=table_columns_selection_user_prompt_template | ||
), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against context overflow when prompting the selection LLM.
When token count exceeds the context window you still send all DDLs to the generator to select columns, which can exceed the same context window. Consider chunking the DDLs, pre‑filtering tables (e.g., by semantic similarity or name constraints), or reserving headroom for prompt overhead.
I can draft a chunked selection pass (N tables per call with merge) if helpful.
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py around lines
488-509, the selection LLM is sent all DDLs unbounded which can exceed the model
context window; modify _update_components and the selection flow to enforce a
token-safe pipeline: compute a token budget for the generator (context window
minus prompt overhead), pre-filter candidate tables by simple heuristics (name
patterns, schema size) and/or a fast semantic similarity pass to reduce
candidates, split remaining DDLs into chunks that fit the token budget (N tables
per call), call the generator iteratively for each chunk and merge the partial
selections into a final set, and ensure generator calls include reserved
headroom for system/user prompt and any completion; update any config to expose
chunk size and token headroom.
"table_columns_selection_generator": self._llm_provider.get_generator( | ||
system_prompt=table_columns_selection_system_prompt, | ||
generation_kwargs=RETRIEVAL_MODEL_KWARGS, | ||
), | ||
"generator_name": llm_provider.get_model(), | ||
"generator_name": self._llm_provider.model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid logging chain‑of‑thought; tighten output schema to reduce tokens.
The system prompt asks for step‑by‑step “chain_of_thought_reasoning”. Combined with @observe(as_type="generation", capture_input=False)
the output may still be logged by Langfuse. Either disable output capture on the generation step or drop CoT from the schema to brief rationales.
Outside this hunk, update the decorator:
@observe(as_type="generation", capture_input=False, capture_output=False)
Optionally, rename chain_of_thought_reasoning
to a short rationales
array and instruct the model to produce terse bullet points only.
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py around lines
501 to 505, the system prompt requests step-by-step "chain_of_thought_reasoning"
which risks logging CoT tokens; remove or shorten CoT by replacing that field
with a brief "rationales" array of terse bullet points (or drop it entirely),
and change the Langfuse observe decorator (outside this hunk) to disable output
capture by using capture_output=False (i.e., @observe(as_type="generation",
capture_input=False, capture_output=False)); ensure the generator prompt
instructs concise rationales only and update any schema/field names accordingly.
def _update_configs(self): | ||
_model = (self._llm_provider.model,) | ||
if _model == "gpt-4o-mini" or _model == "gpt-4o": | ||
_encoding = tiktoken.get_encoding("o200k_base") | ||
else: | ||
_encoding = tiktoken.get_encoding("cl100k_base") | ||
|
||
self._configs = { | ||
return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix model tuple bug and robustly select encoding
current code wraps model in a tuple and compares with ==, so the 'o200k_base' branch never triggers.
def _update_configs(self):
- _model = (self._llm_provider.model,)
- if _model == "gpt-4o-mini" or _model == "gpt-4o":
+ _model = self._llm_provider.model
+ if "gpt-4o" in _model:
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _update_configs(self): | |
_model = (self._llm_provider.model,) | |
if _model == "gpt-4o-mini" or _model == "gpt-4o": | |
_encoding = tiktoken.get_encoding("o200k_base") | |
else: | |
_encoding = tiktoken.get_encoding("cl100k_base") | |
self._configs = { | |
return { | |
def _update_configs(self): | |
_model = self._llm_provider.model | |
if "gpt-4o" in _model: | |
_encoding = tiktoken.get_encoding("o200k_base") | |
else: | |
_encoding = tiktoken.get_encoding("cl100k_base") | |
return { |
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py around lines
92 to 99, the code incorrectly wraps the model name in a tuple and compares with
== so the o200k_base branch never matches; change to read the model string
directly (e.g., model = self._llm_provider.model), guard for None, and select
encoding robustly by checking the model name (for example use
model.startswith("gpt-4o") or check membership in an explicit set of names) to
return tiktoken.get_encoding("o200k_base") for gpt-4o variants and
tiktoken.get_encoding("cl100k_base") otherwise.
def update_components(self, document_store_provider: DocumentStoreProvider, **_): | ||
self._document_store_provider = document_store_provider | ||
self._retriever = self._document_store_provider.get_retriever( | ||
self._document_store_provider.get_store("project_meta") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align update_components signature with Base + clear stale cache
This override diverges from BasicPipeline.update_components(...)
, which can break central reconfiguration calls. Also clear the TTL cache when the underlying store/retriever changes.
- def update_components(self, document_store_provider: DocumentStoreProvider, **_):
- self._document_store_provider = document_store_provider
- self._retriever = self._document_store_provider.get_retriever(
- self._document_store_provider.get_store("project_meta")
- )
+ def update_components(
+ self,
+ llm_provider=None,
+ embedder_provider=None,
+ document_store_provider: DocumentStoreProvider | None = None,
+ update_components: bool = True,
+ **_,
+ ):
+ # keep base state consistent; we don't need base to rebuild components here
+ super().update_components(
+ llm_provider=llm_provider,
+ embedder_provider=embedder_provider,
+ document_store_provider=document_store_provider,
+ update_components=False,
+ )
+ if document_store_provider is not None:
+ self._document_store_provider = document_store_provider
+ self._retriever = self._document_store_provider.get_retriever(
+ self._document_store_provider.get_store("project_meta")
+ )
+ # avoid returning stale results across provider swaps
+ self._cache.clear()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def update_components(self, document_store_provider: DocumentStoreProvider, **_): | |
self._document_store_provider = document_store_provider | |
self._retriever = self._document_store_provider.get_retriever( | |
self._document_store_provider.get_store("project_meta") | |
) | |
def update_components( | |
self, | |
llm_provider=None, | |
embedder_provider=None, | |
document_store_provider: DocumentStoreProvider | None = None, | |
update_components: bool = True, | |
**_, | |
): | |
# keep base state consistent; we don't need base to rebuild components here | |
super().update_components( | |
llm_provider=llm_provider, | |
embedder_provider=embedder_provider, | |
document_store_provider=document_store_provider, | |
update_components=False, | |
) | |
if document_store_provider is not None: | |
self._document_store_provider = document_store_provider | |
self._retriever = self._document_store_provider.get_retriever( | |
self._document_store_provider.get_store("project_meta") | |
) | |
# avoid returning stale results across provider swaps | |
self._cache.clear() |
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/retrieval/sql_functions.py around lines 106 to
110, the update_components override uses a different signature than
BasicPipeline.update_components and doesn't clear cached values when the
underlying store/retriever changes; change the method signature to match the
base (accept the same positional and keyword args, e.g., (*args, **kwargs) or
the exact parameters used by BasicPipeline), call super().update_components(...)
so central reconfiguration still works, then set self._document_store_provider
and recompute self._retriever as before, and finally clear any TTL/lru caches
associated with cached helper methods or properties that depend on the
store/retriever (call .cache_clear() on those decorated functions or otherwise
invalidate the cache) so no stale data remains.
def get_text_embedder(self): | ||
return AsyncTextEmbedder( | ||
api_key=self._api_key, | ||
api_base_url=self._api_base, | ||
model=self._embedding_model, | ||
api_version=self._api_version, | ||
model=self._model, | ||
timeout=self._timeout, | ||
**self._kwargs, | ||
**self._model_kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Constructor arg name mismatch: AsyncTextEmbedder expects 'api_base', not 'api_base_url'.
This will raise a TypeError at runtime.
def get_text_embedder(self):
return AsyncTextEmbedder(
api_key=self._api_key,
- api_base_url=self._api_base,
+ api_base=self._api_base,
api_version=self._api_version,
model=self._model,
timeout=self._timeout,
**self._model_kwargs,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_text_embedder(self): | |
return AsyncTextEmbedder( | |
api_key=self._api_key, | |
api_base_url=self._api_base, | |
model=self._embedding_model, | |
api_version=self._api_version, | |
model=self._model, | |
timeout=self._timeout, | |
**self._kwargs, | |
**self._model_kwargs, | |
) | |
def get_text_embedder(self): | |
return AsyncTextEmbedder( | |
api_key=self._api_key, | |
api_base=self._api_base, | |
api_version=self._api_version, | |
model=self._model, | |
timeout=self._timeout, | |
**self._model_kwargs, | |
) |
🤖 Prompt for AI Agents
In wren-ai-service/src/providers/embedder/litellm.py around lines 195 to 203,
the constructor call uses the wrong keyword name api_base_url while
AsyncTextEmbedder expects api_base; update the argument name to
api_base=self._api_base (leave other args as-is) so the correct keyword is
passed and avoid the TypeError at runtime.
Summary by CodeRabbit