5
5
from contextlib import contextmanager
6
6
from typing import Any , Generator , List , Type
7
7
8
- from app .bootstrap import BOOTSTRAP
9
8
from app .config import Config , get_config
10
9
from app .models .base import BaseModel
10
+ from app .seeding import SEEDING
11
11
from fastapi import Depends
12
12
from sqlalchemy import Engine , create_engine , text
13
13
from sqlalchemy .exc import SQLAlchemyError
@@ -25,9 +25,14 @@ class DatabaseConnection:
25
25
def __init__ (self , config : Config = Depends (get_config )) -> None :
26
26
# Ensure the database folder exists.
27
27
os .makedirs (config .database_path .parent , exist_ok = True )
28
- self .engine = create_engine (
29
- config .database_url , connect_args = {"check_same_thread" : False }
30
- )
28
+ # Create engine with appropriate connection args for SQLite or other databases
29
+ db_url = config .database_url
30
+ if db_url .startswith ("sqlite" ): # SQLite needs check_same_thread
31
+ self .engine = create_engine (
32
+ db_url , connect_args = {"check_same_thread" : False }
33
+ )
34
+ else :
35
+ self .engine = create_engine (db_url )
31
36
self .session_local = sessionmaker (
32
37
autocommit = False ,
33
38
autoflush = False ,
@@ -62,7 +67,7 @@ def seed_bootstrap_data(self) -> None:
62
67
This method is called once during initialization.
63
68
"""
64
69
with self .get_session () as session :
65
- for model , seeds in BOOTSTRAP .items ():
70
+ for model , seeds in SEEDING .items ():
66
71
try :
67
72
self ._seed_model (
68
73
session = session ,
@@ -104,7 +109,6 @@ def _seed_model(
104
109
logger .info ("Seeding data for table '%s'" , table_name )
105
110
106
111
for seed in seeds :
107
- logger .debug ("Merging seed with id %s for table '%s'" , seed .id , table_name )
108
112
session .merge (seed )
109
113
session .flush ()
110
114
logger .info ("Merged %d seed(s) for table '%s'" , len (seeds ), table_name )
@@ -130,7 +134,24 @@ def _seed_model(
130
134
logger .info (
131
135
"Current sequence for table '%s': %d" , table_name , current_seq
132
136
)
133
- if current_seq < (sequence_start - 1 ):
137
+ # Determine the current max ID in the table
138
+ max_id_result = session .execute (
139
+ text (f"SELECT MAX(id) FROM { table_name } " )
140
+ ).fetchone ()
141
+ max_id = (
142
+ max_id_result [0 ]
143
+ if max_id_result and max_id_result [0 ] is not None
144
+ else 0
145
+ )
146
+ logger .info ("Max id for table '%s': %d" , table_name , max_id )
147
+ if max_id >= sequence_start :
148
+ logger .info (
149
+ "Skipping sequence update for table '%s' as max id %d >= desired start %d" ,
150
+ table_name ,
151
+ max_id ,
152
+ sequence_start ,
153
+ )
154
+ elif current_seq < (sequence_start - 1 ):
134
155
logger .info (
135
156
"Updating sequence for table '%s' to %d" ,
136
157
table_name ,
@@ -161,7 +182,24 @@ def _seed_model(
161
182
sequence_name ,
162
183
current_seq ,
163
184
)
164
- if current_seq < (sequence_start - 1 ):
185
+ # Determine the current max ID in the table
186
+ max_id_result = session .execute (
187
+ text (f"SELECT MAX(id) FROM { table_name } " )
188
+ ).fetchone ()
189
+ max_id = (
190
+ max_id_result [0 ]
191
+ if max_id_result and max_id_result [0 ] is not None
192
+ else 0
193
+ )
194
+ logger .info ("Max id for table '%s': %d" , table_name , max_id )
195
+ if max_id >= sequence_start :
196
+ logger .info (
197
+ "Skipping sequence update for table '%s' as max id %d >= desired start %d" ,
198
+ table_name ,
199
+ max_id ,
200
+ sequence_start ,
201
+ )
202
+ elif current_seq < (sequence_start - 1 ):
165
203
logger .info (
166
204
"Updating sequence for table '%s' (sequence: '%s') to %d" ,
167
205
table_name ,
0 commit comments