diff --git a/api/db/db_models.py b/api/db/db_models.py index 39c6f9351e..8b547f8ec6 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -18,18 +18,19 @@ import sys import typing import operator +from enum import Enum from functools import wraps from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from flask_login import UserMixin -from playhouse.migrate import MySQLMigrator, migrate +from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from peewee import ( BigIntegerField, BooleanField, CharField, CompositeKey, IntegerField, TextField, FloatField, DateTimeField, Field, Model, Metadata ) -from playhouse.pool import PooledMySQLDatabase +from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from api.db import SerializedType, ParserType -from api.settings import DATABASE, stat_logger, SECRET_KEY +from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE from api.utils.log_utils import getLogger from api import utils @@ -58,8 +59,13 @@ def _singleton(): "write_access"} +class TextFieldType(Enum): + MYSQL = 'LONGTEXT' + POSTGRES = 'TEXT' + + class LongTextField(TextField): - field_type = 'LONGTEXT' + field_type = TextFieldType[DATABASE_TYPE.upper()].value class JSONField(LongTextField): @@ -266,18 +272,69 @@ def __init__(self, object_hook=utils.from_dict_hook, super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs) +class PooledDatabase(Enum): + MYSQL = PooledMySQLDatabase + POSTGRES = PooledPostgresqlDatabase + + +class DatabaseMigrator(Enum): + MYSQL = MySQLMigrator + POSTGRES = PostgresqlMigrator + @singleton class BaseDataBase: def __init__(self): database_config = DATABASE.copy() db_name = database_config.pop("name") - self.database_connection = PooledMySQLDatabase( - db_name, **database_config) - stat_logger.info('init mysql database on cluster mode successfully') + self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) + stat_logger.info('init database on cluster mode successfully') + +class PostgresDatabaseLock: + def __init__(self, lock_name, timeout=10, db=None): + self.lock_name = lock_name + self.timeout = int(timeout) + self.db = db if db else DB + + def lock(self): + cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout) + ret = cursor.fetchone() + if ret[0] == 0: + raise Exception(f'acquire postgres lock {self.lock_name} timeout') + elif ret[0] == 1: + return True + else: + raise Exception(f'failed to acquire lock {self.lock_name}') + + def unlock(self): + cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout) + ret = cursor.fetchone() + if ret[0] == 0: + raise Exception( + f'postgres lock {self.lock_name} was not established by this thread') + elif ret[0] == 1: + return True + else: + raise Exception(f'postgres lock {self.lock_name} does not exist') + def __enter__(self): + if isinstance(self.db, PostgresDatabaseLock): + self.lock() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(self.db, PostgresDatabaseLock): + self.unlock() -class DatabaseLock: + def __call__(self, func): + @wraps(func) + def magic(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return magic + +class MysqlDatabaseLock: def __init__(self, lock_name, timeout=10, db=None): self.lock_name = lock_name self.timeout = int(timeout) @@ -325,8 +382,13 @@ def magic(*args, **kwargs): return magic +class DatabaseLock(Enum): + MYSQL = MysqlDatabaseLock + POSTGRES = PostgresDatabaseLock + + DB = BaseDataBase().database_connection -DB.lock = DatabaseLock +DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value def close_connection(): @@ -918,7 +980,7 @@ class Meta: def migrate_db(): with DB.transaction(): - migrator = MySQLMigrator(DB) + migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB) try: migrate( migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", diff --git a/api/db/db_utils.py b/api/db/db_utils.py index 7e156b4a78..795fc7f767 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -17,6 +17,8 @@ from functools import reduce from typing import Dict, Type, Union +from playhouse.pool import PooledMySQLDatabase + from api.utils import current_timestamp, timestamp_to_date from api.db.db_models import DB, DataBaseModel @@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): with DB.atomic(): query = model.insert_many(data_source[i:i + batch_size]) if replace_on_conflict: - query = query.on_conflict(preserve=preserve) + if isinstance(DB, PooledMySQLDatabase): + query = query.on_conflict(preserve=preserve) + else: + query = query.on_conflict(conflict_target="id", preserve=preserve) query.execute() diff --git a/api/settings.py b/api/settings.py index 95bf196138..fdb9bb595c 100644 --- a/api/settings.py +++ b/api/settings.py @@ -164,7 +164,8 @@ PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") -DATABASE = decrypt_database_config(name="mysql") +DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') +DATABASE = decrypt_database_config(name=DATABASE_TYPE) # Switch # upload diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 8e983374d0..7d09df50d4 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -9,6 +9,14 @@ mysql: port: 3306 max_connections: 100 stale_timeout: 30 +postgres: + name: 'rag_flow' + user: 'rag_flow' + password: 'infini_rag_flow' + host: 'postgres' + port: 5432 + max_connections: 100 + stale_timeout: 30 minio: user: 'rag_flow' password: 'infini_rag_flow' diff --git a/requirements.txt b/requirements.txt index 720e44c977..2a8f577217 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ Flask==3.0.3 Flask_Cors==5.0.0 Flask_Login==0.6.3 flask_session==0.8.0 +psycopg2==2.9.9 google_search_results==2.4.2 groq==0.9.0 hanziconv==0.3.2