Skip to content

Commit

Permalink
Common: Support postgreSQL database as the metadata db. (infiniflow#2357
Browse files Browse the repository at this point in the history
)

infiniflow#2356

### What problem does this PR solve?

As title

### Type of change

- [X] New Feature (non-breaking change which adds functionality)
  • Loading branch information
baifachuan authored Sep 12, 2024
1 parent ba834ae commit f8e9a05
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 12 deletions.
82 changes: 72 additions & 10 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
import sys
import typing
import operator
from enum import Enum
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from playhouse.migrate import MySQLMigrator, migrate
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from peewee import (
BigIntegerField, BooleanField, CharField,
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY
from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
from api.utils.log_utils import getLogger
from api import utils

Expand Down Expand Up @@ -58,8 +59,13 @@ def _singleton():
"write_access"}


class TextFieldType(Enum):
MYSQL = 'LONGTEXT'
POSTGRES = 'TEXT'


class LongTextField(TextField):
field_type = 'LONGTEXT'
field_type = TextFieldType[DATABASE_TYPE.upper()].value


class JSONField(LongTextField):
Expand Down Expand Up @@ -266,18 +272,69 @@ def __init__(self, object_hook=utils.from_dict_hook,
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)

class PooledDatabase(Enum):
MYSQL = PooledMySQLDatabase
POSTGRES = PooledPostgresqlDatabase


class DatabaseMigrator(Enum):
MYSQL = MySQLMigrator
POSTGRES = PostgresqlMigrator


@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(
db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
stat_logger.info('init database on cluster mode successfully')

class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB

def lock(self):
cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire postgres lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')

def unlock(self):
cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(
f'postgres lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'postgres lock {self.lock_name} does not exist')

def __enter__(self):
if isinstance(self.db, PostgresDatabaseLock):
self.lock()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PostgresDatabaseLock):
self.unlock()

class DatabaseLock:
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)

return magic

class MysqlDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
Expand Down Expand Up @@ -325,8 +382,13 @@ def magic(*args, **kwargs):
return magic


class DatabaseLock(Enum):
MYSQL = MysqlDatabaseLock
POSTGRES = PostgresDatabaseLock


DB = BaseDataBase().database_connection
DB.lock = DatabaseLock
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value


def close_connection():
Expand Down Expand Up @@ -918,7 +980,7 @@ class Meta:

def migrate_db():
with DB.transaction():
migrator = MySQLMigrator(DB)
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
try:
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
Expand Down
7 changes: 6 additions & 1 deletion api/db/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from functools import reduce
from typing import Dict, Type, Union

from playhouse.pool import PooledMySQLDatabase

from api.utils import current_timestamp, timestamp_to_date

from api.db.db_models import DB, DataBaseModel
Expand Down Expand Up @@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict:
query = query.on_conflict(preserve=preserve)
if isinstance(DB, PooledMySQLDatabase):
query = query.on_conflict(preserve=preserve)
else:
query = query.on_conflict(conflict_target="id", preserve=preserve)
query.execute()


Expand Down
3 changes: 2 additions & 1 deletion api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")

DATABASE = decrypt_database_config(name="mysql")
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)

# Switch
# upload
Expand Down
8 changes: 8 additions & 0 deletions conf/service_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ mysql:
port: 3306
max_connections: 100
stale_timeout: 30
postgres:
name: 'rag_flow'
user: 'rag_flow'
password: 'infini_rag_flow'
host: 'postgres'
port: 5432
max_connections: 100
stale_timeout: 30
minio:
user: 'rag_flow'
password: 'infini_rag_flow'
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Flask==3.0.3
Flask_Cors==5.0.0
Flask_Login==0.6.3
flask_session==0.8.0
psycopg2==2.9.9
google_search_results==2.4.2
groq==0.9.0
hanziconv==0.3.2
Expand Down

0 comments on commit f8e9a05

Please sign in to comment.