Skip to content

Commit

Permalink
Merge changes from github.
Browse files Browse the repository at this point in the history
END_PUBLIC

---
Commit daa67ad authored by Jonathan Hseu<vomjom@vomjom.net>
Committed by Frank Chen<frankchn@gmail.com>:
Remove unittest import (#11596)

---
Commit 491beb7 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 162423171

PiperOrigin-RevId: 162541442
  • Loading branch information
Jonathan Hseu authored and tensorflower-gardener committed Jul 19, 2017
1 parent 1fefe92 commit 9cc871e
Show file tree
Hide file tree
Showing 67 changed files with 746 additions and 118 deletions.
4 changes: 3 additions & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Where component owners are known, add them here.

tensorflow/core/platform/windows/* @mrry
tensorflow/java/* @asimshankar
tensorflow/tensorboard/* @jart @dandelionmane
tensorflow/tools/docs/* @markdaoust
tensorflow/java/* @asimshankar

# contrib

Expand Down Expand Up @@ -46,5 +47,6 @@ tensorflow/contrib/stateless/* @girving
tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst
tensorflow/contrib/testing/* @dandelionmane
tensorflow/contrib/timeseries/* @allenlavoie
tensorflow/contrib/tpu/* @frankchn @saeta @jhseu
tensorflow/contrib/training/* @joel-shor @ebrevdo
tensorflow/contrib/util/* @sherrym
20 changes: 8 additions & 12 deletions configure
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function is_windows() {
}

function is_ppc64le() {
[[ "${uname -m}" == "ppc64le" ]]
[[ "$(uname -m)" == "ppc64le" ]]
}

function sed_in_place() {
Expand Down Expand Up @@ -298,7 +298,7 @@ fi # TF_NEED_MKL

## Set up architecture-dependent optimization flags.
if [ -z "$CC_OPT_FLAGS" ]; then
if [ is_ppc64le ]; then
if is_ppc64le; then
# gcc on ppc64le does not support -march, use mcpu instead
default_cc_opt_flags="-mcpu=native"
else
Expand Down Expand Up @@ -492,6 +492,8 @@ while true; do
if [ -z "$TF_CUDA_VERSION" ]; then
read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: " TF_CUDA_VERSION
fi
# Set default CUDA version if not set
TF_CUDA_VERSION=${TF_CUDA_VERSION:-8.0}

fromuser=""
if [ -z "$CUDA_TOOLKIT_PATH" ]; then
Expand Down Expand Up @@ -545,11 +547,7 @@ while true; do
CUDA_TOOLKIT_PATH=""
done

# Set default CUDA version if not set
if [ -z "$TF_CUDA_VERSION" ]; then
TF_CUDA_VERSION="8.0"
export TF_CUDA_VERSION
fi
export TF_CUDA_VERSION
write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION"

# Set up which gcc nvcc should use as the host compiler
Expand Down Expand Up @@ -587,6 +585,8 @@ while true; do
if [ -z "$TF_CUDNN_VERSION" ]; then
read -p "Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: " TF_CUDNN_VERSION
fi
# Set default CUDNN version if not set
TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-6}

fromuser=""
if [ -z "$CUDNN_INSTALL_PATH" ]; then
Expand Down Expand Up @@ -659,11 +659,7 @@ while true; do
CUDNN_INSTALL_PATH=""
done

# Set default CUDNN version if not set
if [ -z "$TF_CUDNN_VERSION" ]; then
TF_CUDNN_VERSION="6"
export TF_CUDNN_VERSION
fi
export TF_CUDNN_VERSION
write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION"

# Configure the compute capabilities that TensorFlow builds for.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/cc/framework/gradients.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class SymbolicGradientBuilder {
// gradients for the node associated with `src`.
Status BackpropAlongEdge(const Output& dst_grad, const Output& src);

// Adds a node to the graph (returned in`grad`) that sums the in-bound
// Adds a node to the graph (returned in `grad`) that sums the in-bound
// gradients to `src` (if there are more than one).
Status SumGradients(const Output& src, Output* grad);

Expand Down
7 changes: 4 additions & 3 deletions tensorflow/compiler/tf2xla/xla_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ extern const char* const DEVICE_XLA_CPU;
extern const char* const DEVICE_XLA_GPU;

constexpr std::array<DataType, 2> kIntTypes = {{DT_INT32, DT_INT64}};
constexpr std::array<DataType, 2> kFloatTypes = {{DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 4> kNumericTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 3> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 5> kNumericTypes = {
{DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}};

constexpr std::array<DataType, 5> kCpuAllTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/compiler/xla/util.h"

#include <numeric>
#include <stdarg.h>
#include <numeric>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class UnionClusterResolver(ClusterResolver):
This class performs a union given two or more existing ClusterResolvers. It
merges the underlying ClusterResolvers, and returns one unified ClusterSpec
when as_cluster_spec is called. The details of the merge function is
documented in the as_cluster_spec function.
when cluster_spec is called. The details of the merge function is
documented in the cluster_spec function.
"""

def __init__(self, *args):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/cmake/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Step-by-step Windows build
* `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include
GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1.
CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unziped_cudnn.
CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn.
* `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests.
There are many of them and building will take a few hours.
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/framework/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def get_unique_variable(var_op_name):
for candidate in candidates:
if candidate.op.name == var_op_name:
return candidate
raise ValueError('Variable %s does not uniquely identify a variable',
raise ValueError('Variable %s does not uniquely identify a variable' %
var_op_name)


Expand Down Expand Up @@ -444,7 +444,7 @@ def assign_from_values(var_names_to_values):
var_value = var_names_to_values[var_name]
var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, var_name)
if not var:
raise ValueError('Variable %s wasnt found', var_name)
raise ValueError('Variable %s wasn\'t found' % var_name)
elif len(var) > 1:
# tf.get_collection is just a filter on the prefix: find the exact match:
found = False
Expand All @@ -455,7 +455,7 @@ def assign_from_values(var_names_to_values):
break

if not found:
raise ValueError('Variable %s doesnt uniquely identify a variable',
raise ValueError('Variable %s doesn\'t uniquely identify a variable' %
var_name)
else:
var = var[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sparse_boolean_mask(sparse_tensor, mask, name="sparse_boolean_mask"):
Args:
sparse_tensor: a `SparseTensor`.
mask: a 1D boolean dense`Tensor` whose length is equal to the 0th dimension
mask: a 1D boolean dense `Tensor` whose length is equal to the 0th dimension
of `sparse_tensor`.
name: optional name for this operation.
Returns:
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/contrib/learn/python/learn/estimators/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib import lookup as lookup_lib
# TODO(ptucker): Use tf.losses and tf.metrics.
from tensorflow.contrib import losses as losses_lib
# TODO(ptucker): Use tf.metrics.
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
Expand All @@ -44,6 +43,7 @@
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training
Expand Down Expand Up @@ -1212,7 +1212,8 @@ def _loss_fn(labels, logits, weights=None):
with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
with ops.control_dependencies((_assert_labels_rank(labels),)):
labels = array_ops.reshape(labels, shape=(-1, 1))
loss = losses_lib.hinge_loss(logits=logits, labels=labels, scope=name)
loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name,
reduction=losses_lib.Reduction.NONE)
return _compute_weighted_loss(loss, weights)

super(_BinarySvmHead, self).__init__(
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/contrib/makefile/build_all_ios.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ set -e

# Make sure we're on OS X.
if [[ $(uname) != "Darwin" ]]; then
echo "ERROR: This makefile build requires OS X, which the current system "\
echo "ERROR: This makefile build requires macOS, which the current system "\
"is not."
exit 1
fi
Expand All @@ -37,7 +37,9 @@ rm -rf tensorflow/contrib/makefile/downloads
#
# ld: -bind_at_load and -bitcode_bundle (Xcode setting ENABLE_BITCODE=YES) cannot be used together
#
export MACOSX_DEPLOYMENT_TARGET="10.10"
if [[ -n MACOSX_DEPLOYMENT_TARGET ]]; then
export MACOSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion)
fi

# Pull down the required versions of the frameworks we need.
tensorflow/contrib/makefile/download_dependencies.sh
Expand All @@ -48,6 +50,5 @@ tensorflow/contrib/makefile/compile_ios_protobuf.sh
# Build the iOS TensorFlow libraries.
tensorflow/contrib/makefile/compile_ios_tensorflow.sh "-O3"

# Creates a static universal library in
# Creates a static universal library in
# tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a

5 changes: 4 additions & 1 deletion tensorflow/contrib/makefile/compile_ios_protobuf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# ==============================================================================
# Builds protobuf 3 for iOS.

set -x
set -e

if [[ -n MACOSX_DEPLOYMENT_TARGET ]]; then
export MACOSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion)
fi

SCRIPT_DIR=$(dirname $0)
source "${SCRIPT_DIR}/build_helper.subr"

Expand Down
6 changes: 5 additions & 1 deletion tensorflow/contrib/makefile/compile_ios_tensorflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ function less_than_required_version() {
)
}

if [[ -n MACOSX_DEPLOYMENT_TARGET ]]; then
export MACOSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion)
fi

ACTUAL_XCODE_VERSION=$(xcodebuild -version | head -n 1 | sed 's/Xcode //')
REQUIRED_XCODE_VERSION=7.3.0
if less_than_required_version $ACTUAL_XCODE_VERSION 7 3 0
Expand All @@ -44,7 +48,7 @@ LIBDIR=${GENDIR}lib
LIB_PREFIX=libtensorflow-core

make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \
TARGET=IOS IOS_ARCH=ARMV7 LIB_NAME=${LIB_PREFIX}-armv7.a OPTFLAGS="$1"
TARGET=IOS IOS_ARCH=ARMV7 LIB_NAME=${LIB_PREFIX}-armv7.a OPTFLAGS="$1"
if [ $? -ne 0 ]
then
echo "armv7 compilation failed."
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/signal/python/ops/reconstruction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _shuffle_to_front(input_tensor, k):
k: A scalar `Tensor` specifying how many indices to shuffle.
Returns:
A tranposed version of `input_tensor` with `k` indices shuffled to the
A transposed version of `input_tensor` with `k` indices shuffled to the
front.
Raises:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/slim/python/slim/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# Evaluate every 10 minutes:
slim.evaluation_loop(
master='',
'',
checkpoint_dir,
logdir,
num_evals=num_evals,
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/tfprof/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# tfprof: TensorFlow Profiler and Beyond

<h1>Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`</h1>
<h1>Full Document in tensorflow/core/profiler/README.md<h1>
<h1>Full Document in <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/README.md">tensorflow/core/profiler/README.md</a><h1>

###Features
### Features

* Profile model architectures
* parameters, tensor shapes, float operations, device placement, etc.
Expand All @@ -16,7 +16,7 @@
* operation configuration check
* distributed runtime check (Not OSS)

###Interfaces
### Interfaces

* Python API
* Command Line
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/timeseries/python/timeseries/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _filtering_step(self, current_times, current_values, state, predictions):
Args:
current_times: A [batch size] Tensor of times for each observation.
current_values: A [batch size] Tensor of values for each observaiton.
current_values: A [batch size] Tensor of values for each observation.
state: Model state, updated to current_times.
predictions: The outputs of _prediction_step
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _filtering_step(self, current_times, current_values, state, predictions):
Args:
current_times: A [batch size] Tensor for times for each observation.
current_values: A [batch size] Tensor of values for each observaiton.
current_values: A [batch size] Tensor of values for each observation.
state: A tuple of (mean, covariance, previous_times) having shapes
mean; [batch size x state dimension]
covariance; [batch size x state dimension x state dimension]
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ load(
"full_path",
"if_android",
"if_ios",
"if_x86",
"if_linux_x86_64",
"if_not_mobile",
"if_not_windows",
"tf_copts",
Expand Down Expand Up @@ -1379,7 +1379,7 @@ cc_library(
name = "lib_hash_crc32c_accelerate_internal",
srcs = ["lib/hash/crc32c_accelerate.cc"],
# -msse4.2 enables the use of crc32c compiler builtins.
copts = tf_copts() + if_x86(["-msse4.2"]),
copts = tf_copts() + if_linux_x86_64(["-msse4.2"]),
)

cc_library(
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/core/framework/register_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,18 @@ limitations under the License.
#define TF_CALL_QUANTIZED_TYPES(m) \
TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m)

#ifdef TENSORFLOW_SYCL_NO_DOUBLE
#define TF_CALL_SYCL_double(m)
#else // TENSORFLOW_SYCL_NO_DOUBLE
#define TF_CALL_SYCL_double(m) TF_CALL_double(m)
#endif // TENSORFLOW_SYCL_NO_DOUBLE

#ifdef __ANDROID_TYPES_SLIM__
#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
#else // __ANDROID_TYPES_SLIM__
#define TF_CALL_SYCL_NUMBER_TYPES(m) \
TF_CALL_float(m) \
TF_CALL_SYCL_double(m)
#endif // __ANDROID_TYPES_SLIM__

#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/batch_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {

void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 batch_size;
int64 batch_size = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
OP_REQUIRES(
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/kernels/cwise_op_add_1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ REGISTER_KERNEL_BUILDER(Name("Add")


#if TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "Add", functor::add, float, double);
#define REGISTER_KERNEL(type) REGISTER(BinaryOp, SYCL, "Add", functor::add, type);
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);

REGISTER_KERNEL_BUILDER(Name("Add")
.Device(DEVICE_SYCL)
.HostMemory("x")
Expand Down
Loading

0 comments on commit 9cc871e

Please sign in to comment.