Skip to content

Commit 91d909f

Browse files
committed
Fix shape utils
Explict convert-to-tensors in reduce_with_weights and varlength_roll
1 parent adb535a commit 91d909f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

texar/tf/utils/shapes.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def varlength_concat(x, y, x_length, dtype=None, tensor_rank=None):
510510
# out = [[1, 1, 2, 2, 0, 0, 0]
511511
# [1, 1, 1, 2, 2, 2, 0]]
512512
"""
513-
#x = tf.convert_to_tensor(x)
514-
#y = tf.convert_to_tensor(y)
513+
x = tf.convert_to_tensor(x)
514+
y = tf.convert_to_tensor(y)
515515
x_length = tf.convert_to_tensor(x_length)
516516

517517
if tensor_rank is None:
@@ -588,7 +588,7 @@ def varlength_concat_py(x, y, x_length, dtype=None):
588588
y = np.asarray([[2, 2, 0],
589589
[2, 2, 2]])
590590
591-
out = varlength_concat(x, y, x_length)
591+
out = varlength_concat_py(x, y, x_length)
592592
# out = [[1, 1, 2, 2, 0, 0, 0]
593593
# [1, 1, 1, 2, 2, 2, 0]]
594594
"""
@@ -662,8 +662,8 @@ def varlength_roll(input, shift, axis=1, dtype=None):
662662
# out = [[4, 1, 2, 3]
663663
# [6, 7, 8, 5]]
664664
"""
665-
#x = tf.convert_to_tensor(input)
666-
x = input
665+
x = tf.convert_to_tensor(input)
666+
#x = input
667667
shift = tf.convert_to_tensor(shift)
668668

669669
batch_size = tf.shape(x)[0]

texar/tf/utils/shapes_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_varlength_concat(self):
7474
[[1, 1, 0, 0],
7575
[1, 0, 0, 0],
7676
[1, 1, 1, 1]], dtype=np.int32)
77-
x_length = [2, 1, 4]
77+
x_length = np.asarray([2, 1, 4], dtype=np.int32)
7878
y = np.asarray(
7979
[[2, 2, 2, 0],
8080
[2, 2, 2, 2],

0 commit comments

Comments
 (0)