Skip to content

Commit 501a24f

Browse files
authored
Add new shape utils
Add in texar/tf/utils/shapes.py - `reduce_with_weights` - `varlength_*`
2 parents 4c8b0c0 + 91d909f commit 501a24f

File tree

3 files changed

+435
-1
lines changed

3 files changed

+435
-1
lines changed

docs/code/utils.rst

+17
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,27 @@ Shape
124124
~~~~~~~~~~~~~~~~~~~~~~~~~
125125
.. autofunction:: texar.tf.utils.pad_and_concat
126126

127+
:hidden:`reduce_with_weights`
128+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
129+
.. autofunction:: texar.tf.utils.reduce_with_weights
130+
127131
:hidden:`flatten`
128132
~~~~~~~~~~~~~~~~~~~~~~
129133
.. autofunction:: texar.tf.utils.flatten
130134

135+
:hidden:`varlength_concat`
136+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137+
.. autofunction:: texar.tf.utils.varlength_concat
138+
139+
:hidden:`varlength_concat_py`
140+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141+
.. autofunction:: texar.tf.utils.varlength_concat_py
142+
143+
:hidden:`varlength_roll`
144+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145+
.. autofunction:: texar.tf.utils.varlength_roll
146+
147+
131148
Dictionary
132149
===========
133150

texar/tf/utils/shapes.py

+306-1
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535
"mask_sequences",
3636
"_mask_sequences_tensor",
3737
"_mask_sequences_py",
38+
"reduce_with_weights",
3839
"flatten",
3940
"shape_list",
40-
"pad_and_concat"
41+
"pad_and_concat",
42+
"varlength_concat",
43+
"varlength_concat_py",
44+
"varlength_roll"
4145
]
4246

4347

@@ -249,6 +253,90 @@ def _mask_sequences_py(sequence,
249253
return sequence
250254

251255

256+
def reduce_with_weights(tensor,
257+
weights=None,
258+
average_across_batch=True,
259+
average_across_remaining=False,
260+
sum_over_batch=False,
261+
sum_over_remaining=True,
262+
tensor_rank=None):
263+
"""Weights and reduces tensor.
264+
265+
Args:
266+
tensor: A Tensor to weight and reduce, of shape
267+
`[batch_size, ...]`.
268+
weights (optional): A Tensor of the same shape and dtype with
269+
:attr:`tensor`. For example, this is can be a 0-1 tensor
270+
for masking values of :attr:`tensor``.
271+
average_across_batch (bool): If set, average the tensor across the
272+
batch dimension. Must not set `average_across_batch`'
273+
and `sum_over_batch` at the same time.
274+
average_across_remaining (bool): If set, average the
275+
tensor across the
276+
remaining dimensions. Must not set `average_across_remaining`'
277+
and `sum_over_remaining` at the same time.
278+
If :attr:`weights` is given, this is a weighted average.
279+
sum_over_batch (bool): If set, sum the tensor across the
280+
batch dimension. Must not set `average_across_batch`
281+
and `sum_over_batch` at the same time.
282+
sum_over_remaining (bool): If set, sum the tensor
283+
across the
284+
remaining dimension. Must not set `average_across_remaining`
285+
and `sum_over_remaining` at the same time.
286+
If :attr:`weights` is given, this is a weighted sum.
287+
tensor_rank (int, optional): The number of dimensions of
288+
:attr:`tensor`. If not given, inferred from :attr:`tensor`
289+
automatically.
290+
291+
Returns:
292+
A Tensor.
293+
294+
Example:
295+
.. code-block:: python
296+
297+
x = tf.constant([[10, 10, 2, 2],
298+
[20, 2, 2, 2]])
299+
mask = tf.constant([[1, 1, 0, 0],
300+
[1, 0, 0, 0]])
301+
302+
z = reduce_with_weights(x, weights=mask)
303+
# z == 20
304+
# (all 2 in x are masked)
305+
"""
306+
if tensor_rank is None:
307+
tensor_rank = get_rank(tensor)
308+
if tensor_rank is None:
309+
raise ValueError('Unable to infer the rank of `tensor`. '
310+
'Please set `tensor_rank` explicitly.')
311+
312+
if weights is not None:
313+
tensor = tensor * weights
314+
315+
if tensor_rank > 1:
316+
if average_across_remaining and sum_over_remaining:
317+
raise ValueError("Only one of `average_across_remaining` and "
318+
"`sum_over_remaining` can be set.")
319+
if average_across_remaining:
320+
if weights is None:
321+
tensor = tf.reduce_mean(tensor, axis=np.arange(1, tensor_rank))
322+
else:
323+
tensor = tf.reduce_sum(tensor, axis=np.arange(1, tensor_rank))
324+
weights = tf.reduce_sum(weights, axis=np.arange(1, tensor_rank))
325+
tensor = tensor / weights
326+
elif sum_over_remaining:
327+
tensor = tf.reduce_sum(tensor, axis=np.arange(1, tensor_rank))
328+
329+
if average_across_batch and sum_over_batch:
330+
raise ValueError("Only one of `average_across_batch` and "
331+
"`sum_over_batch` can be set.")
332+
if sum_over_batch:
333+
tensor = tf.reduce_sum(tensor, axis=[0])
334+
elif average_across_batch:
335+
tensor = tf.reduce_mean(tensor, axis=[0])
336+
337+
return tensor
338+
339+
252340
def flatten(tensor, preserve_dims, flattened_dim=None):
253341
"""Flattens a tensor whiling keeping several leading dimensions.
254342
@@ -382,3 +470,220 @@ def _pad_to_size(value, axis_, size):
382470
values[i] = _pad_to_size(v, pa, max_dim_size)
383471

384472
return tf.concat(values, axis)
473+
474+
475+
def varlength_concat(x, y, x_length, dtype=None, tensor_rank=None):
476+
"""Concatenates rows of `x` and `y` where each row of
477+
`x` has a variable length.
478+
479+
Both `x` and `y` are of numeric dtypes, such as `tf.int32` and `tf.float32`,
480+
with mask value `0`. The two tensors must be of the same dtype.
481+
482+
Args:
483+
x: A tensor of shape `[batch_size, x_dim_2, other_dims]`.
484+
y: A tensor of shape `[batch_size, y_dim_2, other_dims]`.
485+
All dimensions except the 2nd dimension must be the same
486+
with those of `x`.
487+
x_length: A 1D int tensor of shape `[batch_size]` containing
488+
the length of each `x` row.
489+
Elements beyond the respective lengths will be
490+
made zero.
491+
dtype: Type of :attr:`x`. If `None`, inferred from
492+
:attr:`x` automatically.
493+
tensor_rank (int, optional): The number of dimensions of
494+
:attr:`x`. If not given, inferred from :attr:`x`
495+
automatically.
496+
497+
Returns:
498+
A Tensor of shape `[batch_size, x_dim_2 + y_dim_2, other_dims]`.
499+
500+
Example:
501+
.. code-block:: python
502+
503+
x = tf.constant([[1, 1, 0, 0],
504+
[1, 1, 1, 0]])
505+
x_length = [2, 3]
506+
y = tf.constant([[2, 2, 0],
507+
[2, 2, 2]])
508+
509+
out = varlength_concat(x, y, x_length)
510+
# out = [[1, 1, 2, 2, 0, 0, 0]
511+
# [1, 1, 1, 2, 2, 2, 0]]
512+
"""
513+
x = tf.convert_to_tensor(x)
514+
y = tf.convert_to_tensor(y)
515+
x_length = tf.convert_to_tensor(x_length)
516+
517+
if tensor_rank is None:
518+
tensor_rank = get_rank(x) or get_rank(y)
519+
if tensor_rank is None:
520+
raise ValueError('Unable to infer the rank of `x`. '
521+
'Please set `tensor_rank` explicitly.')
522+
523+
x_masked = mask_sequences(x, x_length, dtype=dtype, tensor_rank=tensor_rank)
524+
zeros_y = tf.zeros_like(y)
525+
x_aug = tf.concat([x_masked, zeros_y], axis=1)
526+
527+
zeros_x = tf.zeros_like(x)
528+
y_aug = tf.concat([zeros_x, y], axis=1)
529+
530+
# Now, x_aug.shape == y_aug.shape
531+
532+
max_length_x = tf.shape(x)[1]
533+
batch_size = tf.shape(x)[0]
534+
535+
initial_index = tf.constant(0, dtype=tf.int32)
536+
initial_outputs_ta = tf.TensorArray(
537+
dtype=dtype or x.dtype,
538+
size=0,
539+
dynamic_size=True)
540+
541+
def _cond(index, _):
542+
return tf.less(index, batch_size)
543+
544+
def _body(index, outputs_ta):
545+
y_aug_i_rolled = tf.roll(
546+
input=y_aug[index],
547+
shift=x_length[index] - max_length_x, # shift to left
548+
axis=0)
549+
xy = x_aug[index] + y_aug_i_rolled
550+
return [index + 1, outputs_ta.write(index, xy)]
551+
552+
res = tf.while_loop(_cond, _body, [initial_index, initial_outputs_ta])
553+
554+
return res[1].stack()
555+
556+
557+
def varlength_concat_py(x, y, x_length, dtype=None):
558+
"""Concatenates rows of `x` and `y` where each row of
559+
`x` has a variable length.
560+
561+
The function has the same semantic as :func:`varlength_concat`,
562+
except that this function is for numpy arrays instead of TF tensors.
563+
564+
Both `x` and `y` are of numeric dtypes, such as `int32` and `float32`,
565+
with mask value `0`. The two arrays must be of the same dtype.
566+
567+
Args:
568+
x: A array of shape `[batch_size, x_dim_2, other_dims]`.
569+
y: A array of shape `[batch_size, y_dim_2, other_dims]`.
570+
All dimensions except the 2nd dimension must be the same
571+
with those of `x`.
572+
x_length: A 1D int array of shape `[batch_size]` containing
573+
the length of each `x` row.
574+
Elements beyond the respective lengths will be
575+
made zero.
576+
dtype: Type of :attr:`x`. If `None`, inferred from
577+
:attr:`x` automatically.
578+
579+
Returns:
580+
An array of shape `[batch_size, x_dim_2 + y_dim_2, other_dims]`.
581+
582+
Example:
583+
.. code-block:: python
584+
585+
x = np.asarray([[1, 1, 0, 0],
586+
[1, 1, 1, 0]])
587+
x_length = [2, 3]
588+
y = np.asarray([[2, 2, 0],
589+
[2, 2, 2]])
590+
591+
out = varlength_concat_py(x, y, x_length)
592+
# out = [[1, 1, 2, 2, 0, 0, 0]
593+
# [1, 1, 1, 2, 2, 2, 0]]
594+
"""
595+
x = np.asarray(x, dtype=dtype)
596+
y = np.asarray(y, dtype=dtype)
597+
598+
x_masked = mask_sequences(x, x_length, dtype=dtype)
599+
zeros_y = np.zeros_like(y)
600+
x_aug = np.concatenate([x_masked, zeros_y], axis=1)
601+
602+
zeros_x = np.zeros_like(x)
603+
y_aug = np.concatenate([zeros_x, y], axis=1)
604+
605+
# Now, x_aug.shape == y_aug.shape
606+
607+
max_length_x = x.shape[1]
608+
batch_size = x.shape[0]
609+
610+
for index in np.arange(batch_size):
611+
y_aug_i_rolled = np.roll(
612+
a=y_aug[index],
613+
shift=x_length[index] - max_length_x,
614+
axis=0)
615+
x_aug[index] += y_aug_i_rolled
616+
617+
return x_aug
618+
619+
620+
def varlength_roll(input, shift, axis=1, dtype=None):
621+
"""Rolls the elements of *each row* of a tensor along an axis for
622+
variable steps.
623+
624+
This is a `tf.while_loop` wrapper of :tf_main:`tf.roll <roll>`. Note the
625+
different definition of :attr:`shift` and :attr:`axis` here compared
626+
to :tf_main:`tf.roll <roll>`.
627+
628+
Args:
629+
input: A tensor of shape `[batch_size, other_dims]` where
630+
`other_dims` can be multiple dimensions.
631+
shift: A 1D int tensor of shape `[batch_size]` containing
632+
the steps for which each row in the batch are rolled.
633+
Positive shifts will roll towards larger indices, while
634+
negative shifts will roll towards smaller indices.
635+
axis: A scalar int tensor > 0. The dimension that the roll
636+
should occur.
637+
dtype: Type of :attr:`input`. If `None`, inferred from
638+
:attr:`input` automatically.
639+
640+
Returns:
641+
A Tensor of the same shape/dtype as :attr:`input`.
642+
643+
Example:
644+
.. code-block:: python
645+
646+
x = tf.constant([[0, 0, 1, 0],
647+
[0, 1, 1, 1]])
648+
shift = [-2, -1]
649+
650+
out = varlength_roll(x, shift)
651+
# out = [[1, 0, 0, 0]
652+
# [1, 1, 1, 0]]
653+
654+
655+
.. code-block:: python
656+
657+
x = tf.constant([[1, 2, 3, 4],
658+
[5, 6, 7, 8]])
659+
shift = [1, -1]
660+
661+
out = varlength_roll(x, shift)
662+
# out = [[4, 1, 2, 3]
663+
# [6, 7, 8, 5]]
664+
"""
665+
x = tf.convert_to_tensor(input)
666+
#x = input
667+
shift = tf.convert_to_tensor(shift)
668+
669+
batch_size = tf.shape(x)[0]
670+
671+
initial_index = tf.constant(0, dtype=tf.int32)
672+
initial_outputs_ta = tf.TensorArray(
673+
dtype=dtype or x.dtype,
674+
size=0,
675+
dynamic_size=True)
676+
677+
def _cond(index, _):
678+
return tf.less(index, batch_size)
679+
680+
def _body(index, outputs_ta):
681+
x_i_rolled = tf.roll(
682+
input=x[index],
683+
shift=shift[index],
684+
axis=axis-1)
685+
return [index + 1, outputs_ta.write(index, x_i_rolled)]
686+
687+
res = tf.while_loop(_cond, _body, [initial_index, initial_outputs_ta])
688+
689+
return res[1].stack()

0 commit comments

Comments
 (0)