2020
2121logger = logging .getLogger (__name__ )
2222
23+ NON_SQLITE_DATABASE_PATH_WARNING = """\
24+ Ignoring 'database_path' setting: not using a sqlite3 database.
25+ --------------------------------------------------------------------------------
26+ """
27+
2328DEFAULT_CONFIG = """\
2429 ## Database ##
2530
@@ -105,6 +110,11 @@ def __init__(self, name: str, db_config: dict):
105110class DatabaseConfig (Config ):
106111 section = "database"
107112
113+ def __init__ (self , * args , ** kwargs ):
114+ super ().__init__ (* args , ** kwargs )
115+
116+ self .databases = []
117+
108118 def read_config (self , config , ** kwargs ):
109119 self .event_cache_size = self .parse_size (config .get ("event_cache_size" , "10K" ))
110120
@@ -125,54 +135,69 @@ def read_config(self, config, **kwargs):
125135
126136 multi_database_config = config .get ("databases" )
127137 database_config = config .get ("database" )
138+ database_path = config .get ("database_path" )
128139
129140 if multi_database_config and database_config :
130141 raise ConfigError ("Can't specify both 'database' and 'datbases' in config" )
131142
132143 if multi_database_config :
133- if config . get ( " database_path" ) :
144+ if database_path :
134145 raise ConfigError ("Can't specify 'database_path' with 'databases'" )
135146
136147 self .databases = [
137148 DatabaseConnectionConfig (name , db_conf )
138149 for name , db_conf in multi_database_config .items ()
139150 ]
140151
141- else :
142- if database_config is None :
143- database_config = {"name" : "sqlite3" , "args" : {}}
144-
152+ if database_config :
145153 self .databases = [DatabaseConnectionConfig ("master" , database_config )]
146154
147- self .set_databasepath (config .get ("database_path" ))
155+ if database_path :
156+ if self .databases and self .databases [0 ].name != "sqlite3" :
157+ logger .warning (NON_SQLITE_DATABASE_PATH_WARNING )
158+ return
159+
160+ database_config = {"name" : "sqlite3" , "args" : {}}
161+ self .databases = [DatabaseConnectionConfig ("master" , database_config )]
162+ self .set_databasepath (database_path )
148163
149164 def generate_config_section (self , data_dir_path , ** kwargs ):
150165 return DEFAULT_CONFIG % {
151166 "database_path" : os .path .join (data_dir_path , "homeserver.db" )
152167 }
153168
154169 def read_arguments (self , args ):
155- self .set_databasepath (args .database_path )
170+ """
171+ Cases for the cli input:
172+ - If no databases are configured and no database_path is set, raise.
173+ - No databases and only database_path available ==> sqlite3 db.
174+ - If there are multiple databases and a database_path raise an error.
175+ - If the database set in the config file is sqlite then
176+ overwrite with the command line argument.
177+ """
156178
157- def set_databasepath (self , database_path ):
158- if database_path is None :
179+ if args .database_path is None :
180+ if not self .databases :
181+ raise ConfigError ("No database config provided" )
159182 return
160183
161- if database_path != ":memory:" :
162- database_path = self .abspath (database_path )
184+ if len (self .databases ) == 0 :
185+ database_config = {"name" : "sqlite3" , "args" : {}}
186+ self .databases = [DatabaseConnectionConfig ("master" , database_config )]
187+ self .set_databasepath (args .database_path )
188+ return
189+
190+ if self .get_single_database ().name == "sqlite3" :
191+ self .set_databasepath (args .database_path )
192+ else :
193+ logger .warning (NON_SQLITE_DATABASE_PATH_WARNING )
163194
164- # We only support setting a database path if we have a single sqlite3
165- # database.
166- if len (self .databases ) != 1 :
167- raise ConfigError ("Cannot specify 'database_path' with multiple databases" )
195+ def set_databasepath (self , database_path ):
168196
169- database = self .get_single_database ()
170- if database .config ["name" ] != "sqlite3" :
171- # We don't raise here as we haven't done so before for this case.
172- logger .warn ("Ignoring 'database_path' for non-sqlite3 database" )
173- return
197+ if database_path != ":memory:" :
198+ database_path = self .abspath (database_path )
174199
175- database .config ["args" ]["database" ] = database_path
200+ self . databases [ 0 ] .config ["args" ]["database" ] = database_path
176201
177202 @staticmethod
178203 def add_arguments (parser ):
@@ -187,7 +212,7 @@ def add_arguments(parser):
187212 def get_single_database (self ) -> DatabaseConnectionConfig :
188213 """Returns the database if there is only one, useful for e.g. tests
189214 """
190- if len ( self .databases ) != 1 :
215+ if not self .databases :
191216 raise Exception ("More than one database exists" )
192217
193218 return self .databases [0 ]
0 commit comments