Skip to content

Commit

Permalink
add tile utility function; speed up pointmass sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Aug 31, 2016
1 parent 08bad5c commit 870dd96
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 4 deletions.
2 changes: 1 addition & 1 deletion edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
MFVI, KLpq, MAP, Laplace
from edward.util import cumprod, dot, get_dims, get_session, hessian, \
kl_multivariate_normal, log_sum_exp, logit, multivariate_rbf, rbf, \
set_seed, to_simplex
set_seed, tile, to_simplex
from edward.version import __version__
6 changes: 3 additions & 3 deletions edward/models/pointmass.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def sample_n(self, n, seed=None, name="sample_n"):
"""
with ops.name_scope(self.name):
with ops.op_scope([self._params, n], name):
# TODO
n = n.eval()
return tf.pack([self._params] * n)
multiples = tf.concat(0, [tf.expand_dims(n, 0),
[1] * len(self._params.get_shape())])
return tile(self._params, multiples)

@property
def is_reparameterized(self):
Expand Down
99 changes: 99 additions & 0 deletions edward/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,105 @@ def set_seed(x):
tf.set_random_seed(x)


def tile(input, multiples, *args, **kwargs):
"""Constructs a tensor by tiling a given tensor.
This extends ``tf.tile`` to features available in ``np.tile``.
Namely, ``inputs`` and ``multiples`` can be a 0-D tensor. Further,
if 1-D, ``multiples`` can be of any length according to broadcasting
rules (see documentation of ``np.tile`` or examples below).
Parameters
----------
input : tf.Tensor
The input tensor.
multiples : tf.Tensor
The number of repetitions of ``input`` along each axis. Has type
``tf.int32``. 0-D or 1-D.
*args :
Passed into ``tf.tile``.
**kwargs :
Passed into ``tf.tile``.
Returns
-------
tf.Tensor
Has the same type as ``input``.
Examples
--------
>>> a = tf.constant([0, 1, 2])
>>> sess.run(ed.tile(a, 2))
array([0, 1, 2, 0, 1, 2], dtype=int32)
>>> sess.run(ed.tile(a, (2, 2)))
array([[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2]], dtype=int32)
>>> sess.run(ed.tile(a, (2, 1, 2)))
array([[[0, 1, 2, 0, 1, 2]],
[[0, 1, 2, 0, 1, 2]]], dtype=int32)
>>>
>>> b = tf.constant([[1, 2], [3, 4]])
>>> sess.run(ed.tile(b, 2))
array([[1, 2, 1, 2],
[3, 4, 3, 4]], dtype=int32)
>>> sess.run(ed.tile(b, (2, 1)))
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)
>>>
>>> c = tf.constant([1, 2, 3, 4])
>>> sess.run(ed.tile(c, (4, 1)))
array([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]], dtype=int32)
Notes
-----
Sometimes this can result in an unknown shape. The core reason for
this is the following behavior:
>>> temp = tf.constant([1])
>>> tf.tile(tf.constant([[1.0]]),
... tf.concat(0, [temp, tf.constant([1.0]).get_shape()]))
<tf.Tensor 'Tile:0' shape=(1, 1) dtype=float32>
>>> temp = tf.reshape([tf.constant(1)], [1])
>>> tf.tile(tf.constant([[1.0]]),
... tf.concat(0, [temp, tf.constant([1.0]).get_shape()]))
<tf.Tensor 'Tile_1:0' shape=(?, ?) dtype=float32>
For this reason, we try to fetch ``multiples`` out of session if
possible. This can be slow if ``multiples`` has computationally
intensive dependencies in order to perform this fetch.
"""
input = tf.convert_to_tensor(input)
multiples = tf.convert_to_tensor(multiples)

# 0-d tensor
if len(input.get_shape()) == 0:
input = tf.expand_dims(input, 0)

# 0-d tensor
if len(multiples.get_shape()) == 0:
multiples = tf.expand_dims(multiples, 0)

try:
get_session()
multiples = tf.convert_to_tensor(multiples.eval())
except:
pass

# broadcasting
diff = len(input.get_shape()) - get_dims(multiples)[0]
if diff < 0:
input = tf.reshape(input, [1] * np.abs(diff) + get_dims(input))
elif diff > 0:
multiples = tf.concat(0, [tf.ones(diff, dtype=tf.int32), multiples])

return tf.tile(input, multiples, *args, **kwargs)


def to_simplex(x):
"""Transform real vector of length ``(K-1)`` to a simplex of dimension ``K``
using a backward stick breaking construction.
Expand Down
52 changes: 52 additions & 0 deletions tests/test-util/test_tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from edward.util import get_dims, tile


class test_tile_class(tf.test.TestCase):

def _test(self, input, multiples):
if isinstance(multiples, int) or isinstance(multiples, float):
multiples_shape = [multiples]
elif isinstance(multiples, tuple):
multiples_shape = list(multiples)
else:
multiples_shape = multiples

input_shape = get_dims(input)
diff = len(input_shape) - len(multiples_shape)
if diff < 0:
input_shape = [1] * abs(diff) + input_shape
elif diff > 0:
multiples_shape = [1] * diff + multiples_shape

val_true = [x * y for x, y in zip(input_shape, multiples_shape)]
with self.test_session():
val_est = get_dims(tile(input, multiples))
assert val_est == val_true

def test_0d(self):
x = tf.constant(0)
self._test(x, 2)
self._test(x, (2, 1))

def test_1d(self):
x = tf.constant([0, 1, 2])
self._test(x, 2)
self._test(x, (2, 2))
self._test(x, (2, 1, 2))
x = tf.constant([1, 2, 3, 4])
self._test(x, (4, 1))

def test_2d(self):
x = tf.constant([[1, 2], [3, 4]])
self._test(x, 2)
self._test(x, (2, 1))

if __name__ == '__main__':
tf.test.main()

0 comments on commit 870dd96

Please sign in to comment.