Skip to content

Commit 830619e

Browse files
committed
property resolution fixed
1 parent 083dc62 commit 830619e

File tree

2 files changed

+93
-45
lines changed

2 files changed

+93
-45
lines changed

params/params.py

+42-17
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class Param:
1616
""" Provides a parameter specification to be used within a Params instance. """
17-
def __init__(self, value, doc: Text = None, dtype: Type = None, required: bool = False):
17+
def __init__(self, value, doc: Text = None, dtype: Type = None, required: bool = False, params_class=None):
1818
"""
1919
Constructs a parameter specification to be used in a Params instance:
2020
@@ -31,17 +31,28 @@ class MyParams(pp.Params):
3131
:param dtype: (Optional) type
3232
:param required: default is True.
3333
"""
34-
self.default_value = value
34+
self.params_class = params_class
35+
self._default_value = value
3536
self.doc_string = doc
3637
self.required = required
37-
self.dtype = type(value) if (dtype is None and value is not None) else dtype
38-
if value is not None:
38+
self.dtype = dtype
39+
if dtype is None and value is not None and not callable(value):
40+
self.dtype = type(value)
41+
if value is not None and not callable(value):
3942
if not isinstance(value, self.dtype):
4043
raise RuntimeError(f"Param({value}) does not match dtype:[{self.dtype}]")
4144
self.name = None
45+
self.is_property = callable(value)
4246

47+
@property
48+
def default_value(self):
49+
return self._default_value(self.params_class) if self.is_property else self._default_value
4350

44-
class Params(dict): # TODO use collections.UserDict instead of dict - see #1
51+
def value(self, params):
52+
return self._default_value(params) if self.is_property else getattr(self.name, params)
53+
54+
55+
class Params(dict):
4556
""" Base class for defining safe parameter dictionaries.
4657
4758
Example:
@@ -67,37 +78,48 @@ class MyParams(pp.Params):
6778

6879
def __init_subclass__(cls, **kwargs):
6980
""" Aggregates the Param spec of the parameters over the hierarchy. """
70-
specs = {}
71-
for base in reversed(cls.__bases__):
81+
base_specs = {}
82+
for base in cls.__bases__:
7283
if issubclass(base, Params):
73-
specs.update(base.__specs)
84+
base_specs.update(base.__specs)
7485

86+
cls_specs = [] # evaluate in order of declaration
7587
for attr, value in cls.__dict__.items():
7688
if attr.startswith("_") or callable(getattr(cls, attr)):
7789
continue
7890

79-
param_spec = getattr(cls, attr)
80-
if isinstance(param_spec, property):
81-
param_spec = Param(param_spec.fget(cls))
82-
elif not isinstance(param_spec, Param):
91+
attr_val = getattr(cls, attr)
92+
if isinstance(attr_val, property):
93+
param_spec = Param(attr_val.fget, params_class=cls)
94+
elif not isinstance(attr_val, Param):
8395
param_spec = Param(value)
96+
else:
97+
param_spec = attr_val
8498

8599
param_spec.name = attr
86-
specs[attr] = param_spec
100+
cls_specs.append((attr, param_spec))
87101

88-
for attr, value in specs.items():
102+
_specs = {}
103+
for attr, value in list(base_specs.items()) + cls_specs:
89104
setattr(cls, attr, value.default_value)
105+
_specs[attr] = value
90106

91-
cls.__specs = specs
107+
cls.__specs = _specs
92108
cls.__defaults = {key: val.default_value for key, val in cls.__specs.items()}
93109

94110
def __init__(self, *args, **kwargs):
95111
super(Params, self).__init__()
96112
self.update(self.__defaults) # start with default values
97113
self.update(dict(*args)) # override with tuple list
98114
self.update(kwargs) # override with kwargs
115+
# update any overridden @property parameters
116+
prop_specs = list(filter(lambda spec: spec.is_property,
117+
self.__class__.__specs.values()))
99118

100-
def update(self, arg=None, **kwargs):
119+
for spec in prop_specs:
120+
self.update({spec.name: spec.value(self)})
121+
122+
def update(self, arg=None, **kwargs): # see dict.update()
101123
if arg:
102124
keys = getattr(arg, "keys") if hasattr(arg, "keys") else None
103125
if keys and (inspect.ismethod(keys) or inspect.isbuiltin(keys)):
@@ -111,7 +133,10 @@ def update(self, arg=None, **kwargs):
111133

112134
def __getattribute__(self, attr):
113135
if not attr.startswith("_") and attr in self.__defaults:
114-
return self.__getitem__(attr)
136+
if self.__specs[attr].is_property:
137+
return self.__specs[attr].value(self)
138+
else:
139+
return self.__getitem__(attr)
115140
return object.__getattribute__(self, attr)
116141

117142
def __setattr__(self, key, value):

tests/test_subclassing.py

+51-28
Original file line numberDiff line numberDiff line change
@@ -7,70 +7,93 @@
77

88
import unittest
99

10-
from params import Params
10+
import params as pp
1111

1212

13-
class BaseParams(Params):
14-
param_a = True
15-
param_b = 1
13+
class BaseParams(pp.Params):
14+
param_a = "a"
15+
param_b = "b"
1616

1717

1818
class SubParams(BaseParams):
19-
param_c = 'a'
20-
param_a = False
19+
param_a = "Sa"
20+
param_c = "Sc"
2121

2222

2323
class SubParamsA(SubParams):
24-
param_d = 'A'
25-
param_e = False
24+
param_d = "aD"
25+
param_e = "aE"
2626

2727

2828
class SubParamsB(SubParams):
29-
param_f = 'B'
30-
param_g = False
29+
param_f = 'SBf'
30+
31+
@property
32+
def param_g(self):
33+
return "SBBg_" + self.param_b
34+
35+
@property
36+
def param_a(self):
37+
return "SBBa_" + self.param_g + self.param_c
3138

3239

3340
class MySubParams(SubParamsA, SubParamsB):
34-
param_h = 'a'
35-
param_j = False
41+
param_h = "MSh"
42+
param_j = "MSj"
3643

3744
@property
38-
def param_f(self):
39-
return 'F'
45+
def param_c(self):
46+
return "MSc_" + self.param_g
47+
48+
49+
class AnotherSubParams(MySubParams):
50+
@property
51+
def param_h(self):
52+
return "ASh_" + self.param_d
53+
54+
@property
55+
def param_g(self):
56+
return "ASg_" + self.param_h
57+
4058

4159

4260
class ParamsSubclassingTest(unittest.TestCase):
4361
def test_subclassing(self):
44-
Params()
4562
params = SubParams()
46-
expected = {'param_a': False, 'param_b': 1, 'param_c': 'a'}
63+
expected = {"param_a": "Sa", "param_b": "b", "param_c": "Sc"}
4764
self.assertEqual(dict(params), expected)
4865

49-
params = SubParams(param_b=2)
50-
params.param_c = 'b'
51-
params.param_a = True
52-
expected = {'param_a': True, 'param_b': 2, 'param_c': 'b'}
66+
params = SubParams(param_b="bb")
67+
params.param_c = "cc"
68+
params.param_a = "aa"
69+
expected = {"param_a": "aa", "param_b": "bb", "param_c": "cc"}
5370
self.assertEqual(dict(params), expected)
5471

5572
params = BaseParams()
5673
try:
57-
params.param_c = 'c'
74+
params.param_c = "cc"
5875
self.fail()
5976
except AttributeError:
6077
pass
6178

62-
params = SubParams(param_b=2).clone(param_c=3)
63-
expected = {'param_b': 2, 'param_a': False, 'param_c': 3}
79+
params = SubParams(param_b="bb").clone(param_c="cc")
80+
expected = {"param_b": "bb", "param_a": "Sa", "param_c": "cc"}
6481
self.assertEqual(dict(params), expected)
6582

6683
def test_hierarchy(self):
67-
params = MySubParams()
68-
self.assertEqual(params.param_b, 1)
69-
self.assertEqual(params.param_a, False)
84+
params = MySubParams(param_f="SBf")
85+
self.assertEqual(params.param_b, "b")
86+
self.assertEqual(params.param_a, "SBBa_SBBg_bMSc_SBBg_b")
87+
7088
MySubParams.to_argument_parser().print_help()
7189

72-
self.assertEqual(params.param_f, "F")
73-
self.assertEqual(MySubParams.param_f, "F")
90+
self.assertEqual(params.param_f, "SBf")
91+
self.assertEqual(MySubParams.param_f, "SBf")
92+
93+
def test_hierarchy_override(self):
94+
params = AnotherSubParams(param_d='Z')
95+
self.assertEqual(params.param_g, "ASg_ASh_Z")
96+
7497

7598
if __name__ == '__main__':
7699
unittest.main()

0 commit comments

Comments
 (0)