Skip to content

Commit 1d4fd06

Browse files
committed
Merge commit for internal changes
2 parents 66edcda + e5bcf54 commit 1d4fd06

File tree

185 files changed

+9436
-873
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

185 files changed

+9436
-873
lines changed

RELEASE.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Changes Since Last Release
2+
3+
## Features & Improvements
4+
* TensorBoard now has an Audio Dashboard, with associated audio summaries.
5+
* TensorBoard now has a reload button, and supports auto-reloading
6+
* TensorBoard scalar charts now show tooltips with more information
7+
* TensorBoard now supports run filtering
8+
* TensorBoard has color changes: the same run always gets the same hue
9+
10+
## Bug Fixes and Other Changes
11+
* TensorBoard now displays graphs with only one data point
12+
* TensorBoard now visually displays NaN values
13+
114
# Release 0.8.0
215

316
## Major Features and Improvements

eigen.BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package(default_visibility = ["//visibility:public"])
22

3-
archive_dir = "eigen-eigen-50812b426b7c"
3+
archive_dir = "eigen-eigen-aaa010b0dd40"
44

55
cc_library(
66
name = "eigen",

farmhash.BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
prefix_dir = "farmhash-34c13ddfab0e35422f4c3979f360635a8c050260"
4+
5+
genrule(
6+
name = "configure",
7+
srcs = glob(
8+
["**/*"],
9+
exclude = [prefix_dir + "/config.h"],
10+
),
11+
outs = [prefix_dir + "/config.h"],
12+
cmd = "pushd external/farmhash_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure; popd; popd; cp $$workdir/config.h $(@D); rm -rf $$workdir;" % prefix_dir,
13+
)
14+
15+
cc_library(
16+
name = "farmhash",
17+
srcs = [prefix_dir + "/src/farmhash.cc"],
18+
hdrs = [prefix_dir + "/src/farmhash.h"] + [":configure"],
19+
includes = [prefix_dir],
20+
visibility = ["//visibility:public"]
21+
)

tensorflow/contrib/bayesflow/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Ops for representing statistical distributions.
15+
"""Ops for representing Bayesian computation.
1616
17-
## This package provides classes for statistical distributions.
17+
## This package provides classes for Bayesian computation with TensorFlow.
1818
1919
"""
2020
from __future__ import absolute_import
2121
from __future__ import division
2222
from __future__ import print_function
23-
24-
# pylint: disable=unused-import,wildcard-import, line-too-long
25-
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
26-
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
27-
from tensorflow.contrib.distributions.python.ops.gaussian import *
28-
# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long

tensorflow/contrib/cmake/external/eigen.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
include (ExternalProject)
99

10-
set(eigen_archive_hash "50812b426b7c")
10+
set(eigen_archive_hash "aaa010b0dd40")
1111

1212
set(eigen_INCLUDE_DIRS
1313
${CMAKE_CURRENT_BINARY_DIR}
@@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS
1616
${tensorflow_source_dir}/third_party/eigen3
1717
)
1818
set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz)
19-
set(eigen_HASH SHA256=fa95e425c379c2c7b8a49d9ef7bd0c5a8369171c987affd6dbae5de8a8911c1a)
19+
set(eigen_HASH SHA256=948cccc08e3ce922e890fe39916b087d6651297cd7422a04524dbf44e372ed9a)
2020
set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen)
2121
set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install)
2222

tensorflow/contrib/distributions/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ cuda_py_tests(
3838
],
3939
)
4040

41+
cuda_py_tests(
42+
name = "uniform_test",
43+
size = "small",
44+
srcs = ["python/kernel_tests/uniform_test.py"],
45+
additional_deps = [
46+
":distributions_py",
47+
"//tensorflow/python:platform_test",
48+
],
49+
)
50+
4151
cuda_py_tests(
4252
name = "mvn_test",
4353
size = "small",

tensorflow/contrib/distributions/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Classes representing statistical distributions. Ops for working with them.
15+
"""Classes representing statistical distributions and ops for working with them.
1616
1717
## Classes for statistical distributions.
1818
1919
Classes that represent batches of statistical distributions. Each class is
2020
initialized with parameters that define the distributions.
2121
22+
### Base classes
23+
24+
@@BaseDistribution
25+
@@ContinuousDistribution
26+
@@DiscreteDistribution
27+
2228
### Univariate (scalar) distributions
2329
2430
@@Gaussian
31+
@@Uniform
2532
2633
### Multivariate distributions
2734
@@ -44,6 +51,8 @@
4451

4552
# pylint: disable=unused-import,wildcard-import,line-too-long
4653
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
54+
from tensorflow.contrib.distributions.python.ops.distribution import *
4755
from tensorflow.contrib.distributions.python.ops.gaussian import *
4856
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
4957
from tensorflow.contrib.distributions.python.ops.mvn import *
58+
from tensorflow.contrib.distributions.python.ops.uniform import *
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright 2015 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for Uniform distribution."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
24+
25+
class UniformTest(tf.test.TestCase):
26+
27+
def testUniformRange(self):
28+
with self.test_session():
29+
a = 3.0
30+
b = 10.0
31+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
32+
self.assertAllClose(a, uniform.a.eval())
33+
self.assertAllClose(b, uniform.b.eval())
34+
self.assertAllClose(b - a, uniform.range.eval())
35+
36+
def testUniformPDF(self):
37+
with self.test_session():
38+
a = tf.constant([-3.0] * 5 + [15.0])
39+
b = tf.constant([11.0] * 5 + [20.0])
40+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
41+
42+
a_v = -3.0
43+
b_v = 11.0
44+
x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
45+
46+
def _expected_pdf():
47+
pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
48+
pdf[x > b_v] = 0.0
49+
pdf[x < a_v] = 0.0
50+
pdf[5] = 1.0 / (20.0 - 15.0)
51+
return pdf
52+
53+
expected_pdf = _expected_pdf()
54+
55+
pdf = uniform.pdf(x)
56+
self.assertAllClose(expected_pdf, pdf.eval())
57+
58+
log_pdf = uniform.log_pdf(x)
59+
self.assertAllClose(np.log(expected_pdf), log_pdf.eval())
60+
61+
def testUniformShape(self):
62+
with self.test_session():
63+
a = tf.constant([-3.0] * 5)
64+
b = tf.constant(11.0)
65+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
66+
67+
self.assertEqual(uniform.batch_shape().eval(), (5,))
68+
self.assertEqual(uniform.get_batch_shape(), tf.TensorShape([5]))
69+
self.assertEqual(uniform.event_shape().eval(), 1)
70+
self.assertEqual(uniform.get_event_shape(), tf.TensorShape([]))
71+
72+
def testUniformPDFWithScalarEndpoint(self):
73+
with self.test_session():
74+
a = tf.constant([0.0, 5.0])
75+
b = tf.constant(10.0)
76+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
77+
78+
x = np.array([0.0, 8.0], dtype=np.float32)
79+
expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
80+
81+
pdf = uniform.pdf(x)
82+
self.assertAllClose(expected_pdf, pdf.eval())
83+
84+
def testUniformCDF(self):
85+
with self.test_session():
86+
batch_size = 6
87+
a = tf.constant([1.0] * batch_size)
88+
b = tf.constant([11.0] * batch_size)
89+
a_v = 1.0
90+
b_v = 11.0
91+
x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
92+
93+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
94+
95+
def _expected_cdf():
96+
cdf = (x - a_v) / (b_v - a_v)
97+
cdf[x >= b_v] = 1
98+
cdf[x < a_v] = 0
99+
return cdf
100+
101+
cdf = uniform.cdf(x)
102+
self.assertAllClose(_expected_cdf(), cdf.eval())
103+
104+
log_cdf = uniform.log_cdf(x)
105+
self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval())
106+
107+
def testUniformEntropy(self):
108+
with self.test_session():
109+
a_v = np.array([1.0, 1.0, 1.0])
110+
b_v = np.array([[1.5, 2.0, 3.0]])
111+
uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
112+
113+
expected_entropy = np.log(b_v - a_v)
114+
self.assertAllClose(expected_entropy, uniform.entropy().eval())
115+
116+
def testUniformAssertMaxGtMin(self):
117+
with self.test_session():
118+
a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
119+
b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
120+
uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
121+
122+
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
123+
"x < y"):
124+
uniform.a.eval()
125+
126+
def testUniformSample(self):
127+
with self.test_session():
128+
a = tf.constant([3.0, 4.0])
129+
b = tf.constant(13.0)
130+
a1_v = 3.0
131+
a2_v = 4.0
132+
b_v = 13.0
133+
n = tf.constant(100000)
134+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
135+
136+
samples = uniform.sample(n, seed=137)
137+
sample_values = samples.eval()
138+
self.assertEqual(sample_values.shape, (100000, 2))
139+
self.assertAllClose(sample_values[::, 0].mean(), (b_v + a1_v) / 2,
140+
atol=1e-2)
141+
self.assertAllClose(sample_values[::, 1].mean(), (b_v + a2_v) / 2,
142+
atol=1e-2)
143+
self.assertFalse(np.any(sample_values[::, 0] < a1_v) or np.any(
144+
sample_values >= b_v))
145+
self.assertFalse(np.any(sample_values[::, 1] < a2_v) or np.any(
146+
sample_values >= b_v))
147+
148+
def testUniformSampleMultiDimensional(self):
149+
with self.test_session():
150+
batch_size = 2
151+
a_v = [3.0, 22.0]
152+
b_v = [13.0, 35.0]
153+
a = tf.constant([a_v] * batch_size)
154+
b = tf.constant([b_v] * batch_size)
155+
156+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
157+
158+
n_v = 100000
159+
n = tf.constant(n_v)
160+
samples = uniform.sample(n, seed=138)
161+
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
162+
163+
sample_values = samples.eval()
164+
165+
self.assertFalse(np.any(sample_values[:, 0, 0] < a_v[0]) or np.any(
166+
sample_values[:, 0, 0] >= b_v[0]))
167+
self.assertFalse(np.any(sample_values[:, 0, 1] < a_v[1]) or np.any(
168+
sample_values[:, 0, 1] >= b_v[1]))
169+
170+
self.assertAllClose(sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2,
171+
atol=1e-2)
172+
self.assertAllClose(sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2,
173+
atol=1e-2)
174+
175+
def testUniformMeanAndVariance(self):
176+
with self.test_session():
177+
a = 10.0
178+
b = 100.0
179+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
180+
self.assertAllClose(uniform.variance.eval(), (b - a)**2 / 12)
181+
self.assertAllClose(uniform.mean.eval(), (b + a) / 2)
182+
183+
def testUniformNans(self):
184+
with self.test_session():
185+
a = 10.0
186+
b = [11.0, 100.0]
187+
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
188+
189+
no_nans = tf.constant(1.0)
190+
nans = tf.constant(0.0) / tf.constant(0.0)
191+
self.assertTrue(tf.is_nan(nans).eval())
192+
with_nans = tf.pack([no_nans, nans])
193+
194+
pdf = uniform.pdf(with_nans)
195+
196+
is_nan = tf.is_nan(pdf).eval()
197+
print(pdf.eval())
198+
self.assertFalse(is_nan[0])
199+
self.assertTrue(is_nan[1])
200+
201+
def testUniformSamplePdf(self):
202+
with self.test_session():
203+
a = 10.0
204+
b = [11.0, 100.0]
205+
uniform = tf.contrib.distributions.Uniform(a, b)
206+
self.assertTrue(tf.reduce_all(uniform.pdf(uniform.sample(10)) > 0).eval())
207+
208+
def testUniformBroadcasting(self):
209+
with self.test_session():
210+
a = 10.0
211+
b = [11.0, 20.0]
212+
uniform = tf.contrib.distributions.Uniform(a, b)
213+
214+
pdf = uniform.pdf([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
215+
expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
216+
self.assertAllClose(expected_pdf, pdf.eval())
217+
218+
219+
if __name__ == "__main__":
220+
tf.test.main()

0 commit comments

Comments
 (0)