Skip to content

Commit

Permalink
Fix tutorials failures. (tensorflow#6200)
Browse files Browse the repository at this point in the history
* Fix ptb_word_lm tutorial.

* Fix failures in cifar10_train example.

* Update cifar10_input.py

Correct shape for read_input.label
  • Loading branch information
gunan authored and caisq committed Dec 9, 2016
1 parent 144b72a commit 936ae38
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/models/image/cifar10/cifar10_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def distorted_inputs(data_dir, batch_size):
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image)

# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/models/rnn/ptb/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def ptb_producer(raw_data, batch_size, num_steps, name=None):
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i * num_steps + 1],
[batch_size, (i + 1) * num_steps + 1])
y.set_shape([batch_size, num_steps])
return x, y

0 comments on commit 936ae38

Please sign in to comment.