Skip to content

Commit 0c90493

Browse files
committed
more tests
1 parent 71d193d commit 0c90493

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

python/tests/transformers/tf_tensor_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,24 @@
1313
# limitations under the License.
1414
#
1515

16+
from keras.layers import Conv1D, Dense, Flatten, MaxPool1D
1617
import numpy as np
1718
import tensorflow as tf
19+
import tensorframes as tfs
1820

1921
from pyspark.sql.types import Row
2022

2123
from sparkdl.graph.builder import IsolatedSession
24+
import sparkdl.graph.utils as tfx
2225
from sparkdl.transformers.tf_tensor import TFOneDimTensorTransformer
2326

2427
from ..tests import SparkDLTestCase
2528

29+
def grab_df_arr(df, output_col):
30+
""" Stack the numpy array from a DataFrame column """
31+
return np.array([row.asDict()[output_col]
32+
for row in df.select(output_col).toLocalIterator()])
33+
2634
class TFOneDimTransformerTest(SparkDLTestCase):
2735

2836
def test_simple(self):
@@ -52,3 +60,56 @@ def test_simple(self):
5260

5361
out_tgt = np.array([row.outCol for row in final_df.select('outCol').collect()])
5462
self.assertTrue(np.allclose(out_ref, out_tgt))
63+
64+
65+
def test_map_blocks_graph(self):
66+
67+
vec_size = 17
68+
num_vecs = 137
69+
70+
input_col = 'vec'
71+
output_col = 'outCol'
72+
73+
df = self.session.createDataFrame([
74+
Row(idx=idx, vec=np.random.randn(vec_size).tolist())
75+
for idx in range(num_vecs)])
76+
analyzed_df = tfs.analyze(df)
77+
78+
# Build the graph: the output should have the same leading/batch dimension
79+
with IsolatedSession(using_keras=True) as issn:
80+
tnsr_in = tfs.block(analyzed_df, input_col)
81+
inp = tf.expand_dims(tnsr_in, axis=2)
82+
inp = tf.cast(inp, tf.float32)
83+
conv = Conv1D(filters=4, kernel_size=2)(inp)
84+
pool = MaxPool1D(pool_size=2)(conv)
85+
flat = Flatten()(pool)
86+
dense = Dense(1)(flat)
87+
redsum = tf.reduce_sum(dense, axis=1)
88+
tnsr_out = tf.cast(redsum, tf.double, name='TnsrOut')
89+
90+
# Initialize the variables
91+
init_op = tf.global_variables_initializer()
92+
issn.run(init_op)
93+
# Train the model ...
94+
gfn = issn.asGraphFunction([tnsr_in], [tnsr_out])
95+
96+
97+
with IsolatedSession() as issn:
98+
feeds, fetches = issn.importGraphFunction(gfn, prefix='')
99+
orig_in_name = tfx.op_name(issn.graph, feeds[0])
100+
input_df = analyzed_df.withColumnRenamed(input_col, orig_in_name)
101+
output_df = tfs.map_blocks(fetches, input_df)
102+
orig_out_name = tfx.op_name(issn.graph, fetches[0])
103+
final_df = output_df.withColumnRenamed(orig_out_name, output_col)
104+
105+
arr_ref = grab_df_arr(final_df, output_col)
106+
107+
# Using the Transformer
108+
transformer = TFOneDimTensorTransformer(
109+
graphFunction=gfn, inputCol=input_col, outputCol=output_col)
110+
transformed_df = transformer.transform(analyzed_df)
111+
112+
arr_tgt = grab_df_arr(transformed_df, output_col)
113+
114+
self.assertTrue(np.allclose(arr_ref, arr_tgt))
115+

0 commit comments

Comments
 (0)