Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Conflicts:
	tensorflow/workspace.bzl
  • Loading branch information
benoitsteiner committed Dec 7, 2016
2 parents 2f2d45e + 287db3a commit a62c532
Show file tree
Hide file tree
Showing 587 changed files with 35,979 additions and 8,948 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
</div>
-----------------

| **`Linux CPU`** | **`Linux GPU PIP`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|-----------------|---------------------|------------------|-------------------|---------------|
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-gpu_pip)](https://ci.tensorflow.org/job/tensorflow-master-gpu_pip) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |

**TensorFlow** is an open source software library for numerical computation using
data flow graphs. Nodes in the graph represent mathematical operations, while
Expand Down
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
removed.
* `tf.all_variables`, `tf.VARIABLES` and `tf.initialize_all_variables` renamed
to `tf.global_variables`, `tf.GLOBAL_VARIABLES` and
`tf.global_variable_initializers` respectively.
`tf.global_variables_initializer` respectively.

## Bug Fixes and Other Changes

Expand Down
2 changes: 1 addition & 1 deletion configure
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ while true; do
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
echo "libcudnn.so resolves to libcudnn${TF_CUDNN_EXT}"
elif [[ "$REALVAL" =~ ([0-9]*).dylib ]]; then
TF_CUDNN_EXT="."${BASH_REMATCH[1]}".dylib"
TF_CUDNN_EXT=${BASH_REMATCH[1]}".dylib"
TF_CUDNN_VERSION=${BASH_REMATCH[1]}
echo "libcudnn.dylib resolves to libcudnn${TF_CUDNN_EXT}"
fi
Expand Down
115 changes: 115 additions & 0 deletions libxsmm.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Description:
# LIBXSMM: Library for small matrix-matrix multiplications targeting Intel Architecture (x86).

licenses(["notice"]) # BSD 3-clause
exports_files(["LICENSE"])

# Arguments to ./scripts/libxsmm_interface.py, see that file for detailed description.
# precision: SP & DP
# ilp64: no
# prefetch: 1 (auto)
libxsmm_interface_arguments = "0 0 1"

# Arguments to ./scripts/libxsmm_config.py, see that file for detailed description.
# ilp64: no
# offload: no
# alignment [b]
# prefetch: 1 (auto)
# threshold: fallback to BLAS if n*m*k above this
# synchronize: yes
# jit: yes
# flags
# alpha = 1
# beta = 1
libxsmm_config_arguments = "0 0 64 1 0 1 1 0 1 1"

genrule(
name = "libxsmm_headers",
srcs = [
"src/template/libxsmm.h",
"src/template/libxsmm_config.h",
],
outs = [
"include/libxsmm.h",
"include/libxsmm_config.h",
],
cmd = "$(location :libxsmm_interface) $(location src/template/libxsmm.h) " + libxsmm_interface_arguments + " > $(location include/libxsmm.h);" +
"$(location :libxsmm_config) $(location src/template/libxsmm_config.h) " + libxsmm_config_arguments + " > $(location include/libxsmm_config.h)",
tools = [
":libxsmm_config",
":libxsmm_interface",
],
)

cc_library(
name = "xsmm_avx",
srcs = [
"src/libxsmm_main.c",
"src/libxsmm_dump.c",
"src/libxsmm_malloc.c",
"src/libxsmm_gemm.c",
"src/libxsmm_timer.c",
"src/libxsmm_trace.c",
"src/libxsmm_trans.c",
"src/libxsmm_sync.c",
"src/libxsmm_perf.c",
"src/libxsmm_dnn.c",
"src/libxsmm_dnn_convolution_forward.c",
"src/libxsmm_cpuid_x86.c",
] + glob([
"src/generator_*.c",
]),
hdrs = [
"include/libxsmm_dnn.h",
"include/libxsmm_frontend.h",
"include/libxsmm_generator.h",
"include/libxsmm_macros.h",
"include/libxsmm_malloc.h",
"include/libxsmm_sync.h",
"include/libxsmm_timer.h",
"include/libxsmm_typedefs.h",
"include/libxsmm_dispatch.h",
"src/libxsmm_gemm_diff.c",
"src/libxsmm_cpuid_x86.c",
"src/libxsmm_hash.c",
# Generated:
"include/libxsmm.h",
"include/libxsmm_config.h",
] + glob([
"src/*.h",
"src/template/*.c",
]),
copts = [
"-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
"-Wno-vla", # Libxsmm convolutions heavily use VLA.
],
defines = [
"LIBXSMM_BUILD",
"LIBXSMM_CPUID_X86_NOINLINE",
"__BLAS=0",
],
includes = ["include"],
linkopts = ["-ldl"],
visibility = ["//visibility:public"],
deps = [
":libxsmm_headers",
],
)

py_library(
name = "libxsmm_scripts",
srcs = glob(["scripts/*.py"]),
data = ["version.txt"],
)

py_binary(
name = "libxsmm_interface",
srcs = ["scripts/libxsmm_interface.py"],
deps = [":libxsmm_scripts"],
)

py_binary(
name = "libxsmm_config",
srcs = ["scripts/libxsmm_config.py"],
deps = [":libxsmm_scripts"],
)
4 changes: 3 additions & 1 deletion tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ filegroup(
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/compiler:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
Expand All @@ -105,6 +106,8 @@ filegroup(
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/input_pipeline:all_files",
"//tensorflow/contrib/input_pipeline/kernels:all_files",
"//tensorflow/contrib/integrate:all_files",
"//tensorflow/contrib/labeled_tensor:all_files",
"//tensorflow/contrib/layers:all_files",
Expand All @@ -116,7 +119,6 @@ filegroup(
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/metrics/kernels:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/rnn:all_files",
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/cc/gradients/math_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ Status LogGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Log", LogGrad);

Status Log1pGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// f(x) = log1p(x) = y
// df/dx = 1 / (1 + x)
// dx = dy * (1 / (1 + x))
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
grad_outputs->push_back(
Div(scope, grad_inputs[0], Add(scope, one, op.input(0))));
return scope.status();
}
REGISTER_GRADIENT_OP("Log1p", Log1pGrad);

Status TanhGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/cc/gradients/math_grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CWiseUnaryGradTest : public ::testing::Test {
RSQRT,
EXP,
LOG,
LOG1P,
TANH,
SIGMOID,
SIGN,
Expand Down Expand Up @@ -101,6 +102,9 @@ class CWiseUnaryGradTest : public ::testing::Test {
case LOG:
y = Log(scope_, x);
break;
case LOG1P:
y = Log1p(scope_, x);
break;
case TANH:
y = Tanh(scope_, x);
break;
Expand Down Expand Up @@ -207,6 +211,15 @@ TEST_F(CWiseUnaryGradTest, Log) {
TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn);
}

TEST_F(CWiseUnaryGradTest, Log1p) {
auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) {
return dy * (1.0 / (1.0 + x));
};
TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn);
}

TEST_F(CWiseUnaryGradTest, Tanh) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/cc/saved_model/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ namespace tensorflow {
namespace {

auto* load_attempt_count = monitoring::Counter<2>::New(
"/tensorflow/cc/saved_model/load_attempt_count", "model_path", "status",
"The number of times a SavedModel was successfully loaded.");
"/tensorflow/cc/saved_model/load_attempt_count",
"The number of times a SavedModel was successfully loaded.", "model_path",
"status");
auto* load_latency = monitoring::Counter<1>::New(
"/tensorflow/cc/saved_model/load_latency", "model_path",
"Latency in microseconds for SavedModels that were successfully loaded.");
"/tensorflow/cc/saved_model/load_latency",
"Latency in microseconds for SavedModels that were succesfully loaded.",
"model_path");
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/cc/saved_model/loader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {

// Directory that exists but is an invalid SavedModel location.
const string invalid_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata");
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model");
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
}

Expand Down
4 changes: 3 additions & 1 deletion tensorflow/cc/training/queue_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
last_run = (runs_ == 0);
}

if (IsQueueClosed(status)) {
// Close the queue unless the coordinator is shutting down since the cancel op
// will be run anway in this case.
if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
if (last_run && !close_op_name_.empty()) {
UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr));
}
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/contrib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
Expand All @@ -24,10 +25,12 @@ py_library(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
Expand Down Expand Up @@ -58,9 +61,9 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:bucketization_op_kernel",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/metrics:set_ops_kernels",
],
)

Expand All @@ -70,9 +73,9 @@ cc_library(
deps = [
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:bucketization_op_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/metrics:set_ops_op_lib",
],
)

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import bayesflow
from tensorflow.contrib import compiler
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
from tensorflow.contrib import cudnn_rnn
Expand All @@ -29,10 +30,12 @@
from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def testStochasticVariables(self):

self.assertEqual(
{"stochastic_variables/sv_mu", "stochastic_variables/sv_sigma"},
set([v.op.name for v in tf.all_variables()]))
self.assertEqual(set(tf.trainable_variables()), set(tf.all_variables()))
set([v.op.name for v in tf.global_variables()]))
self.assertEqual(set(tf.trainable_variables()), set(tf.global_variables()))

v = tf.convert_to_tensor(v)
self.assertEqual(list(shape), v.get_shape().as_list())
Expand All @@ -64,7 +64,7 @@ def testStochasticVariablesWithConstantInitializer(self):
})):
v = tf.get_variable("sv")

for var in tf.all_variables():
for var in tf.global_variables():
if "mu" in var.name:
mu_var = var
if "sigma" in var.name:
Expand Down Expand Up @@ -96,7 +96,7 @@ def sigma_init(shape, dtype, partition_info):
})):
v = tf.get_variable("sv", shape)

for var in tf.all_variables():
for var in tf.global_variables():
if "mu" in var.name:
mu_var = var
if "sigma" in var.name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def mean_baseline(_, loss):
with vs.variable_scope(name, default_name="MeanBaseline"):
reduced_loss = math_ops.reduce_mean(loss)

ema = training.ExponentialMovingAverage(decay=ema_decay)
ema = training.ExponentialMovingAverage(decay=ema_decay, zero_debias=True)
update_op = ema.apply([reduced_loss])

with ops.control_dependencies([update_op]):
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/contrib/cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ if(WIN32)
add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC -D__VERSION__=\"MSVC\")
add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS)
add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH -D_ITERATOR_DEBUG_LEVEL=0)
add_definitions(-DNDEBUG /O2) # Equivalent of -c opt in Bazel.
add_definitions(/bigobj /nologo /EHsc /GF /FC /MP /Gm-)
# Suppress warnings to reduce build log size.
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
Expand Down Expand Up @@ -149,11 +150,11 @@ if (tensorflow_ENABLE_GPU)

# by default we assume compute cabability 3.5 and 5.2. If you change this change it in
# CUDA_NVCC_FLAGS and cuda_config.h below
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\")
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\")
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr)
set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include)
include_directories(${CUDA_INCLUDE})
add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.5,5.2)
add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2)

# add cudnn
include_directories(${CUDNN_HOME})
Expand All @@ -163,7 +164,7 @@ if (tensorflow_ENABLE_GPU)
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
"#ifndef CUDA_CUDA_CONFIG_H_\n"
"#define CUDA_CUDA_CONFIG_H_\n"
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
"#define TF_CUDA_VERSION \"64_80\"\n"
"#define TF_CUDNN_VERSION \"64_5\"\n"
"#endif // CUDA_CUDA_CONFIG_H_\n"
Expand Down
Loading

0 comments on commit a62c532

Please sign in to comment.