Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,20 @@ def _x_default(self):

a = A()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})
b = B()
self.assertEqual(b._trait_values, {'x': 20})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(b.x, 20)
self.assertEqual(b._trait_values, {'x': 20})
c = C()
self.assertEqual(c._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(c.x, 21)
self.assertEqual(c._trait_values, {'x': 21})
# Ensure that the base class remains unmolested when the _default
# initializer gets overridden in a subclass.
a = A()
c = C()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})

Expand Down Expand Up @@ -448,7 +444,7 @@ class A(HasTraits):
klass = Type(allow_none=True)

a = A()
self.assertEqual(a.klass, None)
self.assertEqual(a.klass, object)

a.klass = B
self.assertEqual(a.klass, B)
Expand Down Expand Up @@ -606,7 +602,9 @@ class Foo(object): pass
class A(HasTraits):
inst = Instance(Foo)

self.assertRaises(TraitError, A)
a = A()
with self.assertRaises(TraitError):
a.inst

def test_instance(self):
class Foo(object): pass
Expand Down Expand Up @@ -1110,8 +1108,14 @@ def test_dict_default_value():
"""Check that the `{}` default value of the Dict traitlet constructor is
actually copied."""

d1, d2 = Dict(), Dict()
nt.assert_false(d1.get_default_value() is d2.get_default_value())
class Foo(HasTraits):
d1 = Dict()
d2 = Dict()

foo = Foo()
nt.assert_equal(foo.d1, {})
nt.assert_equal(foo.d2, {})
nt.assert_is_not(foo.d1, foo.d2)


class TestValidationHook(TestCase):
Expand Down
155 changes: 60 additions & 95 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class link(object):
Examples
--------

>>> c = link((src, 'value'), (tgt, 'value'),
>>> c = link((src, 'value'), (tgt, 'value'))
>>> src.value = 5 # updates other objects as well
"""
updating = False
Expand Down Expand Up @@ -371,41 +371,50 @@ def init(self):
pass

def get_default_value(self):
"""Create a new instance of the default value."""
"""Retrieve the static default value for this trait"""
return self.default_value

def init_default_value(self, obj):
"""Instantiate the default value for the trait type.
def validate_default_value(self, obj):
"""Retrieve and validate the static default value"""
v = self.get_default_value()
return self._validate(obj, v)

This method is called when accessing the trait value for the first
time in :meth:`HasTraits.__get__`.
def init_default_value(self, obj):
"""DEPRECATED: Set the static default value for the trait type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deprecation warning on the deprecated method?

"""
value = self.get_default_value()
value = self._validate(obj, value)
warn("init_default_value is deprecated, and may be removed in the future",
stacklevel=2)
value = self.validate_default_value()
obj._trait_values[self.name] = value
return value

def _setup_dynamic_initializer(self, obj):
# Check for a deferred initializer defined in the same class as the
# trait declaration or above.
mro = type(obj).mro()
meth_name = '_%s_default' % self.name
for cls in mro[:mro.index(self.this_class)+1]:
if meth_name in cls.__dict__:
break
else:
return False
# Complete the dynamic initialization.
obj._trait_dyn_inits[self.name] = meth_name
return True

def _set_default_value_at_instance_init(self, obj):
# As above, but if no default was specified, don't try to set it.
# If the trait is accessed before it is given a value, init_default_value
# will be called at that point.
if (not self._setup_dynamic_initializer(obj)) \
def _dynamic_default_callable(self, obj):
"""Retrieve a callable to calculate the default for this traitlet.

This looks for:

- obj._{name}_default() on the class with the traitlet, or a subclass
that obj belongs to.
- trait.make_dynamic_default, which is defined by Instance

If neither exist, it returns None
"""
# Traitlets without a name are not on the instance, e.g. in List or Union
if self.name:
mro = type(obj).mro()
meth_name = '_%s_default' % self.name
for cls in mro[:mro.index(self.this_class)+1]:
if meth_name in cls.__dict__:
return getattr(obj, meth_name)

return getattr(self, 'make_dynamic_default', None)

def instance_init(self, obj):
# If no dynamic initialiser is present, and the trait implementation or
# use provides a static default, transfer that to obj._trait_values.
if (self._dynamic_default_callable(obj) is None) \
and (self.default_value is not Undefined):
self.init_default_value(obj)
self.validate_default_value(obj)

def __get__(self, obj, cls=None):
"""Get the value of the trait by self.name for the instance.
Expand All @@ -422,15 +431,13 @@ def __get__(self, obj, cls=None):
value = obj._trait_values[self.name]
except KeyError:
# Check for a dynamic initializer.
if self.name in obj._trait_dyn_inits:
method = getattr(obj, obj._trait_dyn_inits[self.name])
value = method()
# FIXME: Do we really validate here?
value = self._validate(obj, value)
obj._trait_values[self.name] = value
return value
dynamic_default = self._dynamic_default_callable(obj)
if dynamic_default is not None:
value = self._validate(obj, dynamic_default())
else:
return self.init_default_value(obj)
value = self.validate_default_value(obj)
obj._trait_values[self.name] = value
return value
except Exception:
# This should never be reached.
raise TraitError('Unexpected error in TraitType: '
Expand All @@ -443,7 +450,7 @@ def __set__(self, obj, value):
try:
old_value = obj._trait_values[self.name]
except KeyError:
old_value = Undefined
old_value = self.get_default_value()

obj._trait_values[self.name] = new_value
try:
Expand Down Expand Up @@ -554,7 +561,6 @@ def __new__(cls, *args, **kw):
inst = new_meth(cls, **kw)
inst._trait_values = {}
inst._trait_notifiers = {}
inst._trait_dyn_inits = {}
inst._cross_validation_lock = True
# Here we tell all the TraitType instances to set their default
# values on the instance.
Expand All @@ -569,8 +575,6 @@ def __new__(cls, *args, **kw):
else:
if isinstance(value, BaseDescriptor):
value.instance_init(inst)
if isinstance(value, TraitType) and key not in kw:
value._set_default_value_at_instance_init(inst)
inst._cross_validation_lock = False
return inst

Expand Down Expand Up @@ -901,7 +905,7 @@ def __init__ (self, default_value=None, klass=None, **metadata):
a particular class.

If only ``default_value`` is given, it is used for the ``klass`` as
well.
well. If neither are given, both default to ``object``.

Parameters
----------
Expand All @@ -915,13 +919,12 @@ def __init__ (self, default_value=None, klass=None, **metadata):
may be specified in a string like: 'foo.bar.MyClass'.
The string is resolved into real class, when the parent
:class:`HasTraits` class is instantiated.
allow_none : bool [ default True ]
Indicates whether None is allowed as an assignable value. Even if
``False``, the default value may be ``None``.
allow_none : bool [ default False ]
Indicates whether None is allowed as an assignable value.
"""
if default_value is None:
if klass is None:
klass = object
default_value = klass = object
elif klass is None:
klass = default_value

Expand Down Expand Up @@ -969,20 +972,6 @@ def _resolve_classes(self):
if isinstance(self.default_value, py3compat.string_types):
self.default_value = self._resolve_string(self.default_value)

def get_default_value(self):
return self.default_value


class DefaultValueGenerator(object):
"""A class for generating new default value instances."""

def __init__(self, *args, **kw):
self.args = args
self.kw = kw

def generate(self, klass):
return klass(*self.args, **self.kw)


class Instance(ClassBasedTraitType):
"""A trait whose value must be an instance of a specified class.
Expand Down Expand Up @@ -1030,25 +1019,15 @@ class or its subclasses. Our implementation is quite different
raise TraitError('The klass attribute must be a class'
' not: %r' % klass)

# self.klass is a class, so handle default_value
if args is None and kw is None:
default_value = None
else:
if args is None:
# kw is not None
args = ()
elif kw is None:
# args is not None
kw = {}

if not isinstance(kw, dict):
raise TraitError("The 'kw' argument must be a dict or None.")
if not isinstance(args, tuple):
raise TraitError("The 'args' argument must be a tuple or None.")
if (kw is not None) and not isinstance(kw, dict):
raise TraitError("The 'kw' argument must be a dict or None.")
if (args is not None) and not isinstance(args, tuple):
raise TraitError("The 'args' argument must be a tuple or None.")

default_value = DefaultValueGenerator(*args, **kw)
self.default_args = args
self.default_kwargs = kw

super(Instance, self).__init__(default_value, **metadata)
super(Instance, self).__init__(**metadata)

def validate(self, obj, value):
if isinstance(value, self.klass):
Expand All @@ -1075,18 +1054,11 @@ def _resolve_classes(self):
if isinstance(self.klass, py3compat.string_types):
self.klass = self._resolve_string(self.klass)

def get_default_value(self):
"""Instantiate a default value instance.

This is called when the containing HasTraits classes'
:meth:`__new__` method is called to ensure that a unique instance
is created for each HasTraits instance.
"""
dv = self.default_value
if isinstance(dv, DefaultValueGenerator):
return dv.generate(self.klass)
else:
return dv
def make_dynamic_default(self):
if (self.default_args is None) and (self.default_kwargs is None):
return None
return self.klass(*(self.default_args or ()),
**(self.default_kwargs or {}))


class ForwardDeclaredMixin(object):
Expand Down Expand Up @@ -1166,7 +1138,6 @@ def __init__(self, trait_types, **metadata):

def instance_init(self, obj):
for trait_type in self.trait_types:
trait_type.name = self.name
trait_type.this_class = self.this_class
trait_type.instance_init(obj)
super(Union, self).instance_init(obj)
Expand Down Expand Up @@ -1516,7 +1487,6 @@ def __init__(self, trait=None, default_value=None, **metadata):

if is_trait(trait):
self._trait = trait() if isinstance(trait, type) else trait
self._trait.name = 'element'
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))

Expand Down Expand Up @@ -1710,7 +1680,6 @@ def __init__(self, *traits, **metadata):
self._traits = []
for trait in traits:
t = trait() if isinstance(trait, type) else trait
t.name = 'element'
self._traits.append(t)

if self._traits and default_value is None:
Expand Down Expand Up @@ -1790,14 +1759,10 @@ def __init__(self, trait=None, traits=None, default_value=NoDefaultSpecified,
# Case where a type of TraitType is provided rather than an instance
if is_trait(trait):
self._trait = trait() if isinstance(trait, type) else trait
self._trait.name = 'element'
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait))

self._traits = traits
if traits is not None:
for t in traits.values():
t.name = 'element'

super(Dict, self).__init__(klass=dict, args=args, **metadata)

Expand Down