1
1
# -*- coding: utf-8 -*-
2
2
# Copyright 2018 New Vector Ltd
3
+ # Copyright 2019 The Matrix.org Foundation C.I.C.
3
4
#
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
5
6
# you may not use this file except in compliance with the License.
12
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14
# See the License for the specific language governing permissions and
14
15
# limitations under the License.
16
+
15
17
from synapse .python_dependencies import DependencyException , check_requirements
18
+ from synapse .util .module_loader import load_python_module
16
19
17
20
from ._base import Config , ConfigError
18
21
19
22
23
+ def _dict_merge (merge_dict , into_dict ):
24
+ """Do a deep merge of two dicts
25
+
26
+ Recursively merges `merge_dict` into `into_dict`:
27
+ * For keys where both `merge_dict` and `into_dict` have a dict value, the values
28
+ are recursively merged
29
+ * For all other keys, the values in `into_dict` (if any) are overwritten with
30
+ the value from `merge_dict`.
31
+
32
+ Args:
33
+ merge_dict (dict): dict to merge
34
+ into_dict (dict): target dict
35
+ """
36
+ for k , v in merge_dict .items ():
37
+ if k not in into_dict :
38
+ into_dict [k ] = v
39
+ continue
40
+
41
+ current_val = into_dict [k ]
42
+
43
+ if isinstance (v , dict ) and isinstance (current_val , dict ):
44
+ _dict_merge (v , current_val )
45
+ continue
46
+
47
+ # otherwise we just overwrite
48
+ into_dict [k ] = v
49
+
50
+
20
51
class SAML2Config (Config ):
21
52
def read_config (self , config , ** kwargs ):
22
53
self .saml2_enabled = False
@@ -36,15 +67,20 @@ def read_config(self, config, **kwargs):
36
67
37
68
self .saml2_enabled = True
38
69
39
- import saml2 .config
40
-
41
- self .saml2_sp_config = saml2 .config .SPConfig ()
42
- self .saml2_sp_config .load (self ._default_saml_config_dict ())
43
- self .saml2_sp_config .load (saml2_config .get ("sp_config" , {}))
70
+ saml2_config_dict = self ._default_saml_config_dict ()
71
+ _dict_merge (
72
+ merge_dict = saml2_config .get ("sp_config" , {}), into_dict = saml2_config_dict
73
+ )
44
74
45
75
config_path = saml2_config .get ("config_path" , None )
46
76
if config_path is not None :
47
- self .saml2_sp_config .load_file (config_path )
77
+ mod = load_python_module (config_path )
78
+ _dict_merge (merge_dict = mod .CONFIG , into_dict = saml2_config_dict )
79
+
80
+ import saml2 .config
81
+
82
+ self .saml2_sp_config = saml2 .config .SPConfig ()
83
+ self .saml2_sp_config .load (saml2_config_dict )
48
84
49
85
# session lifetime: in milliseconds
50
86
self .saml2_session_lifetime = self .parse_duration (
0 commit comments