Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

WIP: Added new initializers #1666

Merged
merged 18 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 94 additions & 13 deletions tensor2tensor/trax/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Trax initializers."""

from __future__ import absolute_import
Expand All @@ -23,28 +22,110 @@
from tensor2tensor.trax import backend


def _get_fans(shape, out_dim=-1, in_dim=-2):
#temporary fix until numpy.delete supports negative indices
if out_dim < 0:
out_dim += len(shape)
if in_dim < 0:
in_dim += len(shape)

receptive_field = backend.numpy.prod(onp.delete(shape, [in_dim, out_dim]))
if len(shape) >= 2:
fan_in, fan_out = shape[in_dim], shape[out_dim]
elif len(shape) == 1:
fan_in, fan_out = shape[0]
else:
fan_in, fan_out = 1.
fan_in *= receptive_field
fan_out *= receptive_field
return fan_in, fan_out


def RandomNormalInitializer(stddev=1e-2):
"""An initializer function for random normal coefficients."""

def Init(shape, rng):
return (stddev * backend.random.normal(rng, shape)).astype('float32')

return Init


def GlorotNormalInitializer(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
"""An initializer function for random Glorot-scaled coefficients."""
def RandomUniformInitializer(lim=1.0):
"""An initializer function for random uniform coefficients."""

def Init(shape, rng):
fan_in, fan_out = shape[in_dim], shape[out_dim]
size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
std = scale / backend.numpy.sqrt((fan_in + fan_out) / 2. * size)
return (std * backend.random.normal(rng, shape)).astype('float32')
return (backend.random.uniform(rng, shape, backend.numpy.float32, -lim,
lim))

return Init


def GlorotUniformInitializer(out_dim=0, in_dim=1):
"""An initializer function for random uniform Glorot-scaled coefficients."""
def VarianceScalingInitializer(out_dim, in_dim, scale, mode, distribution):
"""Initializer capable of adapting its scale to the shape of weights tensors."""
if scale <= 0.:
raise ValueError('scale must be positive float, {} given'.format(scale))
if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
raise ValueError(
'Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg'
.format(mode))

def Init(shape, rng):
fan_in, fan_out = shape[in_dim], shape[out_dim]
std = backend.numpy.sqrt(2.0 / (fan_in + fan_out))
a = backend.numpy.sqrt(3.0) * std
return backend.random.uniform(rng, shape, minval=-a, maxval=a)
fan_in, fan_out = _get_fans(shape, out_dim, in_dim)
gain = scale
if mode == 'fan_in':
gain /= fan_in
elif mode == 'fan_out':
gain /= fan_out
elif mode == 'fan_avg':
gain /= (fan_in + fan_out) / 2
if distribution == 'truncated_normal':
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = backend.numpy.sqrt(gain) / .87962566103423978
return (backend.random.truncated_normal(rng, -2, 2, shape) *
stddev).astype('float32')
elif distribution == 'normal':
return (backend.random.normal(rng, shape) *
backend.numpy.sqrt(gain)).astype('float32')
elif distribution == 'uniform':
lim = backend.numpy.sqrt(3. * gain)
return (backend.random.uniform(rng, shape, backend.numpy.float32, -lim,
lim))
else:
raise ValueError('invalid distribution for variance scaling Initializer')

return Init


def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.):
"""An initializer function for random Glorot-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_avg', 'normal')


def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.):
"""An initializer function for random uniform Glorot-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_avg',
'uniform')


def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.):
"""An initializer function for random LeCun-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_in', 'normal')


def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.):
"""An initializer function for random uniform LeCun-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_in', 'uniform')


def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.):
"""An initializer function for random Kaiming-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim,
2.0 / backend.numpy.sqrt(1 + param**2),
'fan_in', 'normal')


def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.):
"""An initializer function for random uniform Kaiming-scaled coefficients."""
return VarianceScalingInitializer(out_dim, in_dim,
2.0 / backend.numpy.sqrt(1 + param**2),
'fan_in', 'uniform')
48 changes: 48 additions & 0 deletions tensor2tensor/trax/layers/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,53 @@ def test_random_normal(self):
self.assertEqual(tuple(init_value.shape), input_shape)


def test_random_uniform(self):
initializer = initializers.RandomUniformInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)

def test_glorot_normal(self):
initializer = initializers.GlorotNormalInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


def test_glorot_uniform(self):
initializer = initializers.GlorotUniformInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


def test_lecun_normal(self):
initializer = initializers.LeCunNormalInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


def test_lecun_uniform(self):
initializer = initializers.LeCunUniformInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


def test_kaiming_normal(self):
initializer = initializers.KaimingNormalInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


def test_kaiming_uniform(self):
initializer = initializers.KaimingUniformInitializer()
input_shape = (29, 5, 7, 20)
init_value = initializer(input_shape, random.get_prng(0))
self.assertEqual(tuple(init_value.shape), input_shape)


if __name__ == "__main__":
absltest.main()