Skip to content

Commit 540ae01

Browse files
committed
feat: add embeddings provider configuration and factory
Introduce a flexible embeddings configuration system: - Add EMBEDDINGS_PROVIDER to .env.example and config - Create EmbeddingsFactory to dynamically select embedding providers - Refactor existing code to use the new embedding factory - Prepare for future multi-provider embedding support
1 parent af089b8 commit 540ae01

File tree

7 files changed

+41
-30
lines changed

7 files changed

+41
-30
lines changed

backend/.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ MYSQL_DATABASE=ragwebui
88
SECRET_KEY=your-secret-key-here
99
ACCESS_TOKEN_EXPIRE_MINUTES=30
1010

11+
# Embeddings settings
12+
# if set to openai, use the openai api key and base url
13+
EMBEDDINGS_PROVIDER=openai
14+
1115
# Vector DB settings
1216
CHROMA_DB_HOST=localhost
1317
CHROMA_DB_PORT=8001

backend/app/api/api_v1/knowledge_base.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
44
from sqlalchemy.orm import Session
55
from langchain_chroma import Chroma
6-
from langchain_openai import OpenAIEmbeddings
76
from sqlalchemy import text
87
import logging
98
from datetime import datetime, timedelta
@@ -25,6 +24,7 @@
2524
from app.core.minio import get_minio_client
2625
from minio.error import MinioException
2726
from app.services.vector_store import VectorStoreFactory
27+
from app.services.embedding.embedding_factory import EmbeddingsFactory
2828

2929
router = APIRouter()
3030

@@ -162,10 +162,7 @@ async def delete_knowledge_base(
162162

163163
# Initialize services
164164
minio_client = get_minio_client()
165-
embeddings = OpenAIEmbeddings(
166-
openai_api_key=settings.OPENAI_API_KEY,
167-
openai_api_base=settings.OPENAI_API_BASE
168-
)
165+
embeddings = EmbeddingsFactory.create()
169166

170167
vector_store = VectorStoreFactory.create(
171168
store_type=settings.VECTOR_STORE_TYPE,
@@ -511,10 +508,7 @@ async def test_retrieval(
511508
detail=f"Knowledge base {request.kb_id} not found",
512509
)
513510

514-
embeddings = OpenAIEmbeddings(
515-
openai_api_key=settings.OPENAI_API_KEY,
516-
openai_api_base=settings.OPENAI_API_BASE
517-
)
511+
embeddings = EmbeddingsFactory.create()
518512

519513
vector_store = VectorStoreFactory.create(
520514
store_type=settings.VECTOR_STORE_TYPE,

backend/app/api/openapi/knowledge.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from fastapi import APIRouter, Depends, HTTPException
33
from sqlalchemy.orm import Session
44
from langchain_chroma import Chroma
5-
from langchain_openai import OpenAIEmbeddings
65
from app.services.vector_store import VectorStoreFactory
76

87
from app import models
98
from app.db.session import get_db
109
from app.core.security import get_api_key_user
1110
from app.core.config import settings
11+
from app.services.embedding.embedding_factory import EmbeddingsFactory
1212

1313
router = APIRouter()
1414

@@ -36,10 +36,7 @@ def query_knowledge_base(
3636
detail=f"Knowledge base {knowledge_base_id} not found",
3737
)
3838

39-
embeddings = OpenAIEmbeddings(
40-
openai_api_key=settings.OPENAI_API_KEY,
41-
openai_api_base=settings.OPENAI_API_BASE
42-
)
39+
embeddings = EmbeddingsFactory.create()
4340

4441
vector_store = VectorStoreFactory.create(
4542
store_type=settings.VECTOR_STORE_TYPE,

backend/app/core/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def get_database_url(self) -> str:
3636
ALGORITHM: str = "HS256"
3737
ACCESS_TOKEN_EXPIRE_MINUTES: int = 10080
3838

39+
# Embeddings settings
40+
EMBEDDINGS_PROVIDER: str = os.getenv("EMBEDDINGS_PROVIDER", "openai")
41+
3942
# Vector DB settings
4043
CHROMA_DB_HOST: str = os.getenv("CHROMA_DB_HOST", "localhost")
4144
CHROMA_DB_PORT: int = int(os.getenv("CHROMA_DB_PORT", "8001"))

backend/app/services/chat_service.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import base64
33
from typing import List, AsyncGenerator
44
from sqlalchemy.orm import Session
5-
from langchain_chroma import Chroma
6-
from langchain_openai import OpenAIEmbeddings
75
from langchain_openai import ChatOpenAI
86
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
97
from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -14,6 +12,7 @@
1412
from app.models.knowledge import KnowledgeBase, Document
1513
from langchain.globals import set_verbose, set_debug
1614
from app.services.vector_store import VectorStoreFactory
15+
from app.services.embedding.embedding_factory import EmbeddingsFactory
1716

1817
set_verbose(True)
1918
set_debug(True)
@@ -52,10 +51,7 @@ async def generate_response(
5251
)
5352

5453
# Initialize embeddings
55-
embeddings = OpenAIEmbeddings(
56-
openai_api_key=settings.OPENAI_API_KEY,
57-
openai_api_base=settings.OPENAI_API_BASE
58-
)
54+
embeddings = EmbeddingsFactory.create()
5955

6056
# Create a vector store for each knowledge base
6157
vector_stores = []

backend/app/services/document_processor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from langchain.text_splitter import RecursiveCharacterTextSplitter
1717
from langchain_core.documents import Document as LangchainDocument
1818
from pydantic import BaseModel
19-
from langchain_openai import OpenAIEmbeddings
20-
from langchain_chroma import Chroma
2119
from sqlalchemy import create_engine, text
2220
from sqlalchemy.orm import Session
2321
from app.core.config import settings
@@ -31,6 +29,7 @@
3129
from minio import Minio
3230
from minio.commonconfig import CopySource
3331
from app.services.vector_store import VectorStoreFactory
32+
from app.services.embedding.embedding_factory import EmbeddingsFactory
3433

3534
class UploadResult(BaseModel):
3635
file_path: str
@@ -56,10 +55,7 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
5655

5756
# Initialize embeddings
5857
logger.info("Initializing OpenAI embeddings...")
59-
embeddings = OpenAIEmbeddings(
60-
openai_api_key=settings.OPENAI_API_KEY,
61-
openai_api_base=settings.OPENAI_API_BASE
62-
)
58+
embeddings = EmbeddingsFactory.create()
6359

6460
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
6561
vector_store = VectorStoreFactory.create(
@@ -303,10 +299,7 @@ async def process_document_background(
303299

304300
# 3. 创建向量存储
305301
logger.info(f"Task {task_id}: Initializing vector store")
306-
embeddings = OpenAIEmbeddings(
307-
openai_api_key=settings.OPENAI_API_KEY,
308-
openai_api_base=settings.OPENAI_API_BASE
309-
)
302+
embeddings = EmbeddingsFactory.create()
310303

311304
vector_store = VectorStoreFactory.create(
312305
store_type=settings.VECTOR_STORE_TYPE,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from app.core.config import settings
2+
from langchain_openai import OpenAIEmbeddings
3+
# If you plan on adding other embeddings, import them here
4+
# from some_other_module import AnotherEmbeddingClass
5+
6+
class EmbeddingsFactory:
7+
@staticmethod
8+
def create():
9+
"""
10+
Factory method to create an embeddings instance based on .env config.
11+
"""
12+
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
13+
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
14+
15+
if embeddings_provider == "openai":
16+
return OpenAIEmbeddings(
17+
openai_api_key=settings.OPENAI_API_KEY,
18+
openai_api_base=settings.OPENAI_API_BASE
19+
)
20+
# Extend with other providers:
21+
# elif embeddings_provider == "another_provider":
22+
# return AnotherEmbeddingClass(...)
23+
else:
24+
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")

0 commit comments

Comments
 (0)